浏览代码

remove test code

cndn 2 年之前
父节点
当前提交
04a476aff2
共有 1 个文件被更改,包括 5 次插入38 次删除
  1. 5 38
      src/seamless_communication/models/inference/translator.py

+ 5 - 38
src/seamless_communication/models/inference/translator.py

@@ -33,13 +33,6 @@ from seamless_communication.models.unity import (
 from seamless_communication.models.unity.generator import SequenceToUnitOutput
 from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
 
-import urllib.request
-from urllib.request import urlopen
-import ssl
-import json
-
-ssl._create_default_https_context = ssl._create_unverified_context
-
 
 class Task(Enum):
     S2ST = auto()
@@ -106,16 +99,10 @@ class Translator(nn.Module):
         else:
             max_len_a = 1
         text_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(1, 200))
-        unit_opts = SequenceGeneratorOptions(
-            beam_size=5, soft_max_seq_len=(max_len_a, 50)
-        )
+        unit_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(max_len_a, 50))
         if ngram_filtering:
-            text_opts.logits_processor = NGramRepeatBlockLogitsProcessor(
-                no_repeat_ngram_size=10
-            )
-            unit_opts.logits_processor = NGramRepeatBlockLogitsProcessor(
-                no_repeat_ngram_size=10
-            )
+            text_opts.logits_processor = NGramRepeatBlockLogitsProcessor(no_repeat_ngram_size=10)
+            unit_opts.logits_processor = NGramRepeatBlockLogitsProcessor(no_repeat_ngram_size=10)
         generator = UnitYGenerator(
             model,
             text_tokenizer,
@@ -125,11 +112,7 @@ class Translator(nn.Module):
             unit_opts=unit_opts,
         )
         return generator(
-            src["seqs"],
-            src["seq_lens"],
-            input_modality.value,
-            output_modality.value,
-            ngram_filtering=ngram_filtering,
+            src["seqs"], src["seq_lens"], input_modality.value, output_modality.value, ngram_filtering=ngram_filtering
         )
 
     def get_modalities_from_task(self, task: Task) -> Tuple[Modality, Modality]:
@@ -221,7 +204,7 @@ class Translator(nn.Module):
             input_modality,
             output_modality,
             tgt_lang=tgt_lang,
-            ngram_filtering=ngram_filtering,
+            ngram_filtering=ngram_filtering
         )
 
         text_out = result[0]
@@ -232,19 +215,3 @@ class Translator(nn.Module):
             units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
             wav_out, sr_out = self.synthesize_speech(units, tgt_lang, spkr)
             return text_out.sentences[0], wav_out, sr_out
-
-
-if __name__ == "__main__":
-    import torchaudio
-
-    # audio = "/data/home/dnn/LJ003-0001.wav"
-    audio = "/data/home/dnn/oss_sc/seamless_communication/spanish_repeat.wav"
-    translator = Translator(
-        "seamlessM4T_large", "vocoder_36langs", torch.device("cuda:0")
-    )
-    text_out, wav, sr = translator.predict(audio, "s2st", "deu", ngram_filtering=True)  # type: ignore
-    torchaudio.save(
-        "/data/home/dnn/deu_testing.wav",
-        wav[0].cpu(),
-        sample_rate=sr,
-    )