2
0

task.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from string import punctuation
  2. from functools import partial
  3. from typing import List
  4. from evaluation import qa_evaluate, GenerationTask
  5. from .strategy import BeamSearchStrategyForLAMBADA
  6. def exact_match_score(prediction, ground_truth):
  7. return prediction.strip() == ground_truth.strip()
  8. class LAMBADA(GenerationTask):
  9. @property
  10. def metrics(self):
  11. return {"Accuracy": partial(qa_evaluate, metric=exact_match_score)}
  12. def __init__(self, model, tokenizer, config_path):
  13. super(LAMBADA, self).__init__(model, tokenizer, config_path)
  14. if self.config.sampling_strategy == "BeamSearchStrategy":
  15. banned_prefix = [[46010], [146337]] # "'" and "``"
  16. invalid_slices = [20068, 146010, 146337]
  17. for p in punctuation:
  18. pp = tokenizer.tokenize(p)
  19. if len(pp) == 1:
  20. invalid_slices.append(pp[0])
  21. banned_prefix.append(pp)
  22. self.strategy = BeamSearchStrategyForLAMBADA(
  23. self.config.num_beams,
  24. length_penalty=self.config.length_penalty,
  25. consider_end=True,
  26. end_tokens=self.strategy.end_tokens,
  27. invalid_slices=invalid_slices,
  28. banned_prefix=banned_prefix,
  29. no_repeat_ngram_size=self.config.no_repeat_ngram_size,
  30. min_gen_length=self.config.min_gen_length,
  31. deterministic=True,
  32. )
  33. def get_first_word_tokens(self, tokens):
  34. text = self.tokenizer.tokenizer.decode(tokens).strip()
  35. return self.tokenizer.tokenize(text.split(" ")[0])
  36. def predict_single_batch(self, batch):
  37. # micro batch size = 1 here, but we still need to return a list of predictions for consistency
  38. outputs: List[List[int]] = self.model.generate_text(batch, self.strategy, return_all_beams=True)
  39. for output in outputs:
  40. text = self.tokenizer.tokenizer.decode(output).strip()
  41. spl = text.split(" ")
  42. if len(spl) >= 2 and spl[1] in punctuation:
  43. return [self.get_first_word_tokens(output)]
  44. return [self.get_first_word_tokens(outputs[0])]