audio_to_units.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import argparse
  6. import logging
  7. import torch
  8. from seamless_communication.models.unit_extraction import UnitExtractor
  9. logging.basicConfig(level=logging.INFO)
  10. logger = logging.getLogger(__name__)
  11. def main():
  12. parser = argparse.ArgumentParser(
  13. description="Convert raw audio to units (and optionally audio) using UnitExtractor."
  14. )
  15. parser.add_argument("audio", type=str, help="Audio WAV file path.")
  16. parser.add_argument(
  17. "--kmeans_uri",
  18. type=str,
  19. help="URL path to the K-Means model.",
  20. default="https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
  21. )
  22. parser.add_argument(
  23. "--model_name",
  24. type=str,
  25. help="Feature extraction model name (`xlsr2_1b_v2`)",
  26. default="xlsr2_1b_v2",
  27. )
  28. parser.add_argument(
  29. "--out_layer_number",
  30. type=int,
  31. help="Layer number of the feature extraction model to pull out features from.",
  32. default=35,
  33. )
  34. args = parser.parse_args()
  35. if torch.cuda.is_available():
  36. device = torch.device("cuda:0")
  37. logger.info("Running unit_extraction on the GPU.")
  38. else:
  39. device = torch.device("cpu")
  40. logger.info("Running unit_extraction on the CPU.")
  41. unit_extractor = UnitExtractor(args.model_name, args.kmeans_uri, device=device)
  42. units = unit_extractor.predict(args.audio, args.out_layer_number - 1)
  43. logger.info(f"Converted to units: {units}")
  44. if __name__ == "__main__":
  45. main()