strategy.py 822 B

1234567891011121314151617181920
  1. from generation import BeamSearchStrategy
  2. class BeamSearchStrategyForLAMBADA(BeamSearchStrategy):
  3. def __init__(self, *args, banned_prefix=[], **kwargs):
  4. super().__init__(*args, **kwargs)
  5. self.banned_prefix = banned_prefix
  6. def forward(self, logits, tokens, mems):
  7. batch_size, vocab_size = logits.shape
  8. logits = logits.float()
  9. for prefix in self.banned_prefix:
  10. if self.length_generated == len(prefix) - 1:
  11. if len(prefix) == 1:
  12. logits[..., prefix[0]] = -65504
  13. else:
  14. for i in range(batch_size):
  15. if tokens[i, -(len(prefix) - 1) :].tolist() == prefix[:-1]:
  16. logits[i, prefix[-1]] = -65504
  17. return super().forward(logits, tokens, mems)