|
@@ -21,13 +21,15 @@ import torch
|
|
audio_input, _ = torchaudio.load(TEST_AUDIO_PATH) # Load waveform using torchaudio
|
|
audio_input, _ = torchaudio.load(TEST_AUDIO_PATH) # Load waveform using torchaudio
|
|
|
|
|
|
s2t_model = torch.jit.load("unity_on_device_s2t.ptl") # Load exported S2T model
|
|
s2t_model = torch.jit.load("unity_on_device_s2t.ptl") # Load exported S2T model
|
|
-text = s2t_model(audio_input, tgt_lang=TGT_LANG) # Forward call with tgt_lang specified for ASR or S2TT
|
|
|
|
-print(f"{lang}:{text}")
|
|
|
|
|
|
+with torch.no_grad():
|
|
|
|
+ text = s2t_model(audio_input, tgt_lang=TGT_LANG) # Forward call with tgt_lang specified for ASR or S2TT
|
|
|
|
+print(text) # Show text output
|
|
|
|
|
|
s2st_model = torch.jit.load("unity_on_device.ptl")
|
|
s2st_model = torch.jit.load("unity_on_device.ptl")
|
|
-text, units, waveform = s2st_model(audio_input, tgt_lang=TGT_LANG) # S2ST model also returns waveform
|
|
|
|
-print(f"{lang}:{text}")
|
|
|
|
-torchaudio.save(f"{OUTPUT_FOLDER}/{lang}.wav", waveform.unsqueeze(0), sample_rate=16000) # Save output waveform to local file
|
|
|
|
|
|
+with torch.no_grad():
|
|
|
|
+ text, units, waveform = s2st_model(audio_input, tgt_lang=TGT_LANG) # S2ST model also returns waveform
|
|
|
|
+print(text)
|
|
|
|
+torchaudio.save(f"{OUTPUT_FOLDER}/result.wav", waveform.unsqueeze(0), sample_rate=16000) # Save output waveform to local file
|
|
```
|
|
```
|
|
|
|
|
|
|
|
|