|
@@ -12,9 +12,13 @@ import logging
|
|
import os
|
|
import os
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
+import torch
|
|
|
|
+
|
|
from seamless_communication.datasets.huggingface import (
|
|
from seamless_communication.datasets.huggingface import (
|
|
Speech2SpeechFleursDatasetBuilder,
|
|
Speech2SpeechFleursDatasetBuilder,
|
|
|
|
+ SpeechTokenizer,
|
|
)
|
|
)
|
|
|
|
+from seamless_communication.models.unit_extraction import UnitExtractor
|
|
|
|
|
|
logging.basicConfig(
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
level=logging.INFO,
|
|
@@ -91,6 +95,28 @@ def _check_lang_code_mapping(lang: str) -> None:
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
+class UnitSpeechTokenizer(SpeechTokenizer):
|
|
|
|
+ MODEL_NAME = "xlsr2_1b_v2"
|
|
|
|
+ KMEANS_MODEL_URI = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy"
|
|
|
|
+ OUTPUT_LAYER_IDX = 34
|
|
|
|
+
|
|
|
|
+ def __init__(self, device: torch.device):
|
|
|
|
+ super().__init__()
|
|
|
|
+ self.device = device
|
|
|
|
+ self.unit_extractor = UnitExtractor(
|
|
|
|
+ model_name_or_card=self.MODEL_NAME,
|
|
|
|
+ kmeans_uri=self.KMEANS_MODEL_URI,
|
|
|
|
+ device=self.device,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ def encode(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
|
|
|
+ return self.unit_extractor.predict(
|
|
|
|
+ wav.to(self.device),
|
|
|
|
+ out_layer_idx=self.OUTPUT_LAYER_IDX,
|
|
|
|
+ sample_rate=sample_rate,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+
|
|
def download_fleurs_dataset(
|
|
def download_fleurs_dataset(
|
|
source_lang: str,
|
|
source_lang: str,
|
|
target_lang: str,
|
|
target_lang: str,
|
|
@@ -99,7 +125,10 @@ def download_fleurs_dataset(
|
|
) -> str:
|
|
) -> str:
|
|
_check_lang_code_mapping(source_lang)
|
|
_check_lang_code_mapping(source_lang)
|
|
_check_lang_code_mapping(target_lang)
|
|
_check_lang_code_mapping(target_lang)
|
|
- tokenizer = None
|
|
|
|
|
|
+ device = (
|
|
|
|
+ torch.device("cuda:0") if torch.cuda.device_count() > 0 else torch.device("cpu")
|
|
|
|
+ )
|
|
|
|
+ tokenizer = UnitSpeechTokenizer(device=device)
|
|
dataset_iterator = Speech2SpeechFleursDatasetBuilder(
|
|
dataset_iterator = Speech2SpeechFleursDatasetBuilder(
|
|
source_lang=UNITY_TO_FLEURS_LANG_MAPPING[source_lang],
|
|
source_lang=UNITY_TO_FLEURS_LANG_MAPPING[source_lang],
|
|
target_lang=UNITY_TO_FLEURS_LANG_MAPPING[target_lang],
|
|
target_lang=UNITY_TO_FLEURS_LANG_MAPPING[target_lang],
|
|
@@ -111,7 +140,7 @@ def download_fleurs_dataset(
|
|
)
|
|
)
|
|
manifest_path: str = os.path.join(save_directory, f"{split}_manifest.json")
|
|
manifest_path: str = os.path.join(save_directory, f"{split}_manifest.json")
|
|
with open(manifest_path, "w") as fp_out:
|
|
with open(manifest_path, "w") as fp_out:
|
|
- for idx, sample in enumerate(dataset_iterator, start=1):
|
|
|
|
|
|
+ for idx, sample in enumerate(dataset_iterator.__iter__(), start=1):
|
|
# correction as FleursDatasetBuilder return fleurs lang codes
|
|
# correction as FleursDatasetBuilder return fleurs lang codes
|
|
sample.source.lang = source_lang
|
|
sample.source.lang = source_lang
|
|
sample.target.lang = target_lang
|
|
sample.target.lang = target_lang
|