|
@@ -24,7 +24,8 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
|
|
|
super().__init__(args)
|
|
|
self.vocoder = vocoder
|
|
|
self.upstream_idx = args.upstream_idx
|
|
|
- self.sample_rate = args.sample_rate
|
|
|
+ self.sample_rate = args.sample_rate # input sample rate
|
|
|
+ self.vocoder_sample_rate = args.vocoder_sample_rate # output sample rate
|
|
|
self.tgt_lang = args.tgt_lang
|
|
|
self.convert_to_fbank = WaveformToFbankConverter(
|
|
|
num_mel_bins=80,
|
|
@@ -72,7 +73,7 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
|
|
|
|
|
|
audio_dict = {
|
|
|
"waveform": torch.tensor(source, dtype=torch.float32, device=self.device).unsqueeze(1),
|
|
|
- "sample_rate": 16000, # input audio is fixed to 16kHZ
|
|
|
+ "sample_rate": self.sample_rate,
|
|
|
"format": -1,
|
|
|
}
|
|
|
|
|
@@ -96,7 +97,7 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
|
|
|
SpeechSegment(
|
|
|
content=wav[0][0].tolist(),
|
|
|
finished=states.source_finished,
|
|
|
- sample_rate=self.sample_rate,
|
|
|
+ sample_rate=self.vocoder_sample_rate,
|
|
|
tgt_lang=tgt_lang,
|
|
|
),
|
|
|
finished=states.source_finished,
|
|
@@ -110,6 +111,12 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
|
|
|
default=0,
|
|
|
help="index of encoder states where states.source contains input audio",
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--vocoder-sample-rate",
|
|
|
+ type=int,
|
|
|
+ default=16000,
|
|
|
+ help="sample rate out of the vocoder"
|
|
|
+ )
|
|
|
|
|
|
@classmethod
|
|
|
def from_args(cls, args: Namespace, **kwargs: Dict[str, Any]) -> PretsselVocoderAgent:
|