|
@@ -86,9 +86,8 @@ class UnitYFinetuneWrapper(nn.Module):
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
self, batch: dataloader.MultimodalSeqsBatch
|
|
self, batch: dataloader.MultimodalSeqsBatch
|
|
- ) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
assert self.model.t2u_model is not None
|
|
assert self.model.t2u_model is not None
|
|
-
|
|
|
|
dummy_context = contextmanager(lambda: iter([None]))()
|
|
dummy_context = contextmanager(lambda: iter([None]))()
|
|
with torch.no_grad() if self.freeze_s2t else dummy_context: # type:ignore
|
|
with torch.no_grad() if self.freeze_s2t else dummy_context: # type:ignore
|
|
assert batch.speech_to_text.src_tokens is not None
|
|
assert batch.speech_to_text.src_tokens is not None
|
|
@@ -96,6 +95,7 @@ class UnitYFinetuneWrapper(nn.Module):
|
|
seqs=batch.speech_to_text.src_tokens.to(self.device),
|
|
seqs=batch.speech_to_text.src_tokens.to(self.device),
|
|
seq_lens=batch.speech_to_text.src_lengths.to(self.device),
|
|
seq_lens=batch.speech_to_text.src_lengths.to(self.device),
|
|
)
|
|
)
|
|
|
|
+ assert batch.speech_to_text.prev_output_tokens is not None
|
|
text_decoder_out, text_decoder_padding_mask = self.model.decode(
|
|
text_decoder_out, text_decoder_padding_mask = self.model.decode(
|
|
seqs=batch.speech_to_text.prev_output_tokens.to(self.device),
|
|
seqs=batch.speech_to_text.prev_output_tokens.to(self.device),
|
|
seq_lens=batch.speech_to_text.target_lengths.to(self.device),
|
|
seq_lens=batch.speech_to_text.target_lengths.to(self.device),
|
|
@@ -103,7 +103,8 @@ class UnitYFinetuneWrapper(nn.Module):
|
|
encoder_padding_mask=speech_encoder_padding_mask,
|
|
encoder_padding_mask=speech_encoder_padding_mask,
|
|
)
|
|
)
|
|
text_logits = self.model.final_proj(text_decoder_out)
|
|
text_logits = self.model.final_proj(text_decoder_out)
|
|
-
|
|
|
|
|
|
+ if batch.text_to_units.prev_output_tokens is None:
|
|
|
|
+ return (text_logits, None)
|
|
dummy_context = contextmanager(lambda: iter([None]))()
|
|
dummy_context = contextmanager(lambda: iter([None]))()
|
|
with torch.no_grad() if self.freeze_t2u else dummy_context: # type:ignore
|
|
with torch.no_grad() if self.freeze_t2u else dummy_context: # type:ignore
|
|
(
|
|
(
|
|
@@ -141,8 +142,9 @@ class CalcLoss:
|
|
self,
|
|
self,
|
|
batch: dataloader.MultimodalSeqsBatch,
|
|
batch: dataloader.MultimodalSeqsBatch,
|
|
text_logits: torch.Tensor,
|
|
text_logits: torch.Tensor,
|
|
- unit_logits: torch.Tensor,
|
|
|
|
|
|
+ unit_logits: Optional[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
|
|
+ assert batch.speech_to_text.target_lengths is not None
|
|
s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(
|
|
s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(
|
|
text_logits.device
|
|
text_logits.device
|
|
)
|
|
)
|
|
@@ -153,6 +155,9 @@ class CalcLoss:
|
|
ignore_prefix_size=1,
|
|
ignore_prefix_size=1,
|
|
label_smoothing=self.label_smoothing,
|
|
label_smoothing=self.label_smoothing,
|
|
)
|
|
)
|
|
|
|
+ if unit_logits is None:
|
|
|
|
+ return s2t_loss / s2t_numel
|
|
|
|
+ assert batch.text_to_units.target_lengths is not None
|
|
s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
|
|
s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
|
|
s2u_loss = SequenceModelOutput(
|
|
s2u_loss = SequenceModelOutput(
|
|
logits=unit_logits, pad_idx=self.t2u_pad_idx
|
|
logits=unit_logits, pad_idx=self.t2u_pad_idx
|