Ruslan Mavlyutov 1 anno fa
parent
commit
8bd011b211

+ 38 - 15
src/seamless_communication/cli/m4t/finetune/dataloader.py

@@ -8,7 +8,7 @@
 import json
 import logging
 from dataclasses import dataclass
-from typing import Any, Dict, Iterable, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Tuple
 
 import numpy as np
 import torch
@@ -100,6 +100,7 @@ class UnitYDataLoader:
         unit_tokenizer: UnitTokenizer,
         dataset_manifest_path: str,
         batching_config: BatchingConfig,
+        max_src_tokens_per_batch: int = 100000
     ):
         self.text_tokenizer = text_tokenizer
         self.text_encoders_per_lang: Dict[str, TextTokenEncoder] = {}
@@ -115,6 +116,7 @@ class UnitYDataLoader:
             "dtype": self.batching_config.float_dtype,
         }
         self.dataset = self._load_manifest(dataset_manifest_path)
+        self.max_src_tokens_per_batch = max_src_tokens_per_batch
 
     def get_dataloader(self) -> DataLoader[SeqsBatch]:
         subset = split_dataset_by_node(
@@ -156,9 +158,9 @@ class UnitYDataLoader:
         """Expected sequence is [<eos>, <lang_tok> , ..text tokens.., <eos>]"""
         target_lang = sample.target.lang
         if target_lang not in self.text_encoders_per_lang:
-            self.text_encoders_per_lang[
-                target_lang
-            ] = self.text_tokenizer.create_encoder(lang=target_lang, mode="target")
+            self.text_encoders_per_lang[target_lang] = (
+                self.text_tokenizer.create_encoder(lang=target_lang, mode="target")
+            )
         tokens = self.text_encoders_per_lang[target_lang](sample.target.text)
         eos_idx = self.text_tokenizer.vocab_info.eos_idx
         tokens = torch.concat([tokens, torch.LongTensor([eos_idx])])
@@ -170,9 +172,9 @@ class UnitYDataLoader:
             return None
         target_lang = sample.target.lang
         if target_lang not in self.unit_encoders_per_lang:
-            self.unit_encoders_per_lang[
-                target_lang
-            ] = self.unit_tokenizer.create_encoder(lang=target_lang)
+            self.unit_encoders_per_lang[target_lang] = (
+                self.unit_tokenizer.create_encoder(lang=target_lang)
+            )
         tokens = self.unit_encoders_per_lang[target_lang](
             torch.LongTensor(sample.target.units).unsqueeze(0)
         )
@@ -195,20 +197,41 @@ class UnitYDataLoader:
         length_s: float = max(wav.shape) / sample_rate
         return length_s > self.batching_config.max_audio_length_sec
 
+    def _drop_overflow_samples(
+        self, samples_with_fbanks: List[Tuple[LangPairSample, torch.Tensor]]
+    ) -> List[Tuple[LangPairSample, torch.Tensor]]:
+        # filter by src_tokens length (reverse)
+        samples_with_fbanks = sorted(
+            samples_with_fbanks, key=lambda sb: -sb[1].shape[0]
+        )
+        bwd = samples_with_fbanks[0][1].shape[0]
+        max_samples_for_batch = min(1, self.max_src_tokens_per_batch // bwd)
+        if max_samples_for_batch < len(samples_with_fbanks):
+            samples_with_fbanks = samples_with_fbanks[:max_samples_for_batch]
+        return samples_with_fbanks
+
     def _prepare_batch(self, raw_samples: List[Dict[str, Any]]) -> MultimodalSeqsBatch:
         samples = [LangPairSample.from_json(sample) for sample in raw_samples]
         # input speech
         #  - filter long audio samples
-        filtered_samples = [sample for sample in samples if not self._is_long_src_audio(sample)]
-        samples = filtered_samples if filtered_samples else [samples[0]]  # keep at least one sample
-        src_tokens_list = [self._get_source_fbank(sample) for sample in samples]
+        filtered_samples = [
+            sample for sample in samples if not self._is_long_src_audio(sample)
+        ]
+        samples = (
+            filtered_samples if filtered_samples else [samples[0]]
+        )  # keep at least one sample
+        with_fbanks = [(sample, self._get_source_fbank(sample)) for sample in samples]
         #  - filter NaNs in fbanks
-        with_nans = [fbank.isnan().any().item() for fbank in src_tokens_list]
-        samples = [sample for sample, skip in zip(samples, with_nans) if not skip]
-        assert len(samples) > 0
-        src_tokens_list = [
-            src_toks for src_toks, skip in zip(src_tokens_list, with_nans) if not skip
+        filtered = [
+            (sample, fbank)
+            for sample, fbank in with_fbanks
+            if not fbank.isnan().any().item()
         ]
+        filtered = self._drop_overflow_samples(filtered)
+
+        samples = [sample for sample, _ in filtered]
+        src_tokens_list = [src_tokens for _, src_tokens in filtered]
+        assert len(samples) > 0
         src_tokens = self._batch_tensors(
             src_tokens_list, pad_value=self.batching_config.fbank_feats_pad_idx
         ).to(self.batching_config.float_dtype)

+ 1 - 0
src/seamless_communication/cli/m4t/finetune/finetune.py

@@ -176,6 +176,7 @@ def main() -> None:
             float_dtype=finetune_params.float_dtype,
         ),
         dataset_manifest_path=args.train_dataset,
+        max_src_tokens_per_batch=7000,
     )
     eval_dataloader = dataloader.UnitYDataLoader(
         text_tokenizer=text_tokenizer,