Sengxian пре 3 година
родитељ
комит
478eb5c9d0
5 измењених фајлова са 18 додато и 20 уклоњено
  1. 7 7
      evaluation/model.py
  2. 3 4
      generate.py
  3. 2 5
      generation/strategies.py
  4. 4 3
      tasks/lambada/strategy.py
  5. 2 1
      tasks/lambada/task.py

+ 7 - 7
evaluation/model.py

@@ -40,9 +40,8 @@ def batch_filling_sequence(
         # Now, we want to generate seq[counter + 1],
         # token[:, index: counter+1] needs forwarding.
         # forward
-        if num_beams > 1:
-            tokens = tokens.reshape(batch_size * num_beams, -1)
-            mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1])
+        tokens = tokens.reshape(batch_size * num_beams, -1)
+        mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1]) if mems is not None else None
         logits, *output_per_layers = model(
             tokens[:, index:],
             position_ids[..., index: counter+1],
@@ -58,11 +57,12 @@ def batch_filling_sequence(
             logits = logits[:, -1]
         counter += 1
         index = counter
+        # if torch.distributed.get_rank() == 0:
+        #     print(f"counter: {counter}: logits: {logits.float().abs().mean()}")
         # sampling
-        if num_beams > 1:
-            logits = logits.reshape(batch_size, num_beams, -1)
-            tokens = tokens.reshape(batch_size, num_beams, -1)
-            mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
+        logits = logits.reshape(batch_size, num_beams, -1)
+        tokens = tokens.reshape(batch_size, num_beams, -1)
+        mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
         tokens, mems = strategy.forward(logits, tokens, mems)
         if len(tokens.shape) == 3 and num_beams == 1:
             num_beams = tokens.shape[1]

+ 3 - 4
generate.py

@@ -114,14 +114,13 @@ def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], L
             ),
         )
         if isinstance(output, torch.Tensor):  # different strategies
-            output = list(output)
-        else:
-            output = output[0]
+            output = output.tolist()
+        output = output[0]  # batch_size = 1
         output_list.extend(output)
 
         # clip -1s and fill back generated things into seq
         for i in range(len(output_list)):
-            output = output_list[i].tolist()
+            output = output_list[i].tolist() if isinstance(output_list[i], torch.Tensor) else output_list[i]
             try:
                 unfinished = output.index(-1)
             except ValueError:

+ 2 - 5
generation/strategies.py

@@ -21,6 +21,7 @@ class BaseStrategy:
         return self._is_done.all()
 
     def forward(self, logits, tokens, mems, temperature=None):
+        logits = logits.view(-1, logits.size(-1))
         batch_size = tokens.shape[0]
         if temperature is None:
             temperature = self.temperature
@@ -38,7 +39,7 @@ class BaseStrategy:
                 pred[i] = -1
             elif pred[i].item() in self.end_tokens:
                 self._is_done[i] = True
-        tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
+        tokens = torch.cat((tokens, pred.view(tokens.shape[:-1] + (1,))), dim=-1)
         return tokens, mems
 
     def finalize(self, tokens, mems):
@@ -94,10 +95,6 @@ class BeamSearchStrategy:
         return self._is_done.all()
 
     def forward(self, logits, tokens, mems):
-        if len(logits.shape) == 2:
-            logits = logits.unsqueeze(1)
-            tokens = tokens.unsqueeze(1)
-            mems = mems.unsqueeze(2)
         batch_size, num_beams, vocab_size = logits.shape
         seq_len = tokens.shape[-1]
         logits = logits.float()

+ 4 - 3
tasks/lambada/strategy.py

@@ -7,7 +7,7 @@ class BeamSearchStrategyForLAMBADA(BeamSearchStrategy):
         self.banned_prefix = banned_prefix
 
     def forward(self, logits, tokens, mems):
-        batch_size, vocab_size = logits.shape
+        batch_size, num_beams, vocab_size = logits.shape
         logits = logits.float()
         for prefix in self.banned_prefix:
             if self.length_generated == len(prefix) - 1:
@@ -15,6 +15,7 @@ class BeamSearchStrategyForLAMBADA(BeamSearchStrategy):
                     logits[..., prefix[0]] = -65504
                 else:
                     for i in range(batch_size):
-                        if tokens[i, -(len(prefix) - 1) :].tolist() == prefix[:-1]:
-                            logits[i, prefix[-1]] = -65504
+                        for j in range(num_beams):
+                            if tokens[i, j, -(len(prefix) - 1) :].tolist() == prefix[:-1]:
+                                logits[i, j, prefix[-1]] = -65504
         return super().forward(logits, tokens, mems)

+ 2 - 1
tasks/lambada/task.py

@@ -28,7 +28,8 @@ class LAMBADA(GenerationTask):
                     invalid_slices.append(pp[0])
                 banned_prefix.append(pp)
             self.strategy = BeamSearchStrategyForLAMBADA(
-                self.config.num_beams,
+                batch_size=self.config.micro_batch_size,
+                num_beams=self.config.num_beams,
                 length_penalty=self.config.length_penalty,
                 consider_end=True,
                 end_tokens=self.strategy.end_tokens,