2
0

task.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from string import punctuation
  2. from functools import partial
  3. from typing import List
  4. from evaluation import qa_evaluate, GenerationTask
  5. from .strategy import BeamSearchStrategyForLAMBADA
  6. def exact_match_score(prediction, ground_truth):
  7. return prediction.strip() == ground_truth.strip()
  8. class LAMBADA(GenerationTask):
  9. @property
  10. def metrics(self):
  11. return {"Accuracy": partial(qa_evaluate, metric=exact_match_score)}
  12. def __init__(self, model, tokenizer, config_path):
  13. super(LAMBADA, self).__init__(model, tokenizer, config_path)
  14. if self.config.sampling_strategy == "BeamSearchStrategy":
  15. banned_prefix = [[46010], [146337]] # "'" and "``"
  16. invalid_slices = [20068, 146010, 146337]
  17. for p in punctuation:
  18. pp = tokenizer.tokenize(p)
  19. if len(pp) == 1:
  20. invalid_slices.append(pp[0])
  21. banned_prefix.append(pp)
  22. self.strategy = BeamSearchStrategyForLAMBADA(
  23. batch_size=self.config.micro_batch_size,
  24. num_beams=self.config.num_beams,
  25. length_penalty=self.config.length_penalty,
  26. consider_end=True,
  27. end_tokens=self.strategy.end_tokens,
  28. invalid_slices=invalid_slices,
  29. banned_prefix=banned_prefix,
  30. no_repeat_ngram_size=self.config.no_repeat_ngram_size,
  31. min_gen_length=self.config.min_gen_length,
  32. deterministic=True,
  33. )
  34. def get_first_word_tokens(self, tokens):
  35. text = self.tokenizer.tokenizer.decode(tokens).strip()
  36. return self.tokenizer.tokenize(text.split(" ")[0])
  37. def predict_single_batch(self, batch):
  38. # micro batch size = 1 here, but we still need to return a list of predictions for consistency
  39. outputs_batch: List[List[List[int]]] = self.model.generate_text(batch, self.strategy, return_all_beams=True)
  40. predictions = []
  41. for outputs in outputs_batch:
  42. found = False
  43. for output in outputs:
  44. text = self.tokenizer.tokenizer.decode(output).strip()
  45. spl = text.split(" ")
  46. if len(spl) >= 2 and spl[1] in punctuation:
  47. predictions.append(self.get_first_word_tokens(output))
  48. found = True
  49. break
  50. if not found:
  51. predictions.append(self.get_first_word_tokens(outputs[0]))
  52. return predictions