from os.path import join from typing import Dict, Tuple, List from abc import ABC from collections import defaultdict from evaluation import ( MultiChoiceTask, MultiChoiceTaskConfig, ) from evaluation.dataset import ( MultiChoiceTaskDataset, ) from evaluation.utils import ( print_rank_0, get_tokenized_input, ) class CrowsPairTask(MultiChoiceTask, ABC): config: MultiChoiceTaskConfig def build_dataset(self, relative_path): return CrowsPairDataset(join(self.config.path, relative_path), self.config) def predict_single_batch(self, batch) -> List[int]: log_probs = self.model.cond_log_prob(batch) return log_probs def CrowsPairMetric(self, predictions, examples): print_rank_0("Special metric for CrowsPair") results = defaultdict(float) labels = defaultdict() for prediction, example in zip(predictions, examples): prediction = prediction[0] if example["sent_ID"] == 1: results[example["pair_ID"]] = results[example["pair_ID"]] + prediction else: results[example["pair_ID"]] = results[example["pair_ID"]] - prediction labels[example["pair_ID"]] = example["bias_type"] cat_postivie = defaultdict(int) cat_tt = defaultdict(int) final = defaultdict(int) for val1, val2 in zip(results.values(), labels.values()): if val1 >= 0: cat_postivie[val2] = cat_postivie[val2] + 1 else: cat_postivie[val2] = cat_postivie[val2] cat_tt[val2] = cat_tt[val2] + 1 for key, val in cat_postivie.items(): final[key] = val / cat_tt[key] return final def report_single_metrics(self, file: str, result_dict: Dict[str, float]): pass @property def metrics(self): return {"CP": self.CrowsPairMetric} def report_group_metrics(self, group_name, result_dict_group: Dict[str, Tuple[Dict[str, float], int]], level=1): for result in result_dict_group.values(): result = result[0] for value1 in result.items(): value1 = value1[1] for key, value in value1.items(): print_rank_0("category:{cat} score:{score}".format(cat=key, score=round(value * 100,2))) class CrowsPairDataset(MultiChoiceTaskDataset): config: MultiChoiceTaskConfig def __init__(self, path, config: MultiChoiceTaskConfig): self.is_single_token = True # set to False later in process_single_item func self.eval_data = [] super().__init__(path, config) def process_single_item(self, item): text, choices, label = ( get_tokenized_input(item, "inputs"), get_tokenized_input(item, "choices"), item["label"], ) pair_ID, sent_ID, bias_type = ( item["pair_ID"], item["sent_ID"], item["bias_type"], ) tgt_seq_length = sum([len(choice) for choice in choices]) if tgt_seq_length == len(choices): # For single token, we only insert one [sop] tgt_seq_length = 1 assert tgt_seq_length < self.config.max_seq_length if len(text) + tgt_seq_length + 2 > self.config.max_seq_length: text_length = self.config.max_seq_length - tgt_seq_length - 2 text = text[len(text) - text_length : len(text)] assert not ( self.mask_id in text and self.config.use_multitask_encoding ), "Unified multitask encoding don't support blank filling" if tgt_seq_length != 1: self.is_single_token = False dataset = { "text": text, "choices": choices, "label": label, "pair_ID": pair_ID, "sent_ID": sent_ID, "bias_type": bias_type, } return dataset