|
@@ -118,7 +118,7 @@ class ModelForEvaluation(torch.nn.Module):
|
|
return log_probs
|
|
return log_probs
|
|
|
|
|
|
def generate_text(self, sample, strategy, return_all_beams=False) -> Union[
|
|
def generate_text(self, sample, strategy, return_all_beams=False) -> Union[
|
|
- List[int], List[List[int]]]:
|
|
|
|
|
|
+ List[List[int]], List[List[List[int]]]]:
|
|
"""
|
|
"""
|
|
@return: A list of text model generated, sorted by score in descending order
|
|
@return: A list of text model generated, sorted by score in descending order
|
|
"""
|
|
"""
|
|
@@ -155,16 +155,16 @@ class ModelForEvaluation(torch.nn.Module):
|
|
)[0]
|
|
)[0]
|
|
|
|
|
|
if isinstance(output, torch.Tensor): # different strategies
|
|
if isinstance(output, torch.Tensor): # different strategies
|
|
- output = list(output)
|
|
|
|
|
|
+ output = output.tolist()
|
|
|
|
|
|
output_targets = []
|
|
output_targets = []
|
|
context_length = seqs.shape[1]
|
|
context_length = seqs.shape[1]
|
|
for lines in output:
|
|
for lines in output:
|
|
|
|
+ lines = lines.tolist() if isinstance(lines, torch.Tensor) else lines
|
|
output_target = []
|
|
output_target = []
|
|
if not isinstance(lines, list):
|
|
if not isinstance(lines, list):
|
|
lines = [lines]
|
|
lines = [lines]
|
|
for line in lines:
|
|
for line in lines:
|
|
- line = line.tolist()
|
|
|
|
unfinished = line.index(-1) if -1 in line else len(line)
|
|
unfinished = line.index(-1) if -1 in line else len(line)
|
|
if line[unfinished - 1] in strategy.end_tokens:
|
|
if line[unfinished - 1] in strategy.end_tokens:
|
|
unfinished -= 1
|
|
unfinished -= 1
|