Bladeren bron

Fix type bugs

Sengxian 3 jaren geleden
bovenliggende
commit
e131c8b557
2 gewijzigde bestanden met toevoegingen van 17 en 10 verwijderingen
  1. 3 3
      evaluation/model.py
  2. 14 7
      tasks/lambada/task.py

+ 3 - 3
evaluation/model.py

@@ -118,7 +118,7 @@ class ModelForEvaluation(torch.nn.Module):
         return log_probs
 
     def generate_text(self, sample, strategy, return_all_beams=False) -> Union[
-        List[int], List[List[int]]]:
+        List[List[int]], List[List[List[int]]]]:
         """
         @return: A list of text model generated, sorted by score in descending order
         """
@@ -155,16 +155,16 @@ class ModelForEvaluation(torch.nn.Module):
             )[0]
 
         if isinstance(output, torch.Tensor):  # different strategies
-            output = list(output)
+            output = output.tolist()
 
         output_targets = []
         context_length = seqs.shape[1]
         for lines in output:
+            lines = lines.tolist() if isinstance(lines, torch.Tensor) else lines
             output_target = []
             if not isinstance(lines, list):
                 lines = [lines]
             for line in lines:
-                line = line.tolist()
                 unfinished = line.index(-1) if -1 in line else len(line)
                 if line[unfinished - 1] in strategy.end_tokens:
                     unfinished -= 1

+ 14 - 7
tasks/lambada/task.py

@@ -46,10 +46,17 @@ class LAMBADA(GenerationTask):
 
     def predict_single_batch(self, batch):
         # micro batch size = 1 here, but we still need to return a list of predictions for consistency
-        outputs: List[List[int]] = self.model.generate_text(batch, self.strategy, return_all_beams=True)
-        for output in outputs:
-            text = self.tokenizer.tokenizer.decode(output).strip()
-            spl = text.split(" ")
-            if len(spl) >= 2 and spl[1] in punctuation:
-                return [self.get_first_word_tokens(output)]
-        return [self.get_first_word_tokens(outputs[0])]
+        outputs_batch: List[List[List[int]]] = self.model.generate_text(batch, self.strategy, return_all_beams=True)
+        predictions = []
+        for outputs in outputs_batch:
+            found = False
+            for output in outputs:
+                text = self.tokenizer.tokenizer.decode(output).strip()
+                spl = text.split(" ")
+                if len(spl) >= 2 and spl[1] in punctuation:
+                    predictions.append(self.get_first_word_tokens(output))
+                    found = True
+                    break
+            if not found:
+                predictions.append(self.get_first_word_tokens(outputs[0]))
+        return predictions