|
@@ -14,6 +14,7 @@ import torch
|
|
|
import torch.distributed as dist
|
|
|
import torch.nn as nn
|
|
|
from fairseq2.models.sequence import SequenceModelOutput
|
|
|
+from fairseq2.nn.padding import PaddingMask
|
|
|
from fairseq2.optim.lr_scheduler import MyleLR
|
|
|
from m4t_scripts.train import dataloader, dist_utils
|
|
|
from torch.optim import Adam
|
|
@@ -34,21 +35,33 @@ class UnitYTrainWrapper(nn.Module):
|
|
|
if isinstance(self.model.t2u_model, UnitYT2UModel):
|
|
|
self.t2u: UnitYT2UModel = self.model.t2u_model
|
|
|
else:
|
|
|
- raise NotImplementedError("Expand UnitYTrainWrapper supports only instances of UnitYT2UModel as t2u")
|
|
|
+ raise NotImplementedError(
|
|
|
+ "Expand UnitYTrainWrapper supports only instances of UnitYT2UModel as t2u"
|
|
|
+ )
|
|
|
|
|
|
- def forward(self, batch: dataloader.MultimodalSeqsBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
+ def forward(
|
|
|
+ self, batch: dataloader.MultimodalSeqsBatch
|
|
|
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""Forward pass, computes S2T and T2U losses"""
|
|
|
assert self.model.t2u_model is not None
|
|
|
assert batch.speech_to_text.src_tokens is not None
|
|
|
# s2t
|
|
|
+ speech_padding_mask = PaddingMask(
|
|
|
+ seq_lens=batch.speech_to_text.src_lengths,
|
|
|
+ batch_seq_len=int(torch.max(batch.speech_to_text.src_lengths).item()),
|
|
|
+ )
|
|
|
speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech(
|
|
|
seqs=batch.speech_to_text.src_tokens,
|
|
|
- seq_lens=batch.speech_to_text.src_lengths,
|
|
|
+ padding_mask=speech_padding_mask,
|
|
|
)
|
|
|
assert batch.speech_to_text.prev_output_tokens is not None
|
|
|
+ s2t_prev_out_tokens_padding_mask = PaddingMask(
|
|
|
+ seq_lens=batch.speech_to_text.target_lengths,
|
|
|
+ batch_seq_len=int(torch.max(batch.speech_to_text.target_lengths).item()),
|
|
|
+ )
|
|
|
text_decoder_out, text_decoder_padding_mask = self.model.decode(
|
|
|
seqs=batch.speech_to_text.prev_output_tokens,
|
|
|
- seq_lens=batch.speech_to_text.target_lengths,
|
|
|
+ padding_mask=s2t_prev_out_tokens_padding_mask,
|
|
|
encoder_output=speech_encoder_out,
|
|
|
encoder_padding_mask=speech_encoder_padding_mask,
|
|
|
)
|
|
@@ -61,9 +74,13 @@ class UnitYTrainWrapper(nn.Module):
|
|
|
text_decoder_output=text_decoder_out,
|
|
|
text_decoder_padding_mask=text_decoder_padding_mask,
|
|
|
)
|
|
|
+ t2u_prev_out_tokens_padding_mask = PaddingMask(
|
|
|
+ seq_lens=batch.text_to_units.target_lengths,
|
|
|
+ batch_seq_len=int(torch.max(batch.text_to_units.target_lengths).item()),
|
|
|
+ )
|
|
|
unit_decoder_out, _ = self.t2u.decode(
|
|
|
seqs=batch.text_to_units.prev_output_tokens,
|
|
|
- seq_lens=batch.text_to_units.target_lengths,
|
|
|
+ padding_mask=t2u_prev_out_tokens_padding_mask,
|
|
|
encoder_output=unit_encoder_out,
|
|
|
encoder_padding_mask=unit_encoder_padding_mask,
|
|
|
)
|
|
@@ -94,15 +111,21 @@ class CalcLoss:
|
|
|
unit_logits: torch.Tensor,
|
|
|
) -> torch.Tensor:
|
|
|
assert batch.speech_to_text.target_lengths is not None
|
|
|
- s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(text_logits.device)
|
|
|
- s2t_loss = SequenceModelOutput(logits=text_logits, pad_idx=self.s2t_pad_idx).compute_loss(
|
|
|
+ s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(
|
|
|
+ text_logits.device
|
|
|
+ )
|
|
|
+ s2t_loss = SequenceModelOutput(
|
|
|
+ logits=text_logits, pad_idx=self.s2t_pad_idx
|
|
|
+ ).compute_loss(
|
|
|
targets=batch.speech_to_text.target_tokens.to(text_logits.device),
|
|
|
ignore_prefix_size=self.s2t_ignore_prefix_size,
|
|
|
label_smoothing=self.label_smoothing,
|
|
|
)
|
|
|
assert batch.text_to_units.target_lengths is not None
|
|
|
s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
|
|
|
- s2u_loss = SequenceModelOutput(logits=unit_logits, pad_idx=self.t2u_pad_idx).compute_loss(
|
|
|
+ s2u_loss = SequenceModelOutput(
|
|
|
+ logits=unit_logits, pad_idx=self.t2u_pad_idx
|
|
|
+ ).compute_loss(
|
|
|
targets=batch.text_to_units.target_tokens.to(unit_logits.device),
|
|
|
ignore_prefix_size=1,
|
|
|
label_smoothing=self.label_smoothing,
|
|
@@ -140,7 +163,10 @@ class LossCollector:
|
|
|
if not self.is_distributed:
|
|
|
return self.n_samples, self.val_sum
|
|
|
local_val = torch.tensor([[self.n_samples, self.val_sum]], device=self.device)
|
|
|
- all_vals = [torch.zeros((1, 2), device=self.device) for _ in range(dist_utils.get_world_size())]
|
|
|
+ all_vals = [
|
|
|
+ torch.zeros((1, 2), device=self.device)
|
|
|
+ for _ in range(dist_utils.get_world_size())
|
|
|
+ ]
|
|
|
dist.all_gather(all_vals, local_val)
|
|
|
losses = torch.concat(all_vals, dim=0)
|
|
|
reduced = torch.sum(losses, dim=0).reshape(2).cpu()
|
|
@@ -245,7 +271,9 @@ class UnitYTrainer:
|
|
|
|
|
|
def _get_avg_bsz(self) -> float:
|
|
|
"""Avg training batch size"""
|
|
|
- return sum(self.batch_sizes) / len(self.batch_sizes) if self.batch_sizes else 0.0
|
|
|
+ return (
|
|
|
+ sum(self.batch_sizes) / len(self.batch_sizes) if self.batch_sizes else 0.0
|
|
|
+ )
|
|
|
|
|
|
def _get_ups(self) -> float:
|
|
|
"""Updates per second"""
|
|
@@ -267,9 +295,13 @@ class UnitYTrainer:
|
|
|
|
|
|
def _update_eval_stats(self, eval_loss: float) -> None:
|
|
|
self.last_eval_loss = eval_loss
|
|
|
- self.is_best_state = self.best_eval_loss is None or eval_loss < self.best_eval_loss
|
|
|
+ self.is_best_state = (
|
|
|
+ self.best_eval_loss is None or eval_loss < self.best_eval_loss
|
|
|
+ )
|
|
|
self.best_eval_loss = eval_loss if self.is_best_state else self.best_eval_loss
|
|
|
- self.patience_left = self.params.patience if self.is_best_state else self.patience_left - 1
|
|
|
+ self.patience_left = (
|
|
|
+ self.params.patience if self.is_best_state else self.patience_left - 1
|
|
|
+ )
|
|
|
logger.info(
|
|
|
f"Eval after {self.update_idx} updates: "
|
|
|
f"loss={eval_loss:.4f} "
|
|
@@ -340,7 +372,10 @@ class UnitYTrainer:
|
|
|
|
|
|
def _get_state(self) -> Dict[str, Any]:
|
|
|
model_state_dict = self.model.state_dict()
|
|
|
- model_state_dict = {key.replace("module.model.", ""): value for key, value in model_state_dict.items()}
|
|
|
+ model_state_dict = {
|
|
|
+ key.replace("module.model.", ""): value
|
|
|
+ for key, value in model_state_dict.items()
|
|
|
+ }
|
|
|
return model_state_dict
|
|
|
|
|
|
def _get_chck_path(self) -> str:
|
|
@@ -368,7 +403,9 @@ class UnitYTrainer:
|
|
|
if os.path.exists(best_link_path):
|
|
|
os.unlink(best_link_path)
|
|
|
os.symlink(save_path, best_link_path)
|
|
|
- logger.info(f"Updating pointer to the best checkpoint {best_link_path} -> {save_path}")
|
|
|
+ logger.info(
|
|
|
+ f"Updating pointer to the best checkpoint {best_link_path} -> {save_path}"
|
|
|
+ )
|
|
|
if dist_utils.is_dist_initialized():
|
|
|
dist.barrier()
|
|
|
|