audio_to_units.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import argparse
  6. import logging
  7. import torch
  8. import torchaudio
  9. from seamless_communication.models.unit_extraction import UnitExtractor
  10. from seamless_communication.models.inference import Translator
  11. from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
  12. from itertools import groupby
  13. logging.basicConfig(level=logging.INFO)
  14. logger = logging.getLogger(__name__)
  15. def main():
  16. parser = argparse.ArgumentParser(
  17. description="Convert raw audio to units (and optionally audio) using UnitExtractor."
  18. )
  19. parser.add_argument("audio", type=str, help="Audio WAV file path.")
  20. parser.add_argument(
  21. "--kmeans_uri",
  22. type=str,
  23. help="URL path to the K-Means model.",
  24. default="https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
  25. )
  26. parser.add_argument(
  27. "--model_name",
  28. type=str,
  29. help="Feature extraction model name (`xlsr2_1b_v2`)",
  30. default="xlsr2_1b_v2",
  31. )
  32. parser.add_argument(
  33. "--vocoder_name", type=str, help="Vocoder name", default="vocoder_36langs"
  34. )
  35. parser.add_argument(
  36. "--out_layer_number",
  37. type=int,
  38. help="Layer number of the feature extraction model to pull out features from.",
  39. default=35,
  40. )
  41. parser.add_argument(
  42. "--output_path",
  43. type=str,
  44. help="Path to save the generated audio.",
  45. default=None,
  46. )
  47. parser.add_argument(
  48. "--src_lang", type=str, help="Source language of the audio.", default=None
  49. )
  50. args = parser.parse_args()
  51. if torch.cuda.is_available():
  52. device = torch.device("cuda:0")
  53. logger.info("Running unit_extraction on the GPU.")
  54. else:
  55. device = torch.device("cpu")
  56. logger.info("Running unit_extraction on the CPU.")
  57. unit_extractor = UnitExtractor(args.model_name, args.kmeans_uri, device=device)
  58. units = unit_extractor.predict(args.audio, args.out_layer_number - 1)
  59. if args.output_path is not None:
  60. if args.src_lang is None:
  61. raise ValueError("src_lang must be provided to resynthesize the audio.")
  62. def reduce_list(lst):
  63. return [key for key, _ in groupby(lst)]
  64. reduced_units = reduce_list(units.cpu().tolist())
  65. vocoder: Vocoder = Translator.load_model_for_inference(
  66. load_vocoder_model, args.vocoder_name, device, torch.float32
  67. )
  68. wav = vocoder(reduced_units, args.src_lang, spkr=-1, dur_prediction=True)
  69. torchaudio.save(
  70. args.output_path,
  71. wav[0].cpu(),
  72. sample_rate=16000,
  73. )
  74. if __name__ == "__main__":
  75. main()