Ver Fonte

Merge pull request #23 from facebookresearch/adjust_setup_py

Adjust setup.py. Map scripts to binary names. Adjust Readme docs.
Ruslan Mavlyutov há 2 anos atrás
pai
commit
21241a6c47

+ 5 - 0
.gitignore

@@ -139,3 +139,8 @@ wandb/
 nohup.out
 nohup.out
 multirun
 multirun
 outputs
 outputs
+
+
+# symlinks
+seamless_communication
+m4t_scripts

+ 9 - 3
README.md

@@ -22,22 +22,28 @@ Links:
 
 
 # Quick Start
 # Quick Start
 ## Installation
 ## Installation
+
 ```
 ```
-pip install fairseq2==0.1
 pip install .
 pip install .
 ```
 ```
 
 
+A temporary extra requirement for fairseq2 is [libsndfile](https://github.com/libsndfile/libsndfile). From [Conda](https://docs.conda.io/en/latest/) environment it can be installed via:
+```
+conda install -y -c conda-forge libsndfile
+```
+At this point fairseq2 has a confirmed support only for Linux and macOS. Pre-built packages are only available for Linux (macOS is planned).
+
 ## Running inference
 ## Running inference
 
 
 Here’s an example of using the CLI from the root directory to run inference.
 Here’s an example of using the CLI from the root directory to run inference.
 
 
 S2ST task:
 S2ST task:
 ```bash
 ```bash
-python scripts/m4t/predict/predict.py <path_to_input_audio> s2st <tgt_lang> --output_path <path_to_save_audio>
+m4t_predict <path_to_input_audio> s2st <tgt_lang> --output_path <path_to_save_audio>
 ```
 ```
 T2TT task:
 T2TT task:
 ```bash
 ```bash
-python scripts/m4t/predict/predict.py <input_text> t2tt <tgt_lang> --src_lang <src_lang>
+m4t_predict <input_text> t2tt <tgt_lang> --src_lang <src_lang>
 ```
 ```
 
 
 Please refer to the [evaluation README](scripts/m4t/predict) for detailed instruction on how to run inference.
 Please refer to the [evaluation README](scripts/m4t/predict) for detailed instruction on how to run inference.

+ 4 - 0
dev_requirements.txt

@@ -0,0 +1,4 @@
+pytest
+black
+flake8
+isort

+ 1 - 0
requirements.txt

@@ -3,3 +3,4 @@ datasets
 torchaudio
 torchaudio
 soundfile
 soundfile
 librosa
 librosa
+fairseq2==0.1.0

+ 4 - 3
scripts/m4t/finetune/README.md

@@ -29,12 +29,12 @@ Below is an example bash script that prepares a training and evaluation dataset
 export DATASET_DIR=~/m4t_dataset
 export DATASET_DIR=~/m4t_dataset
 mkdir -p $DATASET_DIR
 mkdir -p $DATASET_DIR
 
 
-python scripts/m4t/finetune/dataset.py \
+m4t_prepare_dataset \
   --source_lang eng \
   --source_lang eng \
   --target_lang kor \
   --target_lang kor \
   --split train \
   --split train \
   --save_dir $DATASET_DIR
   --save_dir $DATASET_DIR
- python scripts/m4t/finetune/dataset.py \
+m4t_prepare_dataset \
   --source_lang eng \
   --source_lang eng \
   --target_lang kor \
   --target_lang kor \
   --split validation \
   --split validation \
@@ -97,7 +97,8 @@ torchrun \
    --rdzv-endpoint=localhost:0 \
    --rdzv-endpoint=localhost:0 \
    --nnodes=1 \
    --nnodes=1 \
    --nproc-per-node=8  \
    --nproc-per-node=8  \
-  scripts/m4t/finetune/finetune.py \
+   --no-python \
+  m4t_finetune \
    --mode SPEECH_TO_TEXT \
    --mode SPEECH_TO_TEXT \
    --train_dataset $DATASET_DIR/train_manifest.json  \
    --train_dataset $DATASET_DIR/train_manifest.json  \
    --eval_dataset $DATASET_DIR/validation_manifest.json \
    --eval_dataset $DATASET_DIR/validation_manifest.json \

+ 3 - 4
scripts/m4t/finetune/dataset.py

@@ -10,7 +10,6 @@ import dataclasses
 import json
 import json
 import logging
 import logging
 import os
 import os
-from argparse import Namespace
 from pathlib import Path
 from pathlib import Path
 
 
 from seamless_communication.datasets.huggingface import (
 from seamless_communication.datasets.huggingface import (
@@ -157,7 +156,8 @@ def init_parser() -> argparse.ArgumentParser:
     return parser
     return parser
 
 
 
 
-def main(args: Namespace) -> None:
+def main() -> None:
+    args = init_parser().parse_args()
     manifest_path = download_fleurs_dataset(
     manifest_path = download_fleurs_dataset(
         source_lang=args.source_lang,
         source_lang=args.source_lang,
         target_lang=args.target_lang,
         target_lang=args.target_lang,
@@ -168,5 +168,4 @@ def main(args: Namespace) -> None:
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    args = init_parser().parse_args()
-    main(args)
+    main()

+ 5 - 8
scripts/m4t/finetune/finetune.py

@@ -7,14 +7,11 @@
 import argparse
 import argparse
 import logging
 import logging
 import os
 import os
-from argparse import Namespace
 from pathlib import Path
 from pathlib import Path
 
 
-import dataloader
-import dist_utils
 import torch
 import torch
-import trainer
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
+from m4t_scripts.finetune import dataloader, dist_utils, trainer
 
 
 from seamless_communication.models.unity import (
 from seamless_communication.models.unity import (
     UnitTokenizer,
     UnitTokenizer,
@@ -115,7 +112,7 @@ def init_parser() -> argparse.ArgumentParser:
         "--mode",
         "--mode",
         type=trainer.FinetuneMode,
         type=trainer.FinetuneMode,
         choices=list(trainer.FinetuneMode),
         choices=list(trainer.FinetuneMode),
-        default=trainer.FinetuneMode.TEXT_TO_SPEECH,
+        default=trainer.FinetuneMode.SPEECH_TO_TEXT,
         help=(
         help=(
             "* `SPEECH_TO_SPEECH` -- finetune S2T and T2U parts of the model; "
             "* `SPEECH_TO_SPEECH` -- finetune S2T and T2U parts of the model; "
             "* `TEXT_TO_SPEECH` -- finetune only T2U; "
             "* `TEXT_TO_SPEECH` -- finetune only T2U; "
@@ -125,7 +122,8 @@ def init_parser() -> argparse.ArgumentParser:
     return parser
     return parser
 
 
 
 
-def run_finetune(args: Namespace) -> None:
+def main() -> None:
+    args = init_parser().parse_args()
     dist_utils.init_distributed([logger, trainer.logger])
     dist_utils.init_distributed([logger, trainer.logger])
     device = torch.device("cuda")
     device = torch.device("cuda")
     text_tokenizer: NllbTokenizer = load_unity_text_tokenizer(args.model_name)
     text_tokenizer: NllbTokenizer = load_unity_text_tokenizer(args.model_name)
@@ -182,5 +180,4 @@ def run_finetune(args: Namespace) -> None:
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    parser = init_parser()
-    run_finetune(parser.parse_args())
+    main()

+ 3 - 3
scripts/m4t/finetune/trainer.py

@@ -12,17 +12,17 @@ from enum import Enum
 from pathlib import Path
 from pathlib import Path
 from typing import Optional, Tuple
 from typing import Optional, Tuple
 
 
-import dataloader
-import dist_utils
 import torch
 import torch
 import torch.distributed as dist
 import torch.distributed as dist
 import torch.nn as nn
 import torch.nn as nn
 from fairseq2.models.sequence import SequenceModelOutput
 from fairseq2.models.sequence import SequenceModelOutput
-from fairseq2.models.unity import UnitYModel
 from fairseq2.optim.lr_scheduler import MyleLR
 from fairseq2.optim.lr_scheduler import MyleLR
 from fairseq2.typing import Device
 from fairseq2.typing import Device
+from m4t_scripts.finetune import dataloader, dist_utils
 from torch.optim import Adam
 from torch.optim import Adam
 
 
+from seamless_communication.models.unity import UnitYModel
+
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 

+ 5 - 5
scripts/m4t/predict/README.md

@@ -16,27 +16,27 @@ The model can be specified with `--model_name` `seamlessM4T_large` or `seamlessM
 
 
 **S2ST**:
 **S2ST**:
 ```bash
 ```bash
-python scripts/m4t/predict/predict.py <path_to_input_audio> s2st <tgt_lang> --output_path <path_to_save_audio> --model_name seamlessM4T_large
+m4t_predict <path_to_input_audio> s2st <tgt_lang> --output_path <path_to_save_audio> --model_name seamlessM4T_large
 ```
 ```
 
 
 **S2TT**:
 **S2TT**:
 ```bash
 ```bash
-python scripts/m4t/predict/predict.py <path_to_input_audio> s2tt <tgt_lang>
+m4t_predict <path_to_input_audio> s2tt <tgt_lang>
 ```
 ```
 
 
 **T2TT**:
 **T2TT**:
 ```bash
 ```bash
-python scripts/m4t/predict/predict.py <input_text> t2tt <tgt_lang> --src_lang <src_lang>
+m4t_predict <input_text> t2tt <tgt_lang> --src_lang <src_lang>
 ```
 ```
 
 
 **T2ST**:
 **T2ST**:
 ```bash
 ```bash
-python scripts/m4t/predict/predict.py <input_text> t2st <tgt_lang> --src_lang <src_lang> --output_path <path_to_save_audio>
+m4t_predict <input_text> t2st <tgt_lang> --src_lang <src_lang> --output_path <path_to_save_audio>
 ```
 ```
 
 
 **ASR**:
 **ASR**:
 ```bash
 ```bash
-python scripts/m4t/predict/predict.py <path_to_input_audio> asr <tgt_lang>
+m4t_predict <path_to_input_audio> asr <tgt_lang>
 ```
 ```
 
 
 Note that it takes 16kHz audio now. Here's how you could resample your audio:
 Note that it takes 16kHz audio now. Here's how you could resample your audio:

+ 0 - 0
scripts/m4t/predict/__init__.py


+ 4 - 1
scripts/m4t/predict/predict.py

@@ -9,8 +9,11 @@ import torch
 import torchaudio
 import torchaudio
 from seamless_communication.models.inference import Translator
 from seamless_communication.models.inference import Translator
 
 
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
 
 
-logging.basicConfig(level=logging.INFO)
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 

+ 62 - 4
setup.py

@@ -4,12 +4,70 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
+from pathlib import Path
+import os
+from typing import Iterable
+
+import pkg_resources
 from setuptools import find_packages, setup
 from setuptools import find_packages, setup
+from setuptools.command.develop import develop
+
+
+def _load_requirements(fname: str) -> Iterable[str]:
+    with open(Path(__file__).parent / fname) as fp_in:
+        for req in pkg_resources.parse_requirements(fp_in):
+            yield str(req)
+
+
+def _add_symlinks():
+    root = Path(__file__).parent
+    sc_root = root / "src/seamless_communication"
+    sc_link = root / "seamless_communication"
+    m4t_scripts_root = root / "scripts/m4t"
+    m4t_scripts_link = root / "m4t_scripts"
+    if not sc_link.exists():
+        os.symlink(sc_root, sc_link, target_is_directory=True)
+    if not m4t_scripts_link.exists():
+        os.symlink(m4t_scripts_root, m4t_scripts_link, target_is_directory=True)
+
+
+class cmd_for_editable_mode(develop):
+    def run(self):
+        # add symlinks for modules if install in editable mode
+        _add_symlinks()
+        super().run()
+
+
+default_requirements = list(_load_requirements("requirements.txt"))
+dev_requirements = list(_load_requirements("dev_requirements.txt"))
 
 
 setup(
 setup(
     name="seamless_communication",
     name="seamless_communication",
-    version="0.1",
-    packages=find_packages(where="src"),
-    package_dir={"": "src"},
-    package_data={"": ["assets/cards/*.yaml"]},
+    version="1.0.0",
+    packages=find_packages(where="src")
+    + ["m4t_scripts.finetune", "m4t_scripts.predict"],
+    package_dir={
+        "m4t_scripts": "scripts/m4t",
+        "seamless_communication": "src/seamless_communication",
+    },
+    package_data={"": ["seamless_communication/assets/cards/*.yaml"]},
+    description="SeamlessM4T -- Massively Multilingual & Multimodal Machine Translation Model",
+    long_description=open("README.md", encoding="utf-8").read(),
+    long_description_content_type="text/markdown",
+    readme="README.md",
+    python_requires=">=3.8",
+    author="Fundamental AI Research (FAIR) at Meta",
+    url="https://github.com/facebookresearch/seamless_communication",
+    license="Creative Commons",
+    install_requires=default_requirements,
+    extras_require={"dev": default_requirements + dev_requirements},
+    entry_points={
+        "console_scripts": [
+            "m4t_predict=m4t_scripts.predict.predict:main",
+            "m4t_finetune=m4t_scripts.finetune.finetune:main",
+            "m4t_prepare_dataset=m4t_scripts.finetune.dataset:main",
+        ],
+    },
+    cmdclass={"develop": cmd_for_editable_mode},
+    include_package_data=True,
 )
 )