task.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from string import punctuation
  2. from functools import partial
  3. from typing import List
  4. from evaluation import qa_evaluate, GenerationTask
  5. from .strategy import BeamSearchStrategyForLAMBADA
  6. def exact_match_score(prediction, ground_truth):
  7. return prediction.strip() == ground_truth.strip()
  8. class LAMBADA(GenerationTask):
  9. @property
  10. def metrics(self):
  11. return {"Accuracy": partial(qa_evaluate, metric=exact_match_score)}
  12. def __init__(self, model, tokenizer, config_path):
  13. super(LAMBADA, self).__init__(model, tokenizer, config_path)
  14. if self.config.sampling_strategy == "BeamSearchStrategy":
  15. banned_prefix = [[46010], [146337]] # "'" and "``"
  16. invalid_slices = [20068, 146010, 146337]
  17. for p in punctuation:
  18. pp = tokenizer.tokenize(p)
  19. if len(pp) == 1:
  20. invalid_slices.append(pp[0])
  21. banned_prefix.append(pp)
  22. self.strategy = BeamSearchStrategyForLAMBADA(
  23. batch_size=self.config.micro_batch_size,
  24. num_beams=self.config.num_beams,
  25. length_penalty=self.config.length_penalty,
  26. consider_end=True,
  27. end_tokens=self.strategy.end_tokens,
  28. invalid_slices=invalid_slices,
  29. banned_prefix=banned_prefix,
  30. no_repeat_ngram_size=self.config.no_repeat_ngram_size,
  31. min_gen_length=self.config.min_gen_length,
  32. deterministic=True,
  33. )
  34. def get_first_word_tokens(self, tokens):
  35. text = self.tokenizer.tokenizer.decode(tokens).strip()
  36. return self.tokenizer.tokenize(text.split(" ")[0])
  37. def predict_single_batch(self, batch):
  38. outputs_batch: List[List[List[int]]] = self.model.generate_text(batch, self.strategy, return_all_beams=True)
  39. predictions = []
  40. for outputs in outputs_batch:
  41. found = False
  42. for output in outputs:
  43. text = self.tokenizer.tokenizer.decode(output).strip()
  44. spl = text.split(" ")
  45. if len(spl) >= 2 and spl[1] in punctuation:
  46. predictions.append(self.get_first_word_tokens(output))
  47. found = True
  48. break
  49. if not found:
  50. predictions.append(self.get_first_word_tokens(outputs[0]))
  51. return predictions