|
@@ -244,7 +244,8 @@ class Translator(nn.Module):
|
|
|
return text_out.sentences[0], None, None
|
|
return text_out.sentences[0], None, None
|
|
|
else:
|
|
else:
|
|
|
if isinstance(self.model.t2u_model, UnitYT2UModel):
|
|
if isinstance(self.model.t2u_model, UnitYT2UModel):
|
|
|
- # Remove the lang token for AR UnitY.
|
|
|
|
|
|
|
+ # Remove the lang token for AR UnitY since the vocoder doesn't need it
|
|
|
|
|
+ # in the unit sequence. tgt_lang is fed as an argument to the vocoder.
|
|
|
units = unit_out.units[:, 1:]
|
|
units = unit_out.units[:, 1:]
|
|
|
else:
|
|
else:
|
|
|
units = unit_out.units
|
|
units = unit_out.units
|