model.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import torch
  2. from typing import List, Union
  3. from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
  4. class ModelForEvaluation(torch.nn.Module):
  5. def __init__(self, model):
  6. super().__init__()
  7. self.model = model
  8. @staticmethod
  9. def process_data(batch):
  10. return (
  11. batch["tokens"].to(device=torch.cuda.current_device()).long(),
  12. batch["position_ids"].to(device=torch.cuda.current_device()).long(),
  13. batch["attention_mask"].to(device=torch.cuda.current_device()).bool().unsqueeze(1),
  14. )
  15. def cond_log_prob(self, batch) -> List[List[float]]:
  16. """
  17. @return: Conditional log probability of each option
  18. """
  19. tokens, position_ids, attention_mask = self.process_data(batch)
  20. choices_batch, choice_target_ids_batch = batch["choices"], batch["choice_target_ids"]
  21. is_single_token = batch["is_single_token"]
  22. self.model.eval()
  23. with torch.no_grad():
  24. logits, *output_per_layers = self.model(tokens, position_ids, attention_mask, log_attention_weights=None)
  25. logits_batch = torch.nn.functional.log_softmax(logits, dim=-1)
  26. # output: [b, sq, vocab]
  27. log_probs = []
  28. if is_single_token: # Single token
  29. for logits, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
  30. log_probs.append(logits[choice_target_ids[0], choices].tolist())
  31. else: # Multi token
  32. for output, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
  33. log_probs_single = []
  34. for choice, choice_target_id in zip(choices, choice_target_ids):
  35. tmp = output[choice_target_id, choice]
  36. log_probs_single.append(tmp.sum().tolist())
  37. log_probs.append(log_probs_single)
  38. return log_probs
  39. def generate_text(self, sample, strategy, return_all_beams=False) -> Union[List[int], List[List[int]]]:
  40. """
  41. @return: A list of text model generated, sorted by score in descending order
  42. """
  43. seq = torch.squeeze(sample["tokens"].to(device=torch.cuda.current_device()).long())
  44. context_length = sample["context_length"].to(device=torch.cuda.current_device()).long()
  45. seq[context_length:] = -1
  46. def get_masks_and_position_ids(seq):
  47. tokens = seq.unsqueeze(0)
  48. attention_mask = sample["attention_mask"].to(device=torch.cuda.current_device()).bool().unsqueeze(1)
  49. position_ids = sample["position_ids"].to(device=torch.cuda.current_device()).long()
  50. return tokens, attention_mask, position_ids
  51. self.model.eval()
  52. with torch.no_grad():
  53. output = filling_sequence(
  54. self.model,
  55. seq,
  56. get_masks_and_position_ids=get_masks_and_position_ids,
  57. batch_size=strategy.num_beams if hasattr(strategy, "num_beams") else 1,
  58. strategy=strategy,
  59. )[0]
  60. if isinstance(output, torch.Tensor): # different strategies
  61. output = list(output)
  62. output_targets = []
  63. for line in output:
  64. line = line.tolist()
  65. unfinished = line.index(-1) if -1 in line else len(line)
  66. if line[unfinished - 1] in strategy.end_tokens:
  67. unfinished -= 1
  68. line = line[context_length:unfinished]
  69. output_targets.append(line)
  70. return output_targets if return_all_beams else output_targets[0]