123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import json
- import logging
- from dataclasses import dataclass
- from typing import Any, Dict, Iterable, List, Optional
- import numpy as np
- import torch
- import torchaudio
- import torchaudio.compliance.kaldi as ta_kaldi
- from datasets import Dataset
- from datasets.distributed import split_dataset_by_node
- from fairseq2.models.nllb.tokenizer import NllbTokenizer, TextTokenEncoder
- from torch import Tensor
- from torch.nn.functional import pad as pad_tensor
- from torch.utils.data import DataLoader
- from seamless_communication.datasets.datatypes import LangPairSample
- from seamless_communication.models.unity.unit_tokenizer import (
- UnitTokenEncoder, UnitTokenizer)
- logger = logging.getLogger(__name__)
- @dataclass
- class SeqsBatch:
- src_tokens: Optional[Tensor]
- src_lengths: Optional[Tensor]
- target_tokens: Tensor
- prev_output_tokens: Tensor
- target_lengths: Tensor
- def __del__(self) -> None:
- """Explicitly delete tensors
- to force GPU memory cleanup"""
- for tensor in [
- self.src_tokens,
- self.src_lengths,
- self.target_tokens,
- self.prev_output_tokens,
- self.target_lengths,
- ]:
- if tensor is not None:
- del tensor
- @dataclass
- class MultimodalSeqsBatch:
- speech_to_text: SeqsBatch
- text_to_units: SeqsBatch
- def __del__(self) -> None:
- del self.speech_to_text
- del self.text_to_units
- @dataclass
- class BatchingConfig:
- fbank_feats_pad_idx: int = 0
- """The pad index to use in fbanks batching."""
- batch_size: int = 5
- rank: int = 0
- """The rank of this worker in the process group."""
- world_size: int = 1
- """The world size of the process group."""
- num_workers: int = 2
- """Parallelism in dataset preparation."""
- float_dtype: torch.dtype = torch.float16
- """Select between fp16/fp32 for float tensors """
- def worker_init_fn(worker_id):
- np.random.seed(np.random.get_state()[1][0] + worker_id)
- class UnitYDataLoader:
- def __init__(
- self,
- text_tokenizer: NllbTokenizer,
- unit_tokenizer: UnitTokenizer,
- dataset_manifest_path: str,
- batching_config: BatchingConfig,
- ):
- self.text_tokenizer = text_tokenizer
- self.text_encoders_per_lang: Dict[str, TextTokenEncoder] = {}
- self.unit_tokenizer = unit_tokenizer
- self.unit_encoders_per_lang: Dict[str, UnitTokenEncoder] = {}
- self.batching_config = batching_config
- self.dataset = self._load_manifest(dataset_manifest_path)
- def get_dataloader(self) -> DataLoader:
- subset = split_dataset_by_node(
- self.dataset,
- rank=self.batching_config.rank,
- world_size=self.batching_config.world_size,
- )
- data_loader = DataLoader(
- dataset=subset,
- batch_size=self.batching_config.batch_size,
- shuffle=True,
- num_workers=self.batching_config.num_workers,
- collate_fn=self._prepare_batch,
- worker_init_fn=worker_init_fn,
- )
- return data_loader
- def __iter__(self) -> Iterable[MultimodalSeqsBatch]:
- return self.get_dataloader().__iter__()
- def _get_source_fbank(self, sample: LangPairSample) -> Tensor:
- audio_input = torchaudio.load(sample.source.audio_local_path)[0]
- return ta_kaldi.fbank(audio_input, num_mel_bins=80)
- def _get_tokenized_target_text(self, sample: LangPairSample) -> Tensor:
- """Expected sequence is [<eos>, <lang_tok> , ..text tokens.., <eos>]"""
- target_lang = sample.target.lang
- if target_lang not in self.text_encoders_per_lang:
- self.text_encoders_per_lang[
- target_lang
- ] = self.text_tokenizer.create_encoder(lang=target_lang, mode="target")
- tokens = self.text_encoders_per_lang[target_lang](sample.target.text)
- eos_idx = self.text_tokenizer.vocab_info.eos_idx
- tokens = torch.concat([tokens, torch.LongTensor([eos_idx])])
- return tokens
- def _get_tokenized_units(self, sample: LangPairSample) -> Tensor:
- """Expected sequence is [<eos>, <lang_tok> , ..unit tokens.., <eos>]"""
- target_lang = sample.target.lang
- if target_lang not in self.unit_encoders_per_lang:
- self.unit_encoders_per_lang[
- target_lang
- ] = self.unit_tokenizer.create_encoder(lang=target_lang)
- tokens = self.unit_encoders_per_lang[target_lang](
- torch.LongTensor(sample.target.units).unsqueeze(0)
- )
- eos_idx = self.unit_tokenizer.vocab_info.eos_idx
- tokens = torch.concat([tokens.squeeze(0), torch.LongTensor([eos_idx])])
- return tokens
- def _batch_tensors(self, tensors: List[Tensor], pad_value: Any) -> Tensor:
- padding_size = max(tensor.shape[0] for tensor in tensors)
- dims = len(tensors[0].shape)
- padded_tensors = []
- for tensor in tensors:
- padding = [0] * 2 * dims
- padding[-1] = padding_size - tensor.shape[0]
- padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
- return torch.stack([tensor for tensor in padded_tensors], dim=0)
- def _prepare_batch(self, samples: List[Dict[str, Any]]) -> MultimodalSeqsBatch:
- samples = [LangPairSample.from_json(sample) for sample in samples]
- # input speech
- src_tokens_list = [self._get_source_fbank(sample) for sample in samples]
- src_tokens = self._batch_tensors(
- src_tokens_list, pad_value=self.batching_config.fbank_feats_pad_idx
- ).to(self.batching_config.float_dtype)
- src_lengths = torch.LongTensor(
- [src_tokens.shape[0] for src_tokens in src_tokens_list]
- )
- # output text
- text_tokens_list = [
- self._get_tokenized_target_text(sample) for sample in samples
- ]
- text_pad_idx = self.text_tokenizer.vocab_info.pad_idx
- prev_outputs_tokens = self._batch_tensors(
- [tokens[:-1] for tokens in text_tokens_list], pad_value=text_pad_idx
- )
- target_tokens = self._batch_tensors(
- [tokens[1:] for tokens in text_tokens_list], pad_value=text_pad_idx
- )
- tokens_lengths = torch.LongTensor(
- [tokens.shape[0] - 1 for tokens in text_tokens_list]
- )
- # output units
- units_list = [self._get_tokenized_units(sample) for sample in samples]
- units_pad_idx = self.unit_tokenizer.vocab_info.pad_idx
- prev_outputs_units = self._batch_tensors(
- [tokens[:-1] for tokens in units_list], pad_value=units_pad_idx
- )
- target_units = self._batch_tensors(
- [tokens[1:] for tokens in units_list], pad_value=units_pad_idx
- )
- units_lengths = torch.LongTensor([tokens.shape[0] - 1 for tokens in units_list])
- return MultimodalSeqsBatch(
- speech_to_text=SeqsBatch(
- src_tokens=src_tokens,
- src_lengths=src_lengths,
- target_tokens=target_tokens,
- prev_output_tokens=prev_outputs_tokens,
- target_lengths=tokens_lengths,
- ),
- text_to_units=SeqsBatch(
- src_tokens=None,
- src_lengths=None,
- target_tokens=target_units,
- prev_output_tokens=prev_outputs_units,
- target_lengths=units_lengths,
- ),
- )
- def _load_manifest(self, dataset_manifest_path: str) -> Dataset:
- with open(dataset_manifest_path) as fp_in:
- dataset = [json.loads(line) for line in fp_in]
- return Dataset.from_list(dataset)
|