|
@@ -45,7 +45,6 @@ class LAMBADA(GenerationTask):
|
|
return self.tokenizer.tokenize(text.split(" ")[0])
|
|
return self.tokenizer.tokenize(text.split(" ")[0])
|
|
|
|
|
|
def predict_single_batch(self, batch):
|
|
def predict_single_batch(self, batch):
|
|
- # micro batch size = 1 here, but we still need to return a list of predictions for consistency
|
|
|
|
outputs_batch: List[List[List[int]]] = self.model.generate_text(batch, self.strategy, return_all_beams=True)
|
|
outputs_batch: List[List[List[int]]] = self.model.generate_text(batch, self.strategy, return_all_beams=True)
|
|
predictions = []
|
|
predictions = []
|
|
for outputs in outputs_batch:
|
|
for outputs in outputs_batch:
|