|
@@ -350,7 +350,8 @@ class UnitYFinetune:
|
|
|
"""Run one train step"""
|
|
|
self.model.train()
|
|
|
self.optimizer.zero_grad()
|
|
|
- tokens, units = self.model(batch)
|
|
|
+ with torch.autocast(device_type=self.params.device):
|
|
|
+ tokens, units = self.model(batch)
|
|
|
loss = self.calc_loss(batch, tokens, units)
|
|
|
if loss.isnan().any().item():
|
|
|
logger.error(batch.speech_to_text)
|