12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import argparse
- import logging
- import torch
- from seamless_communication.models.unit_extraction import UnitExtractor
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- def main():
- parser = argparse.ArgumentParser(
- description="Convert raw audio to units (and optionally audio) using UnitExtractor."
- )
- parser.add_argument("audio", type=str, help="Audio WAV file path.")
- parser.add_argument(
- "--kmeans_uri",
- type=str,
- help="URL path to the K-Means model.",
- default="https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
- )
- parser.add_argument(
- "--model_name",
- type=str,
- help="Feature extraction model name (`xlsr2_1b_v2`)",
- default="xlsr2_1b_v2",
- )
- parser.add_argument(
- "--out_layer_number",
- type=int,
- help="Layer number of the feature extraction model to pull out features from.",
- default=35,
- )
- args = parser.parse_args()
- if torch.cuda.is_available():
- device = torch.device("cuda:0")
- logger.info("Running unit_extraction on the GPU.")
- else:
- device = torch.device("cpu")
- logger.info("Running unit_extraction on the CPU.")
- unit_extractor = UnitExtractor(args.model_name, args.kmeans_uri, device=device)
- units = unit_extractor.predict(args.audio, args.out_layer_number - 1)
- logger.info(f"Converted to units: {units}")
- if __name__ == "__main__":
- main()
|