|
@@ -20,10 +20,13 @@ from fairseq2.models.sequence import SequenceModelOutput
|
|
from fairseq2.nn.padding import PaddingMask
|
|
from fairseq2.nn.padding import PaddingMask
|
|
from fairseq2.optim.lr_scheduler import MyleLR
|
|
from fairseq2.optim.lr_scheduler import MyleLR
|
|
from fairseq2.typing import Device
|
|
from fairseq2.typing import Device
|
|
-from torch.optim import Adam
|
|
|
|
|
|
+from torch.optim import AdamW
|
|
|
|
|
|
from seamless_communication.cli.m4t.finetune import dataloader, dist_utils
|
|
from seamless_communication.cli.m4t.finetune import dataloader, dist_utils
|
|
-from seamless_communication.models.unity import UnitYModel
|
|
|
|
|
|
+from seamless_communication.models.unity import (
|
|
|
|
+ UnitYModel,
|
|
|
|
+ UnitYT2UModel,
|
|
|
|
+)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@@ -80,26 +83,27 @@ class UnitYFinetuneWrapper(nn.Module):
|
|
|
|
|
|
def __init__(self, model: UnitYModel, mode: FinetuneMode, device: Device):
|
|
def __init__(self, model: UnitYModel, mode: FinetuneMode, device: Device):
|
|
super().__init__()
|
|
super().__init__()
|
|
- assert model.t2u_model is not None
|
|
|
|
self.model: UnitYModel = model
|
|
self.model: UnitYModel = model
|
|
self.freeze_s2t: bool = mode == FinetuneMode.TEXT_TO_SPEECH
|
|
self.freeze_s2t: bool = mode == FinetuneMode.TEXT_TO_SPEECH
|
|
self.freeze_t2u: bool = mode == FinetuneMode.SPEECH_TO_TEXT
|
|
self.freeze_t2u: bool = mode == FinetuneMode.SPEECH_TO_TEXT
|
|
|
|
+ logger.info(f"Freeze s2t: {self.freeze_s2t}, freeze t2u: {self.freeze_t2u}")
|
|
self.device = device
|
|
self.device = device
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
self, batch: dataloader.MultimodalSeqsBatch
|
|
self, batch: dataloader.MultimodalSeqsBatch
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
- assert self.model.t2u_model is not None
|
|
|
|
dummy_context = contextmanager(lambda: iter([None]))()
|
|
dummy_context = contextmanager(lambda: iter([None]))()
|
|
with torch.no_grad() if self.freeze_s2t else dummy_context: # type:ignore
|
|
with torch.no_grad() if self.freeze_s2t else dummy_context: # type:ignore
|
|
assert batch.speech_to_text.src_tokens is not None
|
|
assert batch.speech_to_text.src_tokens is not None
|
|
seqs = batch.speech_to_text.src_tokens.to(self.device)
|
|
seqs = batch.speech_to_text.src_tokens.to(self.device)
|
|
|
|
+ assert batch.speech_to_text.src_lengths is not None
|
|
seq_lens = batch.speech_to_text.src_lengths.to(self.device)
|
|
seq_lens = batch.speech_to_text.src_lengths.to(self.device)
|
|
speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech(
|
|
speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech(
|
|
seqs=seqs, padding_mask=PaddingMask(seq_lens, seqs.size(1))
|
|
seqs=seqs, padding_mask=PaddingMask(seq_lens, seqs.size(1))
|
|
)
|
|
)
|
|
assert batch.speech_to_text.prev_output_tokens is not None
|
|
assert batch.speech_to_text.prev_output_tokens is not None
|
|
seqs = batch.speech_to_text.prev_output_tokens.to(self.device)
|
|
seqs = batch.speech_to_text.prev_output_tokens.to(self.device)
|
|
|
|
+ assert batch.speech_to_text.target_lengths is not None
|
|
seq_lens = batch.speech_to_text.target_lengths.to(self.device)
|
|
seq_lens = batch.speech_to_text.target_lengths.to(self.device)
|
|
text_decoder_out, text_decoder_padding_mask = self.model.decode(
|
|
text_decoder_out, text_decoder_padding_mask = self.model.decode(
|
|
seqs=seqs,
|
|
seqs=seqs,
|
|
@@ -107,19 +111,27 @@ class UnitYFinetuneWrapper(nn.Module):
|
|
encoder_output=speech_encoder_out,
|
|
encoder_output=speech_encoder_out,
|
|
encoder_padding_mask=speech_encoder_padding_mask,
|
|
encoder_padding_mask=speech_encoder_padding_mask,
|
|
)
|
|
)
|
|
|
|
+ assert self.model.final_proj is not None
|
|
text_logits = self.model.final_proj(text_decoder_out)
|
|
text_logits = self.model.final_proj(text_decoder_out)
|
|
- if batch.text_to_units.prev_output_tokens is None:
|
|
|
|
|
|
+ if self.freeze_t2u:
|
|
return (text_logits, None)
|
|
return (text_logits, None)
|
|
|
|
+ assert self.model.t2u_model is not None
|
|
|
|
+ assert batch.text_to_units.prev_output_tokens is not None
|
|
dummy_context = contextmanager(lambda: iter([None]))()
|
|
dummy_context = contextmanager(lambda: iter([None]))()
|
|
with torch.no_grad() if self.freeze_t2u else dummy_context: # type:ignore
|
|
with torch.no_grad() if self.freeze_t2u else dummy_context: # type:ignore
|
|
|
|
+ if not isinstance(self.model.t2u_model, UnitYT2UModel):
|
|
|
|
+ raise NotImplementedError(
|
|
|
|
+ "T2U finetuning implemented only for UnitYT2UModel"
|
|
|
|
+ )
|
|
(
|
|
(
|
|
unit_encoder_out,
|
|
unit_encoder_out,
|
|
unit_encoder_padding_mask,
|
|
unit_encoder_padding_mask,
|
|
) = self.model.t2u_model.encode(
|
|
) = self.model.t2u_model.encode(
|
|
- text_decoder_output=text_decoder_out,
|
|
|
|
- text_decoder_padding_mask=text_decoder_padding_mask,
|
|
|
|
|
|
+ seqs=text_decoder_out,
|
|
|
|
+ padding_mask=text_decoder_padding_mask,
|
|
)
|
|
)
|
|
seqs = batch.text_to_units.prev_output_tokens.to(self.device)
|
|
seqs = batch.text_to_units.prev_output_tokens.to(self.device)
|
|
|
|
+ assert batch.text_to_units.target_lengths is not None
|
|
seq_lens = batch.text_to_units.target_lengths.to(self.device)
|
|
seq_lens = batch.text_to_units.target_lengths.to(self.device)
|
|
unit_decoder_out, _ = self.model.t2u_model.decode(
|
|
unit_decoder_out, _ = self.model.t2u_model.decode(
|
|
seqs=seqs,
|
|
seqs=seqs,
|
|
@@ -139,7 +151,7 @@ class CalcLoss:
|
|
self,
|
|
self,
|
|
label_smoothing: float,
|
|
label_smoothing: float,
|
|
s2t_vocab_info: VocabularyInfo,
|
|
s2t_vocab_info: VocabularyInfo,
|
|
- t2u_vocab_info: VocabularyInfo,
|
|
|
|
|
|
+ t2u_vocab_info: Optional[VocabularyInfo],
|
|
):
|
|
):
|
|
self.label_smoothing = label_smoothing
|
|
self.label_smoothing = label_smoothing
|
|
self.s2t_vocab_info = s2t_vocab_info
|
|
self.s2t_vocab_info = s2t_vocab_info
|
|
@@ -152,25 +164,31 @@ class CalcLoss:
|
|
unit_logits: Optional[torch.Tensor],
|
|
unit_logits: Optional[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
assert batch.speech_to_text.target_lengths is not None
|
|
assert batch.speech_to_text.target_lengths is not None
|
|
- s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(
|
|
|
|
|
|
+ prefix_skip_len = 1 # language tokens to skip
|
|
|
|
+ s2t_numel = torch.sum(batch.speech_to_text.target_lengths - prefix_skip_len).to(
|
|
text_logits.device
|
|
text_logits.device
|
|
)
|
|
)
|
|
|
|
+ assert batch.speech_to_text.target_tokens is not None
|
|
s2t_loss = SequenceModelOutput(
|
|
s2t_loss = SequenceModelOutput(
|
|
logits=text_logits, vocab_info=self.s2t_vocab_info
|
|
logits=text_logits, vocab_info=self.s2t_vocab_info
|
|
).compute_loss(
|
|
).compute_loss(
|
|
targets=batch.speech_to_text.target_tokens.to(text_logits.device),
|
|
targets=batch.speech_to_text.target_tokens.to(text_logits.device),
|
|
- ignore_prefix_size=1,
|
|
|
|
|
|
+ ignore_prefix_size=prefix_skip_len,
|
|
label_smoothing=self.label_smoothing,
|
|
label_smoothing=self.label_smoothing,
|
|
)
|
|
)
|
|
if unit_logits is None:
|
|
if unit_logits is None:
|
|
return s2t_loss / s2t_numel
|
|
return s2t_loss / s2t_numel
|
|
assert batch.text_to_units.target_lengths is not None
|
|
assert batch.text_to_units.target_lengths is not None
|
|
- s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
|
|
|
|
|
|
+ s2u_numel = torch.sum(batch.text_to_units.target_lengths - prefix_skip_len).to(
|
|
|
|
+ unit_logits.device
|
|
|
|
+ )
|
|
|
|
+ assert batch.text_to_units.target_tokens is not None
|
|
|
|
+ assert self.t2u_vocab_info is not None
|
|
s2u_loss = SequenceModelOutput(
|
|
s2u_loss = SequenceModelOutput(
|
|
logits=unit_logits, vocab_info=self.t2u_vocab_info
|
|
logits=unit_logits, vocab_info=self.t2u_vocab_info
|
|
).compute_loss(
|
|
).compute_loss(
|
|
targets=batch.text_to_units.target_tokens.to(unit_logits.device),
|
|
targets=batch.text_to_units.target_tokens.to(unit_logits.device),
|
|
- ignore_prefix_size=1,
|
|
|
|
|
|
+ ignore_prefix_size=prefix_skip_len,
|
|
label_smoothing=self.label_smoothing,
|
|
label_smoothing=self.label_smoothing,
|
|
)
|
|
)
|
|
return s2t_loss / s2t_numel + s2u_loss / s2u_numel
|
|
return s2t_loss / s2t_numel + s2u_loss / s2u_numel
|
|
@@ -225,17 +243,17 @@ class UnitYFinetune:
|
|
eval_data_loader: Optional[dataloader.UnitYDataLoader] = None,
|
|
eval_data_loader: Optional[dataloader.UnitYDataLoader] = None,
|
|
):
|
|
):
|
|
self.params = params
|
|
self.params = params
|
|
-
|
|
|
|
- assert model.t2u_model is not None
|
|
|
|
self.calc_loss = CalcLoss(
|
|
self.calc_loss = CalcLoss(
|
|
label_smoothing=self.params.label_smoothing,
|
|
label_smoothing=self.params.label_smoothing,
|
|
s2t_vocab_info=model.target_vocab_info,
|
|
s2t_vocab_info=model.target_vocab_info,
|
|
- t2u_vocab_info=model.t2u_model.target_vocab_info,
|
|
|
|
|
|
+ t2u_vocab_info=model.t2u_model.target_vocab_info
|
|
|
|
+ if model.t2u_model is not None
|
|
|
|
+ else None,
|
|
)
|
|
)
|
|
self.model = self._wrap_model_for_trainining(model=model)
|
|
self.model = self._wrap_model_for_trainining(model=model)
|
|
self.train_data_loader = train_data_loader
|
|
self.train_data_loader = train_data_loader
|
|
self.eval_data_loader = eval_data_loader
|
|
self.eval_data_loader = eval_data_loader
|
|
- self.optimizer = Adam(
|
|
|
|
|
|
+ self.optimizer = AdamW(
|
|
params=self.model.parameters(),
|
|
params=self.model.parameters(),
|
|
lr=self.params.learning_rate,
|
|
lr=self.params.learning_rate,
|
|
betas=(0.9, 0.98),
|
|
betas=(0.9, 0.98),
|
|
@@ -244,7 +262,7 @@ class UnitYFinetune:
|
|
weight_decay=0.0,
|
|
weight_decay=0.0,
|
|
fused=True,
|
|
fused=True,
|
|
)
|
|
)
|
|
- self.grad_scaler = torch.cuda.amp.GradScaler()
|
|
|
|
|
|
+ self.grad_scaler = torch.cuda.amp.GradScaler() # type: ignore
|
|
self.lr_scheduler = MyleLR(
|
|
self.lr_scheduler = MyleLR(
|
|
optimizer=self.optimizer,
|
|
optimizer=self.optimizer,
|
|
num_warmup_steps=self.params.warmup_steps,
|
|
num_warmup_steps=self.params.warmup_steps,
|
|
@@ -257,6 +275,7 @@ class UnitYFinetune:
|
|
self.patience_left: int = self.params.patience
|
|
self.patience_left: int = self.params.patience
|
|
self.best_eval_loss: Optional[float] = None
|
|
self.best_eval_loss: Optional[float] = None
|
|
self.is_best_state: bool = False
|
|
self.is_best_state: bool = False
|
|
|
|
+ torch.set_float32_matmul_precision("high")
|
|
|
|
|
|
def _reset_stats(self) -> None:
|
|
def _reset_stats(self) -> None:
|
|
self.train_loss_hist.reset()
|
|
self.train_loss_hist.reset()
|
|
@@ -272,10 +291,11 @@ class UnitYFinetune:
|
|
)
|
|
)
|
|
if not dist_utils.is_dist_initialized():
|
|
if not dist_utils.is_dist_initialized():
|
|
return wrapped_model
|
|
return wrapped_model
|
|
|
|
+ find_unused = self.params.finetune_mode == FinetuneMode.TEXT_TO_SPEECH
|
|
return nn.parallel.DistributedDataParallel(
|
|
return nn.parallel.DistributedDataParallel(
|
|
wrapped_model,
|
|
wrapped_model,
|
|
device_ids=[dist_utils.get_local_rank()],
|
|
device_ids=[dist_utils.get_local_rank()],
|
|
- find_unused_parameters=True,
|
|
|
|
|
|
+ find_unused_parameters=find_unused,
|
|
)
|
|
)
|
|
|
|
|
|
def _update_eval_stats(self, eval_loss: float) -> None:
|
|
def _update_eval_stats(self, eval_loss: float) -> None:
|
|
@@ -314,7 +334,7 @@ class UnitYFinetune:
|
|
eval_loss = loss_hist.reduce()
|
|
eval_loss = loss_hist.reduce()
|
|
self._update_eval_stats(eval_loss)
|
|
self._update_eval_stats(eval_loss)
|
|
|
|
|
|
- def _train_step_log(self):
|
|
|
|
|
|
+ def _train_step_log(self) -> None:
|
|
"""Log train stats"""
|
|
"""Log train stats"""
|
|
if (self.update_idx + 1) % self.params.log_steps == 0:
|
|
if (self.update_idx + 1) % self.params.log_steps == 0:
|
|
avg_loss = self.train_loss_hist.reduce()
|
|
avg_loss = self.train_loss_hist.reduce()
|
|
@@ -332,6 +352,9 @@ class UnitYFinetune:
|
|
self.optimizer.zero_grad()
|
|
self.optimizer.zero_grad()
|
|
tokens, units = self.model(batch)
|
|
tokens, units = self.model(batch)
|
|
loss = self.calc_loss(batch, tokens, units)
|
|
loss = self.calc_loss(batch, tokens, units)
|
|
|
|
+ if loss.isnan().any().item():
|
|
|
|
+ logger.error(batch.speech_to_text)
|
|
|
|
+ raise RuntimeError("Loss is Nan. Terminating.")
|
|
self.grad_scaler.scale(loss).backward()
|
|
self.grad_scaler.scale(loss).backward()
|
|
self.grad_scaler.step(self.optimizer)
|
|
self.grad_scaler.step(self.optimizer)
|
|
self.grad_scaler.update()
|
|
self.grad_scaler.update()
|
|
@@ -340,7 +363,7 @@ class UnitYFinetune:
|
|
self.train_loss_hist.update(1, loss.item())
|
|
self.train_loss_hist.update(1, loss.item())
|
|
self._train_step_log()
|
|
self._train_step_log()
|
|
|
|
|
|
- def _save_model(self):
|
|
|
|
|
|
+ def _save_model(self) -> None:
|
|
logger.info("Saving model")
|
|
logger.info("Saving model")
|
|
if dist_utils.is_main_process():
|
|
if dist_utils.is_main_process():
|
|
state_dict = {
|
|
state_dict = {
|
|
@@ -351,7 +374,7 @@ class UnitYFinetune:
|
|
if dist_utils.is_dist_initialized():
|
|
if dist_utils.is_dist_initialized():
|
|
dist.barrier()
|
|
dist.barrier()
|
|
|
|
|
|
- def run(self):
|
|
|
|
|
|
+ def run(self) -> None:
|
|
logger.info("Start finetuning")
|
|
logger.info("Start finetuning")
|
|
self._reset_stats()
|
|
self._reset_stats()
|
|
self._eval_model()
|
|
self._eval_model()
|