test_translator.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from typing import Final
  7. import torch
  8. from fairseq2.typing import Device
  9. from seamless_communication.inference import Translator
  10. from tests.common import device
  11. # fmt: off
  12. ENG_SENTENCE: Final = "On Monday, scientists from the Stanford University School of Medicine announced the invention of a new diagnostic tool that can sort cells by type: a tiny printable chip that can be manufactured using standard inkjet printers for possibly about one U.S. cent each."
  13. DEU_SENTENCE: Final = "Am Montag kündigten Wissenschaftler der Stanford University School of Medicine die Erfindung eines neuen Diagnosewerkzeugs an, das Zellen nach Typ sortieren kann: ein winziger druckbarer Chip, der mit Standard-Tintenstrahldruckern für etwa einen US-Cent hergestellt werden kann."
  14. DEU_SENTENCE_V2: Final = "Am Montag kündigten Wissenschaftler der Stanford University School of Medicine die Erfindung eines neuen diagnostischen Werkzeugs an, das Zellen nach Typ sortieren kann: ein winziger druckbarer Chip, der mit Standard-Tintenstrahldrucker für möglicherweise etwa einen US-Cent pro Stück hergestellt werden kann."
  15. # fmt: on
  16. def test_seamless_m4t_large_t2tt() -> None:
  17. model_name = "seamlessM4T_large"
  18. src_lang = "eng"
  19. tgt_lang = "deu"
  20. if device == Device("cpu"):
  21. dtype = torch.float32
  22. else:
  23. dtype = torch.float16
  24. translator = Translator(model_name, "vocoder_36langs", device, dtype=dtype)
  25. text_output, _ = translator.predict(
  26. ENG_SENTENCE,
  27. "t2tt",
  28. tgt_lang,
  29. src_lang=src_lang,
  30. )
  31. assert text_output[0] == DEU_SENTENCE, f"'{text_output[0]}' is not '{DEU_SENTENCE}'"
  32. def test_seamless_m4t_v2_large_t2tt() -> None:
  33. model_name = "seamlessM4T_v2_large"
  34. src_lang = "eng"
  35. tgt_lang = "deu"
  36. if device == Device("cpu"):
  37. dtype = torch.float32
  38. else:
  39. dtype = torch.float16
  40. translator = Translator(model_name, "vocoder_v2", device, dtype=dtype)
  41. text_output, _ = translator.predict(
  42. ENG_SENTENCE,
  43. "t2tt",
  44. tgt_lang,
  45. src_lang=src_lang,
  46. )
  47. assert (
  48. text_output[0] == DEU_SENTENCE_V2
  49. ), f"'{text_output[0]}' is not '{DEU_SENTENCE_V2}'"
  50. def test_seamless_m4t_v2_large_multiple_tasks() -> None:
  51. model_name = "seamlessM4T_v2_large"
  52. english_text = "Hello! I hope you're all doing well."
  53. ref_spanish_text = "Hola, espero que todos estéis haciendo bien."
  54. ref_spanish_asr_text = "Hola, espero que todos estéis haciendo bien."
  55. if device == Device("cpu"):
  56. dtype = torch.float32
  57. else:
  58. dtype = torch.float16
  59. translator = Translator(model_name, "vocoder_v2", device, dtype=dtype)
  60. # Generate english speech for the english text.
  61. _, english_speech_output = translator.predict(
  62. english_text,
  63. "t2st",
  64. "eng",
  65. src_lang="eng",
  66. )
  67. assert english_speech_output is not None
  68. # Translate english speech to spanish speech.
  69. spanish_text_output, spanish_speech_output = translator.predict(
  70. english_speech_output.audio_wavs[0][0],
  71. "s2st",
  72. "spa",
  73. )
  74. assert spanish_speech_output is not None
  75. assert (
  76. spanish_text_output[0] == ref_spanish_text
  77. ), f"'{spanish_text_output[0]}' is not '{ref_spanish_text}'"
  78. # Run ASR on the spanish speech.
  79. spanish_asr_text_output, _ = translator.predict(
  80. spanish_speech_output.audio_wavs[0][0],
  81. "asr",
  82. "spa",
  83. )
  84. assert (
  85. spanish_asr_text_output[0] == ref_spanish_asr_text
  86. ), f"{spanish_asr_text_output[0]} is not {ref_spanish_asr_text}'"