|
@@ -367,12 +367,12 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
break
|
|
break
|
|
|
|
|
|
pred_indices.append(index)
|
|
pred_indices.append(index)
|
|
- if self.state_bag.step == 0:
|
|
|
|
- self.state_bag.increment_step(
|
|
|
|
|
|
+ if self.state_bag.step_nr == 0:
|
|
|
|
+ self.state_bag.increment_step_nr(
|
|
len(self.prefix_indices + states.target_indices)
|
|
len(self.prefix_indices + states.target_indices)
|
|
)
|
|
)
|
|
else:
|
|
else:
|
|
- self.state_bag.increment_step()
|
|
|
|
|
|
+ self.state_bag.increment_step_nr()
|
|
|
|
|
|
states.target_indices += pred_indices
|
|
states.target_indices += pred_indices
|
|
|
|
|
|
@@ -426,7 +426,7 @@ class UnitYMMATextDecoderAgent(MMASpeechToTextDecoderAgent):
|
|
# TODO: a temporary solution.
|
|
# TODO: a temporary solution.
|
|
ending_token_index = self.text_tokenizer.model.token_to_index(",")
|
|
ending_token_index = self.text_tokenizer.model.token_to_index(",")
|
|
token_list.append(ending_token_index)
|
|
token_list.append(ending_token_index)
|
|
- self.state_bag.increment_step()
|
|
|
|
|
|
+ self.state_bag.increment_step_nr()
|
|
|
|
|
|
_, _, decoder_features = self.run_decoder(states, [ending_token_index])
|
|
_, _, decoder_features = self.run_decoder(states, [ending_token_index])
|
|
decoder_features_out = torch.cat(
|
|
decoder_features_out = torch.cat(
|