|
@@ -162,8 +162,8 @@ class UnitYDataLoader:
|
|
padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
|
|
padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
|
|
return torch.stack([tensor for tensor in padded_tensors], dim=0)
|
|
return torch.stack([tensor for tensor in padded_tensors], dim=0)
|
|
|
|
|
|
- def _prepare_batch(self, samples: List[Dict[str, Any]]) -> MultimodalSeqsBatch:
|
|
|
|
- samples = [LangPairSample.from_json(sample) for sample in samples]
|
|
|
|
|
|
+ def _prepare_batch(self, raw_samples: List[Dict[str, Any]]) -> MultimodalSeqsBatch:
|
|
|
|
+ samples = [LangPairSample.from_json(sample) for sample in raw_samples]
|
|
# input speech
|
|
# input speech
|
|
src_tokens_list = [self._get_source_fbank(sample) for sample in samples]
|
|
src_tokens_list = [self._get_source_fbank(sample) for sample in samples]
|
|
src_tokens = self._batch_tensors(
|
|
src_tokens = self._batch_tensors(
|