bigbench.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import torch
  2. import scipy.special
  3. import bigbench.models.model_utils as model_utils
  4. from bigbench.api.model import Model, ModelData
  5. from .model import ModelForEvaluation
  6. from .dataset import MultiChoiceTaskDataset, GenerationTaskDataset
  7. from .configs import MultiChoiceTaskConfig, GenerationTaskConfig, TaskType
  8. from .utils import build_data_loader, gather_result
  9. from generation import BaseStrategy
  10. class ModelForBigBench(Model):
  11. model: ModelForEvaluation
  12. def __init__(self, model, *args, **kwargs):
  13. super().__init__(*args, **kwargs)
  14. self.model = ModelForEvaluation(model)
  15. self.max_seq_length = 2048
  16. def generate_text(self, inputs, max_length, micro_batch_size=8, stop_string=None, output_regex=None):
  17. squeeze = False
  18. if isinstance(inputs, str):
  19. squeeze = True
  20. inputs = [inputs]
  21. config = GenerationTaskConfig(
  22. name="big-bench",
  23. type=TaskType.GENERATION,
  24. max_gen_length=max_length or 128,
  25. path="",
  26. micro_batch_size=micro_batch_size,
  27. )
  28. dataset = GenerationTaskDataset(path=[], config=config)
  29. for text in inputs:
  30. text = dataset.tokenizer.tokenize(text)
  31. if len(text) + config.max_gen_length + 2 > self.max_seq_length:
  32. text_length = self.max_seq_length - config.max_gen_length - 2
  33. text = text[len(text) - text_length : len(text)]
  34. dataset.data.append({"text": text})
  35. end_tokens = [dataset.tokenizer.get_command("eop"), dataset.tokenizer.get_command("eos")]
  36. strategy = BaseStrategy(batch_size=config.micro_batch_size, temperature=1.0, top_k=1, end_tokens=end_tokens)
  37. dataloader = build_data_loader(
  38. dataset,
  39. micro_batch_size=config.micro_batch_size,
  40. num_workers=1,
  41. drop_last=False,
  42. collate_fn=dataset.collate_fn if dataset.has_collate_fn else None,
  43. )
  44. generations = []
  45. with torch.no_grad():
  46. for _, batch in enumerate(dataloader):
  47. generations.append(self.model.generate_text(batch, strategy))
  48. generations = gather_result(generations, len(dataset), config.micro_batch_size)
  49. generations = [dataset.tokenizer.detokenize(generation) for generation in generations]
  50. generations = model_utils.postprocess_output(generations, max_length, stop_string, output_regex)
  51. return generations if not squeeze else generations[0]
  52. def cond_log_prob(self, inputs, targets, absolute_normalization=False):
  53. assert not isinstance(
  54. targets, str
  55. ), "targets in cond_log_prob must be a list (or a list of lists if inputs is a list). targets was instead a str."
  56. squeeze = False
  57. if isinstance(inputs, str):
  58. squeeze = True
  59. inputs, targets = [inputs], [targets]
  60. config = MultiChoiceTaskConfig(name="big-bench", type=TaskType.MULTICHOICE, path="", micro_batch_size=1)
  61. dataset = MultiChoiceTaskDataset(path=[], config=config)
  62. max_sample_length = 0
  63. for text, choices in zip(inputs, targets):
  64. text = dataset.tokenizer.tokenize(text)
  65. choices = [dataset.tokenizer.tokenize(choice) for choice in choices]
  66. tgt_seq_length = sum([len(choice) for choice in choices])
  67. if tgt_seq_length == len(choices):
  68. # For single token, we only insert one [sop]
  69. tgt_seq_length = 1
  70. if tgt_seq_length > 1:
  71. dataset.is_single_token = False
  72. assert tgt_seq_length < self.max_seq_length
  73. if len(text) + tgt_seq_length + 2 > self.max_seq_length:
  74. text_length = self.max_seq_length - tgt_seq_length - 2
  75. text = text[len(text) - text_length : len(text)]
  76. max_sample_length = max(max_sample_length, len(text) + tgt_seq_length + 2)
  77. dataset.data.append({"text": text, "choices": choices})
  78. config.micro_batch_size = max(1, 4 * self.max_seq_length // max_sample_length)
  79. dataloader = build_data_loader(
  80. dataset,
  81. micro_batch_size=config.micro_batch_size,
  82. num_workers=1,
  83. drop_last=False,
  84. collate_fn=dataset.collate_fn if dataset.has_collate_fn else None,
  85. )
  86. scores = []
  87. with torch.no_grad():
  88. for _, batch in enumerate(dataloader):
  89. scores.append(self.model.cond_log_prob(batch))
  90. scores = gather_result(scores, len(dataset), config.micro_batch_size)
  91. if absolute_normalization:
  92. log_probs = scores
  93. else:
  94. log_normalizer = scipy.special.logsumexp(scores)
  95. log_probs = [lp - log_normalizer for lp in scores]
  96. return log_probs if not squeeze else log_probs[0]
  97. def model_data(self):
  98. return ModelData(
  99. model_family="GLM",
  100. model_name="GLM-130B",
  101. total_params=130000000000,
  102. non_embedding_params=130000000000,
  103. flop_matched_non_embedding_params=130000000000,
  104. training_batch_size=4224,
  105. training_steps=49300,
  106. description="An Open Bilingual Pre-Trained Model",
  107. decoding_params={},
  108. )