Преглед изворни кода

Merge pull request #60 from facebookresearch/restore_unit_extraction_in_dataset_script

Re-enable unit extraction in dataset preparation script
Ruslan Mavlyutov пре 2 година
родитељ
комит
519f46576e
3 измењених фајлова са 48 додато и 6 уклоњено
  1. 3 1
      .gitignore
  2. 31 2
      scripts/m4t/finetune/dataset.py
  3. 14 3
      src/seamless_communication/datasets/huggingface.py

+ 3 - 1
.gitignore

@@ -143,4 +143,6 @@ outputs
 
 # symlinks
 seamless_communication
-m4t_scripts
+# ignore src/seamless_communication  
+!*/seamless_communication
+m4t_scripts

+ 31 - 2
scripts/m4t/finetune/dataset.py

@@ -12,9 +12,13 @@ import logging
 import os
 from pathlib import Path
 
+import torch
+
 from seamless_communication.datasets.huggingface import (
     Speech2SpeechFleursDatasetBuilder,
+    SpeechTokenizer,
 )
+from seamless_communication.models.unit_extraction import UnitExtractor
 
 logging.basicConfig(
     level=logging.INFO,
@@ -91,6 +95,28 @@ def _check_lang_code_mapping(lang: str) -> None:
         )
 
 
+class UnitSpeechTokenizer(SpeechTokenizer):
+    MODEL_NAME = "xlsr2_1b_v2"
+    KMEANS_MODEL_URI = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy"
+    OUTPUT_LAYER_IDX = 34
+
+    def __init__(self, device: torch.device):
+        super().__init__()
+        self.device = device
+        self.unit_extractor = UnitExtractor(
+            model_name_or_card=self.MODEL_NAME,
+            kmeans_uri=self.KMEANS_MODEL_URI,
+            device=self.device,
+        )
+
+    def encode(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
+        return self.unit_extractor.predict(
+            wav.to(self.device),
+            out_layer_idx=self.OUTPUT_LAYER_IDX,
+            sample_rate=sample_rate,
+        )
+
+
 def download_fleurs_dataset(
     source_lang: str,
     target_lang: str,
@@ -99,7 +125,10 @@ def download_fleurs_dataset(
 ) -> str:
     _check_lang_code_mapping(source_lang)
     _check_lang_code_mapping(target_lang)
-    tokenizer = None
+    device = (
+        torch.device("cuda:0") if torch.cuda.device_count() > 0 else torch.device("cpu")
+    )
+    tokenizer = UnitSpeechTokenizer(device=device)
     dataset_iterator = Speech2SpeechFleursDatasetBuilder(
         source_lang=UNITY_TO_FLEURS_LANG_MAPPING[source_lang],
         target_lang=UNITY_TO_FLEURS_LANG_MAPPING[target_lang],
@@ -111,7 +140,7 @@ def download_fleurs_dataset(
     )
     manifest_path: str = os.path.join(save_directory, f"{split}_manifest.json")
     with open(manifest_path, "w") as fp_out:
-        for idx, sample in enumerate(dataset_iterator, start=1):
+        for idx, sample in enumerate(dataset_iterator.__iter__(), start=1):
             # correction as FleursDatasetBuilder return fleurs lang codes
             sample.source.lang = source_lang
             sample.target.lang = target_lang

+ 14 - 3
src/seamless_communication/datasets/huggingface.py

@@ -7,7 +7,8 @@
 
 import logging
 import os
-from typing import Any, Dict, Iterable, Optional
+from abc import abstractmethod
+from typing import Dict, Iterable, Optional
 
 import numpy as np
 import torch
@@ -18,6 +19,12 @@ from .datatypes import LangPairSample, MultimodalSample
 logger = logging.getLogger(__name__)
 
 
+class SpeechTokenizer:
+    @abstractmethod
+    def encode(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
+        ...
+
+
 class Speech2SpeechFleursDatasetBuilder:
     """Assembles speech2speech dataset from google/fleurs on HuggingFace"""
 
@@ -32,7 +39,7 @@ class Speech2SpeechFleursDatasetBuilder:
         skip_target_audio: bool = True,
         audio_dtype: torch.dtype = torch.float32,
         dataset_cache_dir: Optional[str] = None,
-        speech_tokenizer: Optional[Any] = None,
+        speech_tokenizer: Optional[SpeechTokenizer] = None,
     ):
         self.source_lang = source_lang
         self.target_lang = target_lang
@@ -65,7 +72,11 @@ class Speech2SpeechFleursDatasetBuilder:
             waveform = None
         if self.speech_tokenizer is not None and not should_skip_audio:
             assert waveform is not None
-            units = self.speech_tokenizer.encode(waveform.unsqueeze(0))[0].tolist()
+            assert sampling_rate is not None
+            units_tensor = self.speech_tokenizer.encode(
+                waveform, sampling_rate
+            ).reshape(-1)
+            units = units_tensor.tolist()
         else:
             units = None
         return MultimodalSample(