|
@@ -117,7 +117,7 @@ class ModelForEvaluation(torch.nn.Module):
|
|
|
log_probs.append(log_probs_single)
|
|
|
return log_probs
|
|
|
|
|
|
- def generate_text(self, sample, strategy, return_all_beams=False, max_gen_length=128) -> Union[
|
|
|
+ def generate_text(self, sample, strategy, return_all_beams=False) -> Union[
|
|
|
List[int], List[List[int]]]:
|
|
|
"""
|
|
|
@return: A list of text model generated, sorted by score in descending order
|
|
@@ -128,6 +128,7 @@ class ModelForEvaluation(torch.nn.Module):
|
|
|
|
|
|
def get_masks_and_position_ids(seq):
|
|
|
batch_size = seq.shape[0]
|
|
|
+ max_gen_length = sample['target_position_ids'].shape[-1]
|
|
|
tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode='constant', value=-1)
|
|
|
position_ids = torch.cat((sample['position_ids'], sample['target_position_ids']), dim=-1)
|
|
|
position_ids = position_ids.to(device=torch.cuda.current_device()).long()
|