1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- from string import punctuation
- from functools import partial
- from typing import List
- from evaluation import qa_evaluate, GenerationTask
- from .strategy import BeamSearchStrategyForLAMBADA
- def exact_match_score(prediction, ground_truth):
- return prediction.strip() == ground_truth.strip()
- class LAMBADA(GenerationTask):
- @property
- def metrics(self):
- return {"Accuracy": partial(qa_evaluate, metric=exact_match_score)}
- def __init__(self, model, tokenizer, config_path):
- super(LAMBADA, self).__init__(model, tokenizer, config_path)
- if self.config.sampling_strategy == "BeamSearchStrategy":
- banned_prefix = [[46010], [146337]] # "'" and "``"
- invalid_slices = [20068, 146010, 146337]
- for p in punctuation:
- pp = tokenizer.tokenize(p)
- if len(pp) == 1:
- invalid_slices.append(pp[0])
- banned_prefix.append(pp)
- self.strategy = BeamSearchStrategyForLAMBADA(
- batch_size=self.config.micro_batch_size,
- num_beams=self.config.num_beams,
- length_penalty=self.config.length_penalty,
- consider_end=True,
- end_tokens=self.strategy.end_tokens,
- invalid_slices=invalid_slices,
- banned_prefix=banned_prefix,
- no_repeat_ngram_size=self.config.no_repeat_ngram_size,
- min_gen_length=self.config.min_gen_length,
- deterministic=True,
- )
- def get_first_word_tokens(self, tokens):
- text = self.tokenizer.tokenizer.decode(tokens).strip()
- return self.tokenizer.tokenize(text.split(" ")[0])
- def predict_single_batch(self, batch):
- # micro batch size = 1 here, but we still need to return a list of predictions for consistency
- outputs_batch: List[List[List[int]]] = self.model.generate_text(batch, self.strategy, return_all_beams=True)
- predictions = []
- for outputs in outputs_batch:
- found = False
- for output in outputs:
- text = self.tokenizer.tokenizer.decode(output).strip()
- spl = text.split(" ")
- if len(spl) >= 2 and spl[1] in punctuation:
- predictions.append(self.get_first_word_tokens(output))
- found = True
- break
- if not found:
- predictions.append(self.get_first_word_tokens(outputs[0]))
- return predictions
|