|
@@ -6,12 +6,13 @@
|
|
|
|
|
|
|
|
|
|
import logging
|
|
import logging
|
|
|
|
+import time
|
|
from contextlib import contextmanager
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from enum import Enum
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
-from typing import Optional, Tuple
|
|
|
|
|
|
+from typing import List, Optional, Tuple, Union
|
|
|
|
|
|
import torch
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed as dist
|
|
@@ -40,6 +41,9 @@ class FinetuneMode(Enum):
|
|
|
|
|
|
@dataclass
|
|
@dataclass
|
|
class FinetuneParams:
|
|
class FinetuneParams:
|
|
|
|
+ model_name: str
|
|
|
|
+ """Model name of model being finetuned."""
|
|
|
|
+
|
|
save_model_path: Path
|
|
save_model_path: Path
|
|
"""Path were to save finetuned model."""
|
|
"""Path were to save finetuned model."""
|
|
|
|
|
|
@@ -245,6 +249,7 @@ class UnitYFinetune:
|
|
params: FinetuneParams,
|
|
params: FinetuneParams,
|
|
train_data_loader: dataloader.UnitYDataLoader,
|
|
train_data_loader: dataloader.UnitYDataLoader,
|
|
eval_data_loader: Optional[dataloader.UnitYDataLoader] = None,
|
|
eval_data_loader: Optional[dataloader.UnitYDataLoader] = None,
|
|
|
|
+ freeze_modules: Optional[List[Union[str, torch.nn.Module]]] = None
|
|
):
|
|
):
|
|
self.params = params
|
|
self.params = params
|
|
self.calc_loss = CalcLoss(
|
|
self.calc_loss = CalcLoss(
|
|
@@ -254,9 +259,15 @@ class UnitYFinetune:
|
|
if model.t2u_model is not None
|
|
if model.t2u_model is not None
|
|
else None,
|
|
else None,
|
|
)
|
|
)
|
|
|
|
+
|
|
self.model = self._wrap_model_for_trainining(model=model)
|
|
self.model = self._wrap_model_for_trainining(model=model)
|
|
|
|
+ if freeze_modules:
|
|
|
|
+ self._freeze_modules(freeze_modules)
|
|
|
|
+
|
|
self.train_data_loader = train_data_loader
|
|
self.train_data_loader = train_data_loader
|
|
self.eval_data_loader = eval_data_loader
|
|
self.eval_data_loader = eval_data_loader
|
|
|
|
+
|
|
|
|
+ self.grad_scaler = torch.cuda.amp.GradScaler() # type: ignore
|
|
self.optimizer = AdamW(
|
|
self.optimizer = AdamW(
|
|
params=self.model.parameters(),
|
|
params=self.model.parameters(),
|
|
lr=self.params.learning_rate,
|
|
lr=self.params.learning_rate,
|
|
@@ -266,7 +277,6 @@ class UnitYFinetune:
|
|
weight_decay=0.0,
|
|
weight_decay=0.0,
|
|
fused=(self.params.device.type == "cuda"),
|
|
fused=(self.params.device.type == "cuda"),
|
|
)
|
|
)
|
|
- self.grad_scaler = torch.cuda.amp.GradScaler() # type: ignore
|
|
|
|
self.lr_scheduler = MyleLR(
|
|
self.lr_scheduler = MyleLR(
|
|
optimizer=self.optimizer,
|
|
optimizer=self.optimizer,
|
|
num_warmup_steps=self.params.warmup_steps,
|
|
num_warmup_steps=self.params.warmup_steps,
|
|
@@ -301,6 +311,14 @@ class UnitYFinetune:
|
|
device_ids=[dist_utils.get_local_rank()],
|
|
device_ids=[dist_utils.get_local_rank()],
|
|
find_unused_parameters=find_unused,
|
|
find_unused_parameters=find_unused,
|
|
)
|
|
)
|
|
|
|
+
|
|
|
|
+ def _freeze_modules(self, frozen_modules: List[str] = []) -> None:
|
|
|
|
+ for icecube in frozen_modules:
|
|
|
|
+ for (name, module) in self.model.named_modules():
|
|
|
|
+ if name.startswith(icecube):
|
|
|
|
+ logger.info(f"Freezing Module: {name}")
|
|
|
|
+ for param in module.parameters():
|
|
|
|
+ param.requires_grad = False
|
|
|
|
|
|
def _update_eval_stats(self, eval_loss: float) -> None:
|
|
def _update_eval_stats(self, eval_loss: float) -> None:
|
|
self.is_best_state = (
|
|
self.is_best_state = (
|
|
@@ -317,25 +335,26 @@ class UnitYFinetune:
|
|
f"patience_steps_left={self.patience_left}"
|
|
f"patience_steps_left={self.patience_left}"
|
|
)
|
|
)
|
|
|
|
|
|
- def _eval_model(self) -> None:
|
|
|
|
|
|
+ @torch.no_grad()
|
|
|
|
+ def _eval_model(self, n_batches: int) -> None:
|
|
"""Calc avg loss on eval dataset and update evaluation stats"""
|
|
"""Calc avg loss on eval dataset and update evaluation stats"""
|
|
if self.eval_data_loader is None:
|
|
if self.eval_data_loader is None:
|
|
return
|
|
return
|
|
- logger.info("Run evaluation")
|
|
|
|
|
|
+ logger.info(f"Evaluation Step {self.update_idx // self.params.eval_steps}...")
|
|
loss_hist = LossCollector(device=self.params.device)
|
|
loss_hist = LossCollector(device=self.params.device)
|
|
self.model.eval()
|
|
self.model.eval()
|
|
- with torch.no_grad():
|
|
|
|
- for batch in tqdm(self.eval_data_loader.get_dataloader()):
|
|
|
|
- assert batch.speech_to_text.src_tokens is not None
|
|
|
|
- with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
|
|
|
|
- loss = self.calc_loss(batch, *self.model(batch))
|
|
|
|
- if loss.isnan():
|
|
|
|
- logger.warning("Eval loss value is NaN, setting to inf")
|
|
|
|
- loss_val = float("Inf")
|
|
|
|
- else:
|
|
|
|
- loss_val = loss.item()
|
|
|
|
- del batch # force memory release
|
|
|
|
- loss_hist.update(1, loss_val)
|
|
|
|
|
|
+ for batch in self.eval_data_loader.get_dataloader():
|
|
|
|
+ if n_batches == 0:
|
|
|
|
+ break
|
|
|
|
+ assert batch.speech_to_text.src_tokens is not None
|
|
|
|
+ with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
|
|
|
|
+ loss = self.calc_loss(batch, *self.model(batch))
|
|
|
|
+ if loss.isnan():
|
|
|
|
+ logger.warning("Eval batch loss value is NaN, skipping")
|
|
|
|
+ continue
|
|
|
|
+ del batch # force memory release
|
|
|
|
+ loss_hist.update(1, loss.item())
|
|
|
|
+ n_batches -= 1
|
|
eval_loss = loss_hist.reduce()
|
|
eval_loss = loss_hist.reduce()
|
|
self._update_eval_stats(eval_loss)
|
|
self._update_eval_stats(eval_loss)
|
|
|
|
|
|
@@ -351,53 +370,70 @@ class UnitYFinetune:
|
|
f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E}"
|
|
f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E}"
|
|
)
|
|
)
|
|
|
|
|
|
- def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None:
|
|
|
|
|
|
+ def _train_step(self, batch: List[dataloader.MultimodalSeqsBatch]) -> None:
|
|
"""Run one train step"""
|
|
"""Run one train step"""
|
|
self.model.train()
|
|
self.model.train()
|
|
self.optimizer.zero_grad()
|
|
self.optimizer.zero_grad()
|
|
with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
|
|
with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
|
|
tokens, units = self.model(batch)
|
|
tokens, units = self.model(batch)
|
|
|
|
+
|
|
loss = self.calc_loss(batch, tokens, units)
|
|
loss = self.calc_loss(batch, tokens, units)
|
|
if loss.isnan().any().item():
|
|
if loss.isnan().any().item():
|
|
logger.error(batch.speech_to_text)
|
|
logger.error(batch.speech_to_text)
|
|
- raise RuntimeError("Loss is Nan. Terminating.")
|
|
|
|
|
|
+ raise RuntimeError("Train loss is NaN! Something is wrong in the model!")
|
|
|
|
+
|
|
self.grad_scaler.scale(loss).backward()
|
|
self.grad_scaler.scale(loss).backward()
|
|
self.grad_scaler.step(self.optimizer)
|
|
self.grad_scaler.step(self.optimizer)
|
|
self.grad_scaler.update()
|
|
self.grad_scaler.update()
|
|
self.lr_scheduler.step()
|
|
self.lr_scheduler.step()
|
|
|
|
+
|
|
assert batch.speech_to_text.src_tokens is not None
|
|
assert batch.speech_to_text.src_tokens is not None
|
|
self.train_loss_hist.update(1, loss.item())
|
|
self.train_loss_hist.update(1, loss.item())
|
|
self._train_step_log()
|
|
self._train_step_log()
|
|
|
|
+ self.update_idx += 1
|
|
|
|
|
|
def _save_model(self) -> None:
|
|
def _save_model(self) -> None:
|
|
logger.info("Saving model")
|
|
logger.info("Saving model")
|
|
if dist_utils.is_main_process():
|
|
if dist_utils.is_main_process():
|
|
- state_dict = {
|
|
|
|
- key.replace("module.model.", ""): value
|
|
|
|
- for key, value in self.model.state_dict().items()
|
|
|
|
- }
|
|
|
|
- torch.save(state_dict, self.params.save_model_path)
|
|
|
|
|
|
+ torch.save({
|
|
|
|
+ "model_name": self.params.model_name,
|
|
|
|
+ "model": {
|
|
|
|
+ key.replace("module.model.model.", ""): value
|
|
|
|
+ for key, value in self.model.state_dict().items()
|
|
|
|
+ }
|
|
|
|
+ }, self.params.save_model_path)
|
|
if dist_utils.is_dist_initialized():
|
|
if dist_utils.is_dist_initialized():
|
|
dist.barrier()
|
|
dist.barrier()
|
|
|
|
|
|
def run(self) -> None:
|
|
def run(self) -> None:
|
|
- logger.info("Start finetuning")
|
|
|
|
|
|
+ logger.info("Start Finetuning")
|
|
self._reset_stats()
|
|
self._reset_stats()
|
|
self._eval_model()
|
|
self._eval_model()
|
|
- batch_itr = self.train_data_loader.get_dataloader()
|
|
|
|
|
|
+
|
|
|
|
+ train_dataloader = self.train_data_loader.get_dataloader()
|
|
|
|
+
|
|
while self.epoch_idx < self.params.max_epochs and self.patience_left:
|
|
while self.epoch_idx < self.params.max_epochs and self.patience_left:
|
|
- for train_batch in batch_itr:
|
|
|
|
- self._train_step(batch=train_batch)
|
|
|
|
- if self.update_idx and self.update_idx % self.params.eval_steps == 0:
|
|
|
|
- self._eval_model()
|
|
|
|
- if self.is_best_state:
|
|
|
|
- self._save_model()
|
|
|
|
- elif not self.patience_left:
|
|
|
|
- no_improve_steps = self.params.eval_steps * self.params.patience
|
|
|
|
- logger.info(
|
|
|
|
- "Early termination, as eval loss did not improve "
|
|
|
|
- f"over last {no_improve_steps} updates"
|
|
|
|
- )
|
|
|
|
- break
|
|
|
|
- self.update_idx += 1
|
|
|
|
- self.epoch_idx += 1
|
|
|
|
|
|
+ for train_batch in tqdm(train_dataloader, desc="Training Steps"):
|
|
|
|
+ # Run batch through train step
|
|
|
|
+ self._train_step(train_batch)
|
|
|
|
+
|
|
|
|
+ # Perform eval if its time to eval
|
|
|
|
+ if not self.update_idx or self.update_idx % self.params.eval_steps != 0:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # Clear GPU memory for eval
|
|
|
|
+ torch.cuda.empty_cache()
|
|
|
|
+ self._eval_model(n_batches=100)
|
|
|
|
+
|
|
|
|
+ # Save the current model if its the best we've ever had
|
|
|
|
+ if self.is_best_state:
|
|
|
|
+ self._save_model()
|
|
|
|
+ elif not self.patience_left:
|
|
|
|
+ no_improve_steps = self.params.eval_steps * self.params.patience
|
|
|
|
+ logger.info(
|
|
|
|
+ "Early termination, as eval loss did not improve "
|
|
|
|
+ f"over last {no_improve_steps} updates"
|
|
|
|
+ )
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ self.epoch_idx += 1
|