Explorar el Código

Adjust for latest changes in fairseq2n

mavlyutov hace 1 año
padre
commit
d8c45cc9e7
Se han modificado 1 ficheros con 51 adiciones y 14 borrados
  1. 51 14
      scripts/m4t/train/trainer.py

+ 51 - 14
scripts/m4t/train/trainer.py

@@ -14,6 +14,7 @@ import torch
 import torch.distributed as dist
 import torch.nn as nn
 from fairseq2.models.sequence import SequenceModelOutput
+from fairseq2.nn.padding import PaddingMask
 from fairseq2.optim.lr_scheduler import MyleLR
 from m4t_scripts.train import dataloader, dist_utils
 from torch.optim import Adam
@@ -34,21 +35,33 @@ class UnitYTrainWrapper(nn.Module):
         if isinstance(self.model.t2u_model, UnitYT2UModel):
             self.t2u: UnitYT2UModel = self.model.t2u_model
         else:
-            raise NotImplementedError("Expand UnitYTrainWrapper supports only instances of UnitYT2UModel as t2u")
+            raise NotImplementedError(
+                "Expand UnitYTrainWrapper supports only instances of UnitYT2UModel as t2u"
+            )
 
-    def forward(self, batch: dataloader.MultimodalSeqsBatch) -> Tuple[torch.Tensor, torch.Tensor]:
+    def forward(
+        self, batch: dataloader.MultimodalSeqsBatch
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Forward pass, computes S2T and T2U losses"""
         assert self.model.t2u_model is not None
         assert batch.speech_to_text.src_tokens is not None
         # s2t
+        speech_padding_mask = PaddingMask(
+            seq_lens=batch.speech_to_text.src_lengths,
+            batch_seq_len=int(torch.max(batch.speech_to_text.src_lengths).item()),
+        )
         speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech(
             seqs=batch.speech_to_text.src_tokens,
-            seq_lens=batch.speech_to_text.src_lengths,
+            padding_mask=speech_padding_mask,
         )
         assert batch.speech_to_text.prev_output_tokens is not None
+        s2t_prev_out_tokens_padding_mask = PaddingMask(
+            seq_lens=batch.speech_to_text.target_lengths,
+            batch_seq_len=int(torch.max(batch.speech_to_text.target_lengths).item()),
+        )
         text_decoder_out, text_decoder_padding_mask = self.model.decode(
             seqs=batch.speech_to_text.prev_output_tokens,
-            seq_lens=batch.speech_to_text.target_lengths,
+            padding_mask=s2t_prev_out_tokens_padding_mask,
             encoder_output=speech_encoder_out,
             encoder_padding_mask=speech_encoder_padding_mask,
         )
@@ -61,9 +74,13 @@ class UnitYTrainWrapper(nn.Module):
             text_decoder_output=text_decoder_out,
             text_decoder_padding_mask=text_decoder_padding_mask,
         )
+        t2u_prev_out_tokens_padding_mask = PaddingMask(
+            seq_lens=batch.text_to_units.target_lengths,
+            batch_seq_len=int(torch.max(batch.text_to_units.target_lengths).item()),
+        )
         unit_decoder_out, _ = self.t2u.decode(
             seqs=batch.text_to_units.prev_output_tokens,
-            seq_lens=batch.text_to_units.target_lengths,
+            padding_mask=t2u_prev_out_tokens_padding_mask,
             encoder_output=unit_encoder_out,
             encoder_padding_mask=unit_encoder_padding_mask,
         )
@@ -94,15 +111,21 @@ class CalcLoss:
         unit_logits: torch.Tensor,
     ) -> torch.Tensor:
         assert batch.speech_to_text.target_lengths is not None
-        s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(text_logits.device)
-        s2t_loss = SequenceModelOutput(logits=text_logits, pad_idx=self.s2t_pad_idx).compute_loss(
+        s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(
+            text_logits.device
+        )
+        s2t_loss = SequenceModelOutput(
+            logits=text_logits, pad_idx=self.s2t_pad_idx
+        ).compute_loss(
             targets=batch.speech_to_text.target_tokens.to(text_logits.device),
             ignore_prefix_size=self.s2t_ignore_prefix_size,
             label_smoothing=self.label_smoothing,
         )
         assert batch.text_to_units.target_lengths is not None
         s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
-        s2u_loss = SequenceModelOutput(logits=unit_logits, pad_idx=self.t2u_pad_idx).compute_loss(
+        s2u_loss = SequenceModelOutput(
+            logits=unit_logits, pad_idx=self.t2u_pad_idx
+        ).compute_loss(
             targets=batch.text_to_units.target_tokens.to(unit_logits.device),
             ignore_prefix_size=1,
             label_smoothing=self.label_smoothing,
@@ -140,7 +163,10 @@ class LossCollector:
         if not self.is_distributed:
             return self.n_samples, self.val_sum
         local_val = torch.tensor([[self.n_samples, self.val_sum]], device=self.device)
-        all_vals = [torch.zeros((1, 2), device=self.device) for _ in range(dist_utils.get_world_size())]
+        all_vals = [
+            torch.zeros((1, 2), device=self.device)
+            for _ in range(dist_utils.get_world_size())
+        ]
         dist.all_gather(all_vals, local_val)
         losses = torch.concat(all_vals, dim=0)
         reduced = torch.sum(losses, dim=0).reshape(2).cpu()
@@ -245,7 +271,9 @@ class UnitYTrainer:
 
     def _get_avg_bsz(self) -> float:
         """Avg training batch size"""
-        return sum(self.batch_sizes) / len(self.batch_sizes) if self.batch_sizes else 0.0
+        return (
+            sum(self.batch_sizes) / len(self.batch_sizes) if self.batch_sizes else 0.0
+        )
 
     def _get_ups(self) -> float:
         """Updates per second"""
@@ -267,9 +295,13 @@ class UnitYTrainer:
 
     def _update_eval_stats(self, eval_loss: float) -> None:
         self.last_eval_loss = eval_loss
-        self.is_best_state = self.best_eval_loss is None or eval_loss < self.best_eval_loss
+        self.is_best_state = (
+            self.best_eval_loss is None or eval_loss < self.best_eval_loss
+        )
         self.best_eval_loss = eval_loss if self.is_best_state else self.best_eval_loss
-        self.patience_left = self.params.patience if self.is_best_state else self.patience_left - 1
+        self.patience_left = (
+            self.params.patience if self.is_best_state else self.patience_left - 1
+        )
         logger.info(
             f"Eval after {self.update_idx} updates: "
             f"loss={eval_loss:.4f} "
@@ -340,7 +372,10 @@ class UnitYTrainer:
 
     def _get_state(self) -> Dict[str, Any]:
         model_state_dict = self.model.state_dict()
-        model_state_dict = {key.replace("module.model.", ""): value for key, value in model_state_dict.items()}
+        model_state_dict = {
+            key.replace("module.model.", ""): value
+            for key, value in model_state_dict.items()
+        }
         return model_state_dict
 
     def _get_chck_path(self) -> str:
@@ -368,7 +403,9 @@ class UnitYTrainer:
                 if os.path.exists(best_link_path):
                     os.unlink(best_link_path)
                 os.symlink(save_path, best_link_path)
-                logger.info(f"Updating pointer to the best checkpoint {best_link_path} -> {save_path}")
+                logger.info(
+                    f"Updating pointer to the best checkpoint {best_link_path} -> {save_path}"
+                )
         if dist_utils.is_dist_initialized():
             dist.barrier()