Explorar o código

Update training recipees for asr and M4T large; Fix GPU ooms with M4T large training

mavlyutov hai 1 ano
pai
achega
c4e2953141

+ 3 - 0
scripts/m4t/train/configs.py

@@ -131,6 +131,9 @@ class DataLoadingConfig(Config):
     max_tgt_text_tokens_per_batch: Optional[int] = 1000
     """ Defines flexible batch construction """
 
+    max_batch_size: Optional[int] = None
+    """ In flexible batch construction sets max allowed size"""
+
     fixed_batch_size: Optional[int] = None
     """ If set, uses fixed batch size """
 

+ 72 - 20
scripts/m4t/train/dataloader.py

@@ -76,7 +76,9 @@ class UnityDataLoader:
         self.manifest_paths = list(self._iterate_manifest_paths())
         self.text_tokenizer = self._init_text_tokenizer()
         self.unit_tokenizer = self._init_unit_tokenizer()
-        self.spm_encoder = SentencePieceEncoder(model=self.text_tokenizer.model, suffix_tokens=["</s>"])
+        self.spm_encoder = SentencePieceEncoder(
+            model=self.text_tokenizer.model, suffix_tokens=["</s>"]
+        )
         self.text_prefix_tokens = self._build_text_tgt_prefixes()
         self.unit_prefix_tokens = self._build_unit_tgt_prefixes()
         if self.config.fixed_batch_size is None:
@@ -88,18 +90,25 @@ class UnityDataLoader:
 
     @classmethod
     def _set_mkl_num_threads(cls):
-        """ Setting mkl num threads to 1, so that we don't get thread explosion."""
-        mkl_rt = ctypes.CDLL('libmkl_rt.so')
+        """Setting mkl num threads to 1, so that we don't get thread explosion."""
+        mkl_rt = ctypes.CDLL("libmkl_rt.so")
         mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(1)))
 
     def _calculate_tgt_text_batch_shapes(self) -> List[Tuple[int, int]]:
         max_seq_len = self.config.max_tgt_text_tokens_per_sample
         max_tokens_per_batch = self.config.max_tgt_text_tokens_per_batch
         assert max_tokens_per_batch is not None, "max_tokens_per_batch is not set"
+        max_bsz = (
+            self.config.max_batch_size
+            if self.config.max_batch_size is not None
+            else max_tokens_per_batch
+        )
         step = self.BATCH_WIDTH_STEP
         bucket_sizes = []
         for seq_len in range(step, max(step, max_seq_len) + 1, step):
             bsz = max(1, max_tokens_per_batch // seq_len)
+            if bsz > max_bsz:
+                continue
             bucket_sizes.append((bsz, seq_len))
         return bucket_sizes
 
@@ -128,7 +137,8 @@ class UnityDataLoader:
             assert self.config.text_tokenization.langtoks is not None
             assert self.config.text_tokenization.spm_path is not None
             return SPMTokenizer(
-                pathname=self.config.text_tokenization.spm_path, langs=self.config.text_tokenization.langtoks
+                pathname=self.config.text_tokenization.spm_path,
+                langs=self.config.text_tokenization.langtoks,
             )
 
     def _init_unit_tokenizer(self) -> UnitTokenizer:
@@ -154,7 +164,9 @@ class UnityDataLoader:
     def _infer_manifest_full_path(self, manifest_name: str) -> str:
         full_path = manifest_name.strip()
         if self.config.manifest_path_prefix is not None:
-            full_path = os.path.join(self.config.manifest_path_prefix.strip(), full_path)
+            full_path = os.path.join(
+                self.config.manifest_path_prefix.strip(), full_path
+            )
         if not full_path.endswith(self.MANIFEST_EXT) and not os.path.exists(full_path):
             full_path += self.MANIFEST_EXT
         if not os.path.exists(full_path):
@@ -188,7 +200,9 @@ class UnityDataLoader:
             self.TARGET_LANG_COLUMN,
         ]:
             if column not in column_names:
-                raise ValueError(f"Column `{column}` is not present in `{manifest_path}` ")
+                raise ValueError(
+                    f"Column `{column}` is not present in `{manifest_path}` "
+                )
         return column_names
 
     def _builder_from_manifest(self, manifest_path: str) -> DataPipelineBuilder:
@@ -230,7 +244,9 @@ class UnityDataLoader:
         # Split each text line into its fields.
         fields = self._read_column_names(manifest_path)
         logger.debug(f"Column names: {fields}")
-        txt_splitter = StrSplitter(sep=self.MANIFEST_COLUMN_SEP, names=fields, indices=[], exclude=True)
+        txt_splitter = StrSplitter(
+            sep=self.MANIFEST_COLUMN_SEP, names=fields, indices=[], exclude=True
+        )
         pipeline.map(
             txt_splitter,
             selector=self.ROOT_COLUMN,
@@ -244,7 +260,10 @@ class UnityDataLoader:
         Picks samples from per-manifest pipelines in a round-robin order"""
         # TODO: add the ability to upsample/downsample manifests
         logger.info(f"Aggregating data from {len(self.manifest_paths)} manifests")
-        builders = [self._builder_from_manifest(manifest_path=path) for path in self.manifest_paths]
+        builders = [
+            self._builder_from_manifest(manifest_path=path)
+            for path in self.manifest_paths
+        ]
         pipelines = [builder.and_return() for builder in builders]
         return DataPipeline.round_robin(pipelines=pipelines)
 
@@ -276,7 +295,7 @@ class UnityDataLoader:
             channel_last=True,  # audio channel is the last dimension in the waveform
             standardize=self.config.audio.fbanks_standardize_audio,
             keep_waveform=False,
-            device=self.target_device,
+            device=self.CPU_DEVICE,  # avoid uncontrolled memory cons on GPUs
             dtype=self.float_dtype,
         )
         builder.map(
@@ -286,7 +305,9 @@ class UnityDataLoader:
         )
         return builder
 
-    def _attach_target_tokens(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
+    def _attach_target_tokens(
+        self, builder: DataPipelineBuilder
+    ) -> DataPipelineBuilder:
         # Convert `raw_tgt_text` to (full) target tokenized sequences:
         #                   <eos> <lang_tok> <tokens .. > <eos>
         # Lang tokens change between rows, so can't use static encoder
@@ -305,7 +326,10 @@ class UnityDataLoader:
         # 3) Not a computational blocker
         convert_to_units = lambda units_str: (  # noqa: E731
             torch.LongTensor(
-                [int(unit_id) + 4 for unit_id in units_str.rstrip().bytes().decode("utf-8").split()]
+                [
+                    int(unit_id) + 4
+                    for unit_id in units_str.rstrip().bytes().decode("utf-8").split()
+                ]
                 + [self.unit_tokenizer.vocab_info.eos_idx]
             )
         )
@@ -340,23 +364,43 @@ class UnityDataLoader:
 
     def _is_long_sample(self, sample: Any) -> bool:
         # input audio length
-        if self._get_input_audio_seconds(sample) > self.config.max_seconds_per_input_audio:
+        if (
+            self._get_input_audio_seconds(sample)
+            > self.config.max_seconds_per_input_audio
+        ):
             return True
 
         # target text tokens
-        num_tgt_text_tokens = sample[self.ROOT_COLUMN][self.TARGET_TEXT_COLUMN].shape[-1]
+        num_tgt_text_tokens = sample[self.ROOT_COLUMN][self.TARGET_TEXT_COLUMN].shape[
+            -1
+        ]
         if num_tgt_text_tokens > self.config.max_tgt_text_tokens_per_sample:
             return True
 
         # target units
-        num_tgt_units = sample[self.ROOT_COLUMN][self.TARGET_UNITS_COLUMN].shape[-1]  # target units
+        num_tgt_units = sample[self.ROOT_COLUMN][self.TARGET_UNITS_COLUMN].shape[
+            -1
+        ]  # target units
         if num_tgt_units > self.config.max_units_per_sample:
             return True
         return False
 
+    def _nans_in_fbanks(self, sample: Any) -> bool:
+        """Tells if NaNs present in fbank"""
+        fbank = sample[self.ROOT_COLUMN][self.AUDIO_COLUMN_NAME]["data"]["fbank"]
+        has_nans: bool = torch.any(torch.isnan(fbank)).item()  # type: ignore
+        if has_nans:
+            logger.warning("Sample fbank contains NaNs. Skipping")
+        return has_nans
+
     def _filter_samples(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
-        # Drop long samples
-        builder.filter(lambda sample: not self._is_long_sample(sample))
+        # Drop:
+        #  - "long" samples
+        #  - samples with fbanks that contain NaNs
+        builder.filter(
+            lambda sample: not self._is_long_sample(sample)
+            and not self._nans_in_fbanks(sample)
+        )
         return builder
 
     def _batch_samples(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
@@ -415,8 +459,12 @@ class UnityDataLoader:
         prev_output_tokens = prev_output_tokens[:, :-1]
 
         target_tokens = tokens[:, 1:]
-        assert torch.equal(torch.count_nonzero(prev_output_tokens != pad_idx, dim=1), target_lengths)
-        assert torch.equal(torch.count_nonzero(target_tokens != pad_idx, dim=1), target_lengths)
+        assert torch.equal(
+            torch.count_nonzero(prev_output_tokens != pad_idx, dim=1), target_lengths
+        )
+        assert torch.equal(
+            torch.count_nonzero(target_tokens != pad_idx, dim=1), target_lengths
+        )
         return prev_output_tokens, target_tokens, target_lengths
 
     def _get_text_to_units_batch(self, raw_batch: Any) -> SeqsBatch:
@@ -448,7 +496,9 @@ class UnityDataLoader:
             prefix_tokens=prefix_tokens.to(self.target_device),
         )
 
-    def _get_speech_src_tokens_and_lengths(self, raw_batch: Any) -> Tuple[torch.Tensor, torch.Tensor]:
+    def _get_speech_src_tokens_and_lengths(
+        self, raw_batch: Any
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
         fbanks = raw_batch[self.ROOT_COLUMN][self.AUDIO_COLUMN_NAME]["data"]["fbank"]
         return fbanks["seqs"].to(self.float_dtype), fbanks["seq_lens"]
 
@@ -471,7 +521,9 @@ class UnityDataLoader:
             pad_idx=pad_idx,
             eos_idx=eos_idx,
         )
-        src_tokens, src_lengths = self._get_speech_src_tokens_and_lengths(raw_batch=raw_batch)
+        src_tokens, src_lengths = self._get_speech_src_tokens_and_lengths(
+            raw_batch=raw_batch
+        )
 
         return SeqsBatch(
             src_tokens=src_tokens.to(self.target_device),

+ 3 - 3
scripts/m4t/train/recipes/asr_small_wh_transc.yaml

@@ -9,7 +9,7 @@ eval_data:
   manifest_list_path: null
   manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
   max_seconds_per_input_audio: 15
-  fixed_batch_size: 40
+  fixed_batch_size: 30
   max_tgt_text_tokens_per_batch: 1000
   max_tgt_text_tokens_per_sample: 300
   max_units_per_sample: 1500
@@ -63,7 +63,7 @@ train_data:
   manifest_list_path: null
   manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
   max_seconds_per_input_audio: 15
-  fixed_batch_size: 40
+  fixed_batch_size: 30
   max_tgt_text_tokens_per_batch: 600
   max_tgt_text_tokens_per_sample: 300
   max_units_per_sample: 1500
@@ -87,7 +87,7 @@ train_data:
   unit_tokenizer_name: seamlessM4T_large
 training:
   eval_steps: 1000 
-  float_dtype: fp32
+  float_dtype: bf16
   label_smoothing: 0.2
   learning_rate: 0.0001
   log_steps:  50 

+ 23 - 9
scripts/m4t/train/recipes/large_M4T_v1.yaml

@@ -9,7 +9,7 @@ eval_data:
   manifest_list_path: null
   manifest_path_prefix: /fsx-ust/spopuri/datasets/S2ST/V1/M4T_V1_phase2/primary/
   max_seconds_per_input_audio: 150
-  fixed_batch_size: 40
+  fixed_batch_size: 10
   max_tgt_text_tokens_per_batch: null
   max_tgt_text_tokens_per_sample: 3000
   max_units_per_sample: 1500
@@ -28,7 +28,20 @@ eval_data:
   unit_tokenizer_name: seamlessM4T_large
 model:
   custom_params:
+    model_embed_dim: 1024
     nllb_vocabulary_size: 256103
+    w2v2_encoder_layers: 24
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_pos_encoder_type: "relative"
+    w2v2_pos_encoder_depth: 0
+    w2v2_pos_conv_kernel_size: 0
+    w2v2_num_pos_conv_groups: 0
+    nllb_encoder_layers: 24
+    nllb_decoder_layers: 24
+    t2u_encoder_layers: 6
+    t2u_decoder_layers: 6
+    unit_vocabulary_size: 10082
   from_model: null
   from_model_config: null
   pretrained_s2t_decoder_path: /fsx-ust/spopuri/datasets/PT_CKPT/S2T/S2T_M4T_V1_V1_cleaned.pt
@@ -45,10 +58,11 @@ train_data:
   manifest_list_path: /data/home/mavlyutov/train_configs/m4t_v1_train_manifests.txt
   manifest_path_prefix: /fsx-ust/spopuri/datasets/S2ST/V1/M4T_V1_phase2/primary 
   max_seconds_per_input_audio: 15
-  fixed_batch_size: null 
-  max_tgt_text_tokens_per_batch: 600
-  max_tgt_text_tokens_per_sample: 300
-  max_units_per_sample: 1500
+  fixed_batch_size: null
+  max_batch_size: 25
+  max_tgt_text_tokens_per_batch: 300
+  max_tgt_text_tokens_per_sample: 150
+  max_units_per_sample: 1200
   num_threads: 10 
   prefech_batches: 10
   prepend_tgt_lang_tag: true
@@ -63,11 +77,11 @@ train_data:
     num_units: null
   unit_tokenizer_name: seamlessM4T_large
 training:
-  eval_steps: 5000 
-  float_dtype: fp16
+  eval_steps: 1000
+  float_dtype: bf16
   label_smoothing: 0.2
-  learning_rate: 0.0001
-  log_steps: 200 
+  learning_rate: 0.00005
+  log_steps: 200
   max_epochs: 100
   patience: 10
   start_learning_rate: 1.0e-07

+ 1 - 1
scripts/m4t/train/run_with_slurm.py

@@ -93,7 +93,7 @@ def prepare_sbatch_config(
 #SBATCH --ntasks-per-node=1
 
 ## amount of mem
-#SBATCH --mem 50G
+#SBATCH --mem 500G
 
 ## amount of time in minutes
 #SBATCH --time 2400

+ 6 - 1
scripts/m4t/train/trainer.py

@@ -326,7 +326,7 @@ class UnitYTrainer:
                     loss_val = float("Inf")
                 else:
                     loss_val = loss.item()
-                del batch  # force memory release
+                self._release_memory(batch)
                 loss_hist.update(1, loss_val)
         eval_loss = loss_hist.reduce()
         self._update_eval_stats(eval_loss)
@@ -369,6 +369,11 @@ class UnitYTrainer:
         self.train_loss_hist.update(1, loss.item())
         self.batch_sizes.append(batch.speech_to_text.src_tokens.shape[0])
         self._train_step_log()
+        self._release_memory(batch)
+
+    def _release_memory(self, batch: dataloader.MultimodalSeqsBatch) -> None:
+        """ Explicitly release large memory consumers """
+        del batch
 
     def _get_state(self) -> Dict[str, Any]:
         model_state_dict = self.model.state_dict()