Переглянути джерело

Switch to S2T finetuning in example scripts. Add support for null units.

Ruslan Mavlyutov 2 роки тому
батько
коміт
b0a4da0dab

+ 2 - 0
requirements.txt

@@ -1,3 +1,5 @@
 pre-commit
 datasets
 torchaudio
+soundfile
+librosa

+ 3 - 2
scripts/m4t/finetune/README.md

@@ -86,7 +86,7 @@ The scripts supports three modes of finetuning:
 - `TEXT_TO_SPEECH`: only text-to-unit part of the model will be engaged in the finetuning, other weights will be frozen;
 - `SPEECH_TO_TEXT`: only speech-to-text part of the model will be engaged in the finetuning.
 
-The referenced finetuning script does not support finetuning of the text encoder. Though the code expantion should be trivial.
+The referenced finetuning script does not support finetuning of the text encoder.
 
 
 Below is an example bash script that launches finetuning of M4T-large on the dataset prepared earlier, using a single node with eight GPUs:
@@ -98,6 +98,7 @@ torchrun \
    --nnodes=1 \
    --nproc-per-node=8  \
   scripts/m4t/finetune/finetune.py \
+   --mode SPEECH_TO_TEXT \
    --train_dataset $DATASET_DIR/train_manifest.json  \
    --eval_dataset $DATASET_DIR/validation_manifest.json \
    --learning_rate 1e-6 \
@@ -105,7 +106,7 @@ torchrun \
    --max_epochs 10 \
    --patience 3 \
    --model_name seamlessM4T_large \
-   --save_model_to $WORKDIR/checkpoint_lr_1e-6_full.pt
+   --save_model_to $DATASET_DIR/checkpoint.pt
 ```
 
 Excerpt from an example finetuning log:

+ 25 - 13
scripts/m4t/finetune/dataloader.py

@@ -34,9 +34,9 @@ logger = logging.getLogger(__name__)
 class SeqsBatch:
     src_tokens: Optional[Tensor]
     src_lengths: Optional[Tensor]
-    target_tokens: Tensor
-    prev_output_tokens: Tensor
-    target_lengths: Tensor
+    target_tokens: Optional[Tensor]
+    prev_output_tokens: Optional[Tensor]
+    target_lengths: Optional[Tensor]
 
     def __del__(self) -> None:
         """Explicitly delete tensors
@@ -136,8 +136,10 @@ class UnitYDataLoader:
         tokens = torch.concat([tokens, torch.LongTensor([eos_idx])])
         return tokens
 
-    def _get_tokenized_units(self, sample: LangPairSample) -> Tensor:
+    def _get_tokenized_units(self, sample: LangPairSample) -> Optional[Tensor]:
         """Expected sequence is [<eos>, <lang_tok> , ..unit tokens.., <eos>]"""
+        if sample.target.units is None:
+            return None
         target_lang = sample.target.lang
         if target_lang not in self.unit_encoders_per_lang:
             self.unit_encoders_per_lang[
@@ -185,15 +187,25 @@ class UnitYDataLoader:
             [tokens.shape[0] - 1 for tokens in text_tokens_list]
         )
         # output units
-        units_list = [self._get_tokenized_units(sample) for sample in samples]
-        units_pad_idx = self.unit_tokenizer.vocab_info.pad_idx
-        prev_outputs_units = self._batch_tensors(
-            [tokens[:-1] for tokens in units_list], pad_value=units_pad_idx
-        )
-        target_units = self._batch_tensors(
-            [tokens[1:] for tokens in units_list], pad_value=units_pad_idx
-        )
-        units_lengths = torch.LongTensor([tokens.shape[0] - 1 for tokens in units_list])
+        units_list_raw = [self._get_tokenized_units(sample) for sample in samples]
+        if None in units_list_raw:
+            prev_outputs_units = None
+            target_units = None
+            units_lengths = None
+        else:
+            units_list: List[Tensor] = [
+                value for value in units_list_raw if value is not None
+            ]
+            units_pad_idx = self.unit_tokenizer.vocab_info.pad_idx
+            prev_outputs_units = self._batch_tensors(
+                [tokens[:-1] for tokens in units_list], pad_value=units_pad_idx
+            )
+            target_units = self._batch_tensors(
+                [tokens[1:] for tokens in units_list], pad_value=units_pad_idx
+            )
+            units_lengths = torch.LongTensor(
+                [tokens.shape[0] - 1 for tokens in units_list]
+            )
         return MultimodalSeqsBatch(
             speech_to_text=SeqsBatch(
                 src_tokens=src_tokens,

+ 9 - 4
scripts/m4t/finetune/trainer.py

@@ -86,9 +86,8 @@ class UnitYFinetuneWrapper(nn.Module):
 
     def forward(
         self, batch: dataloader.MultimodalSeqsBatch
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
         assert self.model.t2u_model is not None
-
         dummy_context = contextmanager(lambda: iter([None]))()
         with torch.no_grad() if self.freeze_s2t else dummy_context:  # type:ignore
             assert batch.speech_to_text.src_tokens is not None
@@ -96,6 +95,7 @@ class UnitYFinetuneWrapper(nn.Module):
                 seqs=batch.speech_to_text.src_tokens.to(self.device),
                 seq_lens=batch.speech_to_text.src_lengths.to(self.device),
             )
+            assert batch.speech_to_text.prev_output_tokens is not None
             text_decoder_out, text_decoder_padding_mask = self.model.decode(
                 seqs=batch.speech_to_text.prev_output_tokens.to(self.device),
                 seq_lens=batch.speech_to_text.target_lengths.to(self.device),
@@ -103,7 +103,8 @@ class UnitYFinetuneWrapper(nn.Module):
                 encoder_padding_mask=speech_encoder_padding_mask,
             )
             text_logits = self.model.final_proj(text_decoder_out)
-
+        if batch.text_to_units.prev_output_tokens is None:
+            return (text_logits, None)
         dummy_context = contextmanager(lambda: iter([None]))()
         with torch.no_grad() if self.freeze_t2u else dummy_context:  # type:ignore
             (
@@ -141,8 +142,9 @@ class CalcLoss:
         self,
         batch: dataloader.MultimodalSeqsBatch,
         text_logits: torch.Tensor,
-        unit_logits: torch.Tensor,
+        unit_logits: Optional[torch.Tensor],
     ) -> torch.Tensor:
+        assert batch.speech_to_text.target_lengths is not None
         s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(
             text_logits.device
         )
@@ -153,6 +155,9 @@ class CalcLoss:
             ignore_prefix_size=1,
             label_smoothing=self.label_smoothing,
         )
+        if unit_logits is None:
+            return s2t_loss / s2t_numel
+        assert batch.text_to_units.target_lengths is not None
         s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
         s2u_loss = SequenceModelOutput(
             logits=unit_logits, pad_idx=self.t2u_pad_idx