|
|
@@ -26,6 +26,7 @@ from seamless_communication.models.unity import (
|
|
|
UnitTokenizer,
|
|
|
UnitYGenerator,
|
|
|
UnitYModel,
|
|
|
+ UnitYT2UModel,
|
|
|
load_unity_model,
|
|
|
load_unity_text_tokenizer,
|
|
|
load_unity_unit_tokenizer,
|
|
|
@@ -242,6 +243,13 @@ class Translator(nn.Module):
|
|
|
if output_modality == Modality.TEXT:
|
|
|
return text_out.sentences[0], None, None
|
|
|
else:
|
|
|
- units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
|
|
|
+ if isinstance(self.model.t2u_model, UnitYT2UModel):
|
|
|
+ # Remove the lang token for AR UnitY.
|
|
|
+ units = unit_out.units[:, 1:]
|
|
|
+ else:
|
|
|
+ units = unit_out.units
|
|
|
+
|
|
|
+ # TODO: batch_size set to 1 for now, implement batching.
|
|
|
+ units = units[0].cpu().numpy().tolist()
|
|
|
wav_out = self.vocoder(units, tgt_lang, spkr, dur_prediction=True)
|
|
|
return text_out.sentences[0], wav_out, sample_rate
|