|
@@ -11,7 +11,7 @@ from dataclasses import dataclass
|
|
from enum import Enum
|
|
from enum import Enum
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
-from typing import Optional, Tuple
|
|
|
|
|
|
+from typing import List, Optional, Tuple
|
|
|
|
|
|
import torch
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed as dist
|
|
@@ -21,7 +21,7 @@ from fairseq2.models.sequence import SequenceModelOutput
|
|
from fairseq2.nn.padding import PaddingMask
|
|
from fairseq2.nn.padding import PaddingMask
|
|
from fairseq2.optim.lr_scheduler import MyleLR
|
|
from fairseq2.optim.lr_scheduler import MyleLR
|
|
from fairseq2.typing import Device
|
|
from fairseq2.typing import Device
|
|
-from torch.optim import AdamW
|
|
|
|
|
|
+from torch.optim import AdamW, Adam
|
|
|
|
|
|
from seamless_communication.cli.m4t.finetune import dataloader, dist_utils
|
|
from seamless_communication.cli.m4t.finetune import dataloader, dist_utils
|
|
from seamless_communication.models.unity import (
|
|
from seamless_communication.models.unity import (
|
|
@@ -88,11 +88,17 @@ class UnitYFinetuneWrapper(nn.Module):
|
|
def __init__(self, model: UnitYModel, mode: FinetuneMode, device: Device):
|
|
def __init__(self, model: UnitYModel, mode: FinetuneMode, device: Device):
|
|
super().__init__()
|
|
super().__init__()
|
|
self.model: UnitYModel = model
|
|
self.model: UnitYModel = model
|
|
|
|
+ #self._freeze_module(self.model.speech_encoder_frontend)
|
|
|
|
+ #self._freeze_module(self.model.speech_encoder)
|
|
self.freeze_s2t: bool = mode == FinetuneMode.TEXT_TO_SPEECH
|
|
self.freeze_s2t: bool = mode == FinetuneMode.TEXT_TO_SPEECH
|
|
self.freeze_t2u: bool = mode == FinetuneMode.SPEECH_TO_TEXT
|
|
self.freeze_t2u: bool = mode == FinetuneMode.SPEECH_TO_TEXT
|
|
logger.info(f"Freeze s2t: {self.freeze_s2t}, freeze t2u: {self.freeze_t2u}")
|
|
logger.info(f"Freeze s2t: {self.freeze_s2t}, freeze t2u: {self.freeze_t2u}")
|
|
self.device = device
|
|
self.device = device
|
|
|
|
|
|
|
|
+ def _freeze_module(self, module: torch.nn.Module) -> None:
|
|
|
|
+ for param in module.parameters():
|
|
|
|
+ param.requires_grad = False
|
|
|
|
+
|
|
def forward(
|
|
def forward(
|
|
self, batch: dataloader.MultimodalSeqsBatch
|
|
self, batch: dataloader.MultimodalSeqsBatch
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
@@ -329,12 +335,11 @@ class UnitYFinetune:
|
|
assert batch.speech_to_text.src_tokens is not None
|
|
assert batch.speech_to_text.src_tokens is not None
|
|
with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
|
|
with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
|
|
loss = self.calc_loss(batch, *self.model(batch))
|
|
loss = self.calc_loss(batch, *self.model(batch))
|
|
- if loss.isnan():
|
|
|
|
- logger.warning("Eval loss value is NaN, setting to inf")
|
|
|
|
- loss_val = float("Inf")
|
|
|
|
- else:
|
|
|
|
- loss_val = loss.item()
|
|
|
|
del batch # force memory release
|
|
del batch # force memory release
|
|
|
|
+ if loss.isnan():
|
|
|
|
+ logger.warning(".. batch loss value is NaN, skipping")
|
|
|
|
+ continue
|
|
|
|
+ loss_val = loss.item()
|
|
loss_hist.update(1, loss_val)
|
|
loss_hist.update(1, loss_val)
|
|
eval_loss = loss_hist.reduce()
|
|
eval_loss = loss_hist.reduce()
|
|
self._update_eval_stats(eval_loss)
|
|
self._update_eval_stats(eval_loss)
|
|
@@ -351,13 +356,18 @@ class UnitYFinetune:
|
|
f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E}"
|
|
f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E}"
|
|
)
|
|
)
|
|
|
|
|
|
- def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None:
|
|
|
|
|
|
+ def _train_step(self, batches: List[dataloader.MultimodalSeqsBatch]) -> None:
|
|
"""Run one train step"""
|
|
"""Run one train step"""
|
|
self.model.train()
|
|
self.model.train()
|
|
self.optimizer.zero_grad()
|
|
self.optimizer.zero_grad()
|
|
- with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
|
|
|
|
- tokens, units = self.model(batch)
|
|
|
|
- loss = self.calc_loss(batch, tokens, units)
|
|
|
|
|
|
+ # logger.info(f"forward start {torch.cuda.memory_allocated(0) >> 30}g")
|
|
|
|
+ losses = []
|
|
|
|
+ for batch in batches:
|
|
|
|
+ with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
|
|
|
|
+ tokens, units = self.model(batch)
|
|
|
|
+ # logger.info(f"forward done {torch.cuda.memory_allocated(0) >> 30}g")
|
|
|
|
+ losses.append(self.calc_loss(batch, tokens, units))
|
|
|
|
+ loss = sum(losses) / len(losses)
|
|
if loss.isnan().any().item():
|
|
if loss.isnan().any().item():
|
|
logger.error(batch.speech_to_text)
|
|
logger.error(batch.speech_to_text)
|
|
raise RuntimeError("Loss is Nan. Terminating.")
|
|
raise RuntimeError("Loss is Nan. Terminating.")
|
|
@@ -365,6 +375,7 @@ class UnitYFinetune:
|
|
self.grad_scaler.step(self.optimizer)
|
|
self.grad_scaler.step(self.optimizer)
|
|
self.grad_scaler.update()
|
|
self.grad_scaler.update()
|
|
self.lr_scheduler.step()
|
|
self.lr_scheduler.step()
|
|
|
|
+ # logger.info(f"backward done {torch.cuda.memory_allocated(0) >> 30}g")
|
|
assert batch.speech_to_text.src_tokens is not None
|
|
assert batch.speech_to_text.src_tokens is not None
|
|
self.train_loss_hist.update(1, loss.item())
|
|
self.train_loss_hist.update(1, loss.item())
|
|
self._train_step_log()
|
|
self._train_step_log()
|
|
@@ -385,19 +396,24 @@ class UnitYFinetune:
|
|
self._reset_stats()
|
|
self._reset_stats()
|
|
self._eval_model()
|
|
self._eval_model()
|
|
batch_itr = self.train_data_loader.get_dataloader()
|
|
batch_itr = self.train_data_loader.get_dataloader()
|
|
|
|
+ batches_per_iter = 1
|
|
while self.epoch_idx < self.params.max_epochs and self.patience_left:
|
|
while self.epoch_idx < self.params.max_epochs and self.patience_left:
|
|
|
|
+ train_batches = []
|
|
for train_batch in batch_itr:
|
|
for train_batch in batch_itr:
|
|
- self._train_step(batch=train_batch)
|
|
|
|
- if self.update_idx and self.update_idx % self.params.eval_steps == 0:
|
|
|
|
- self._eval_model()
|
|
|
|
- if self.is_best_state:
|
|
|
|
- self._save_model()
|
|
|
|
- elif not self.patience_left:
|
|
|
|
- no_improve_steps = self.params.eval_steps * self.params.patience
|
|
|
|
- logger.info(
|
|
|
|
- "Early termination, as eval loss did not improve "
|
|
|
|
- f"over last {no_improve_steps} updates"
|
|
|
|
- )
|
|
|
|
- break
|
|
|
|
- self.update_idx += 1
|
|
|
|
|
|
+ train_batches.append(train_batch)
|
|
|
|
+ if len(train_batches) > batches_per_iter:
|
|
|
|
+ self._train_step(batches=train_batches)
|
|
|
|
+ train_batches = []
|
|
|
|
+ if self.update_idx and self.update_idx % self.params.eval_steps == 0:
|
|
|
|
+ self._eval_model()
|
|
|
|
+ if self.is_best_state:
|
|
|
|
+ self._save_model()
|
|
|
|
+ elif not self.patience_left:
|
|
|
|
+ no_improve_steps = self.params.eval_steps * self.params.patience
|
|
|
|
+ logger.info(
|
|
|
|
+ "Early termination, as eval loss did not improve "
|
|
|
|
+ f"over last {no_improve_steps} updates"
|
|
|
|
+ )
|
|
|
|
+ break
|
|
|
|
+ self.update_idx += 1
|
|
self.epoch_idx += 1
|
|
self.epoch_idx += 1
|