|
@@ -2,6 +2,7 @@ import torch
|
|
|
import time
|
|
|
import numpy as np
|
|
|
import torch.distributed as dist
|
|
|
+from tqdm import tqdm
|
|
|
|
|
|
from typing import Dict, Callable, Type, Tuple, List, Any
|
|
|
from abc import ABC, abstractmethod
|
|
@@ -42,6 +43,10 @@ class BaseTask(ABC):
|
|
|
|
|
|
self.file_groups = self.get_file_groups()
|
|
|
self.verbose = dist.get_rank() == 0
|
|
|
+ self.save_prediction = config.save_prediction
|
|
|
+
|
|
|
+ def save_prediction_to_file(self, file, prediction, data):
|
|
|
+ pass
|
|
|
|
|
|
def get_file_groups(self):
|
|
|
pattern_group = {}
|
|
@@ -71,7 +76,7 @@ class BaseTask(ABC):
|
|
|
|
|
|
result_dict_group = {}
|
|
|
for file in filelist:
|
|
|
- dataset = self.build_dataset(file)
|
|
|
+ dataset = self.build_dataset(file, group_name)
|
|
|
dataloader = build_data_loader(
|
|
|
dataset,
|
|
|
micro_batch_size=self.config.micro_batch_size,
|
|
@@ -81,13 +86,18 @@ class BaseTask(ABC):
|
|
|
)
|
|
|
|
|
|
prediction = []
|
|
|
+ tqdm_wrapper = tqdm if torch.distributed.get_rank() == 0 else lambda x:x
|
|
|
with torch.no_grad():
|
|
|
- for _, batch in enumerate(dataloader):
|
|
|
- prediction.append(self.predict_single_batch(batch))
|
|
|
+ for idx, batch in tqdm_wrapper(enumerate(dataloader)):
|
|
|
+ p_batch = self.predict_single_batch(batch)
|
|
|
+ prediction.append(p_batch)
|
|
|
+
|
|
|
|
|
|
prediction = gather_result(prediction, len(dataset), self.config.micro_batch_size)
|
|
|
result_dict = {key: metric(prediction, dataset.data) for key, metric in self.metrics.items()}
|
|
|
result_dict_group[file] = (result_dict, len(dataset))
|
|
|
+ if torch.distributed.get_rank() == 0 and self.save_prediction:
|
|
|
+ self.save_prediction_to_file(file, prediction, dataset.data)
|
|
|
|
|
|
if self.verbose:
|
|
|
self.report_single_metrics(file, result_dict)
|
|
@@ -152,7 +162,7 @@ class BaseTask(ABC):
|
|
|
pass
|
|
|
|
|
|
@abstractmethod
|
|
|
- def build_dataset(self, relative_path: str) -> EvaluationDataset:
|
|
|
+ def build_dataset(self, relative_path: str, split: str) -> EvaluationDataset:
|
|
|
pass
|
|
|
|
|
|
|
|
@@ -163,13 +173,17 @@ class GenerationTask(BaseTask, ABC):
|
|
|
def config_class(cls):
|
|
|
return GenerationTaskConfig
|
|
|
|
|
|
- def build_dataset(self, relative_path):
|
|
|
+ def build_dataset(self, relative_path, split):
|
|
|
return GenerationTaskDataset(join(self.config.path, relative_path), self.config)
|
|
|
|
|
|
def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: GenerationTaskConfig):
|
|
|
super(GenerationTask, self).__init__(model, tokenizer, config)
|
|
|
|
|
|
end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
|
|
|
+ if self.config.end_tokens:
|
|
|
+ for token in self.config.end_tokens:
|
|
|
+ end_tokens.append(self.tokenizer.tokenize(token)[-1])
|
|
|
+ print_rank_0(f"End tokens {end_tokens}")
|
|
|
if self.config.sampling_strategy == "BaseStrategy":
|
|
|
self.strategy = BaseStrategy(temperature=1.0, top_k=1, end_tokens=end_tokens)
|
|
|
elif self.config.sampling_strategy == "BeamSearchStrategy":
|
|
@@ -180,7 +194,7 @@ class GenerationTask(BaseTask, ABC):
|
|
|
end_tokens=end_tokens,
|
|
|
no_repeat_ngram_size=self.config.no_repeat_ngram_size,
|
|
|
min_gen_length=self.config.min_gen_length,
|
|
|
- deterministic=True, # For evaluation, we need a determined generation strategy
|
|
|
+ deterministic=False, # For evaluation, we need a determined generation strategy
|
|
|
)
|
|
|
else:
|
|
|
raise ValueError(f"unknown strategy {self.config.sampling_strategy}")
|
|
@@ -199,7 +213,7 @@ class MultiChoiceTask(BaseTask, ABC):
|
|
|
def config_class(cls):
|
|
|
return MultiChoiceTaskConfig
|
|
|
|
|
|
- def build_dataset(self, relative_path):
|
|
|
+ def build_dataset(self, relative_path, split):
|
|
|
return MultiChoiceTaskDataset(join(self.config.path, relative_path), self.config)
|
|
|
|
|
|
def predict_single_batch(self, batch) -> List[int]:
|