model.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import torch
  2. from typing import List, Union
  3. from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
  4. from SwissArmyTransformer.mpu import vocab_parallel_cross_entropy
  5. class ModelForEvaluation(torch.nn.Module):
  6. def __init__(self, model):
  7. super().__init__()
  8. self.model = model
  9. @staticmethod
  10. def process_data(batch):
  11. return (
  12. batch["tokens"].to(device=torch.cuda.current_device()).long(),
  13. batch["position_ids"].to(device=torch.cuda.current_device()).long(),
  14. batch["attention_mask"].to(device=torch.cuda.current_device()).bool().unsqueeze(1),
  15. )
  16. def cond_log_prob(self, batch) -> List[List[float]]:
  17. """
  18. @return: Conditional log probability of each option
  19. """
  20. tokens, position_ids, attention_mask = self.process_data(batch)
  21. choices_batch, choice_target_ids_batch = batch["choices"], batch["choice_target_ids"]
  22. is_single_token = batch["is_single_token"]
  23. self.model.eval()
  24. with torch.no_grad():
  25. logits, *output_per_layers = self.model(tokens, position_ids, attention_mask, log_attention_weights=None)
  26. logits_batch = torch.nn.functional.log_softmax(logits, dim=-1)
  27. # output: [b, sq, vocab]
  28. log_probs = []
  29. if is_single_token: # Single token
  30. for logits, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
  31. log_probs.append(logits[choice_target_ids[0], choices].tolist())
  32. else: # Multi token
  33. for output, choices, choice_target_ids in zip(logits_batch, choices_batch, choice_target_ids_batch):
  34. log_probs_single = []
  35. for choice, choice_target_id in zip(choices, choice_target_ids):
  36. tmp = output[choice_target_id, choice]
  37. log_probs_single.append(tmp.sum().tolist())
  38. log_probs.append(log_probs_single)
  39. return log_probs
  40. def generate_text(self, sample, strategy, return_all_beams=False) -> Union[List[int], List[List[int]]]:
  41. """
  42. @return: A list of text model generated, sorted by score in descending order
  43. """
  44. seq = torch.squeeze(sample["tokens"].to(device=torch.cuda.current_device()).long())
  45. context_length = sample["context_length"].to(device=torch.cuda.current_device()).long()
  46. seq[context_length:] = -1
  47. def get_masks_and_position_ids(seq):
  48. tokens = seq.unsqueeze(0)
  49. attention_mask = sample["attention_mask"].to(device=torch.cuda.current_device()).bool().unsqueeze(1)
  50. position_ids = sample["position_ids"].to(device=torch.cuda.current_device()).long()
  51. return tokens, attention_mask, position_ids
  52. self.model.eval()
  53. with torch.no_grad():
  54. output = filling_sequence(
  55. self.model,
  56. seq,
  57. get_masks_and_position_ids=get_masks_and_position_ids,
  58. batch_size=strategy.num_beams if hasattr(strategy, "num_beams") else 1,
  59. strategy=strategy,
  60. )[0]
  61. if isinstance(output, torch.Tensor): # different strategies
  62. output = list(output)
  63. output_targets = []
  64. for line in output:
  65. line = line.tolist()
  66. unfinished = line.index(-1) if -1 in line else len(line)
  67. if line[unfinished - 1] in strategy.end_tokens:
  68. unfinished -= 1
  69. line = line[context_length:unfinished]
  70. output_targets.append(line)
  71. return output_targets if return_all_beams else output_targets[0]
  72. def calculate_loss(self, batch) -> List[float]:
  73. tokens, position_ids, attention_mask = self.process_data(batch)
  74. targets, loss_masks = (
  75. batch["targets"].to(device=torch.cuda.current_device()).long(),
  76. batch["loss_masks"].to(device=torch.cuda.current_device()).long(),
  77. )
  78. original_parallel_output = self.model.transformer.parallel_output
  79. self.model.transformer.parallel_output = True
  80. self.model.eval()
  81. with torch.no_grad():
  82. logits, *output_per_layers = self.model(tokens, position_ids, attention_mask, log_attention_weights=None)
  83. losses = vocab_parallel_cross_entropy(logits.contiguous().float(), targets)
  84. loss = torch.sum(losses * loss_masks, dim=-1)
  85. self.model.transformer.parallel_output = original_parallel_output
  86. # return list(zip(loss.tolist(), loss_masks.sum(dim=-1).tolist()))
  87. return loss.tolist()