Răsfoiți Sursa

lid classification head training script

Ruslan Mavlyutov 1 an în urmă
părinte
comite
869bd731ec

+ 0 - 0
src/seamless_communication/cli/m4t/classification_head/__init__.py


+ 197 - 0
src/seamless_communication/cli/m4t/classification_head/dataloader.py

@@ -0,0 +1,197 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+
+import json
+import logging
+from dataclasses import dataclass
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+import numpy as np
+import torch
+import torchaudio
+from datasets import Dataset
+from datasets.distributed import split_dataset_by_node
+from fairseq2.data.text import TextTokenEncoder
+from fairseq2.models.nllb import NllbTokenizer
+from fairseq2.data.audio import WaveformToFbankConverter
+from torch import Tensor
+from torch.nn.functional import pad as pad_tensor
+from torch.utils.data import DataLoader
+from sklearn.preprocessing import LabelEncoder
+
+from seamless_communication.datasets.datatypes import LangPairSample
+from seamless_communication.models.unity.unit_tokenizer import (
+    UnitTokenEncoder,
+    UnitTokenizer,
+)
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class SeqsBatch:
+    src_tokens: Optional[Tensor]
+    src_lengths: Optional[Tensor]
+
+    def __del__(self) -> None:
+        """Explicitly delete tensors
+        to force GPU memory cleanup"""
+        for tensor in [self.src_tokens, self.src_lengths]:
+            if tensor is not None:
+                del tensor
+
+
+@dataclass
+class BatchingConfig:
+    fbank_feats_pad_idx: int = 0
+    """The pad index to use in fbanks batching."""
+
+    batch_size: int = 5
+    """Fixed batch size to use"""
+
+    max_audio_length_sec: float = 15.0
+    """ Drop samples with source audio sample length above the threshold."""
+
+    rank: int = 0
+    """The rank of this worker in the process group."""
+
+    world_size: int = 1
+    """The world size of the process group."""
+
+    num_workers: int = 2
+    """Parallelism in dataset preparation."""
+
+    float_dtype: torch.dtype = torch.float16
+    """Select between fp16/fp32 for float tensors """
+
+    langs: Tuple[str] = ("eng", "fra", "deu", "rus", "spa")
+    """Class labels"""
+
+
+def worker_init_fn(worker_id: int) -> None:
+    np.random.seed(np.random.get_state()[1][0] + worker_id)  # type: ignore
+
+
+class UnitYLanguageIDDataLoader:
+    SAMPLE_RATE = 16_000
+
+    def __init__(
+        self,
+        num_languages: int,
+        text_tokenizer: NllbTokenizer,
+        unit_tokenizer: UnitTokenizer,
+        dataset_manifest_path: str,
+        batching_config: BatchingConfig,
+    ):
+        self.num_languages = num_languages
+        self.text_tokenizer = text_tokenizer
+        self.text_encoders_per_lang: Dict[str, TextTokenEncoder] = {}
+        self.unit_tokenizer = unit_tokenizer
+        self.unit_encoders_per_lang: Dict[str, UnitTokenEncoder] = {}
+        self.batching_config = batching_config
+        self._fbank_extract_params = {
+            "num_mel_bins": 80,
+            "waveform_scale": 32768,
+            "channel_last": True,
+            "standardize": True,
+            "device": torch.device("cpu"),
+            "dtype": self.batching_config.float_dtype,
+        }
+        self.dataset = self._load_manifest(dataset_manifest_path)
+
+    def get_dataloader(self) -> DataLoader[SeqsBatch]:
+        subset = split_dataset_by_node(
+            self.dataset,
+            rank=self.batching_config.rank,
+            world_size=self.batching_config.world_size,
+        )
+        data_loader = DataLoader(
+            dataset=subset,
+            batch_size=self.batching_config.batch_size,
+            shuffle=True,
+            num_workers=self.batching_config.num_workers,
+            collate_fn=self._collate,
+            worker_init_fn=worker_init_fn,
+        )
+        return data_loader
+
+    def __iter__(self) -> Iterable[Any]:
+        return self.get_dataloader().__iter__()
+
+    def _get_source_fbank(self, sample: LangPairSample) -> Tensor:
+        wav, sample_rate = torchaudio.load(sample.source.audio_local_path)
+        assert (
+            int(sample_rate) == self.SAMPLE_RATE
+        ), f"sample != {self.SAMPLE_RATE}, please resample"
+        assert len(wav.shape) in (1, 2)
+        if len(wav.shape) == 1:
+            wav = wav.unsqueeze(-1)
+        elif wav.shape[0] <= 2:  # channel is first, should be second
+            wav = wav.transpose(0, 1)
+        return WaveformToFbankConverter(**self._fbank_extract_params)(  # type: ignore
+            {
+                "waveform": wav,
+                "sample_rate": self.SAMPLE_RATE,
+            }
+        )["fbank"]
+
+    def _batch_tensors(self, tensors: List[Tensor], pad_value: Any) -> Tensor:
+        padding_size = max(tensor.shape[0] for tensor in tensors)
+        dims = len(tensors[0].shape)
+        padded_tensors = []
+        for tensor in tensors:
+            padding = [0] * 2 * dims
+            padding[-1] = padding_size - tensor.shape[0]
+            padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
+        return torch.stack([tensor for tensor in padded_tensors], dim=0)
+
+    def _is_long_src_audio(self, sample: LangPairSample) -> bool:
+        # HACK:: causes errored audios to be excluded but this is difficult to follow
+        try:
+            wav, sample_rate = torchaudio.load(sample.source.audio_local_path)
+            length_s: float = max(wav.shape) / sample_rate
+            return length_s > self.batching_config.max_audio_length_sec
+        except Exception:
+            logger.exception(
+                f"Failed to load sample path: {sample.source.audio_local_path}"
+            )
+            return True
+
+    def _collate(self, raw_samples: List[Dict[str, Any]]) -> Tuple[SeqsBatch, torch.Tensor]:
+        samples = [LangPairSample.from_json(sample) for sample in raw_samples]
+
+        # Input Speech
+
+        # 1 - filter long audio samples
+        filtered_samples = [
+            sample for sample in samples if not self._is_long_src_audio(sample)
+        ]
+        samples = (
+            filtered_samples if filtered_samples else [samples[0]]
+        )  # keep at least one sample
+        src_tokens_list = [self._get_source_fbank(sample) for sample in samples]
+
+        # 2 - filter NaNs in fbanks´´
+        with_nans = [fbank.isnan().any().item() for fbank in src_tokens_list]
+        samples = [sample for sample, skip in zip(samples, with_nans) if not skip]
+        assert len(samples) > 0
+        src_tokens_list = [
+            tok for tok, skip in zip(src_tokens_list, with_nans) if not skip
+        ]
+        src_tokens = self._batch_tensors(
+            src_tokens_list, pad_value=self.batching_config.fbank_feats_pad_idx
+        ).to(self.batching_config.float_dtype)
+        src_lengths = torch.LongTensor([tok.shape[0] for tok in src_tokens_list])
+        source_lang_ids = torch.LongTensor([self.batching_config.langs.index(sample.source.lang) for sample in samples])
+        # logger.info(f"Batch size {source_lang_ids.shape}, lengths: {src_lengths}, labels: {source_lang_ids}")
+
+        return SeqsBatch(src_tokens=src_tokens, src_lengths=src_lengths), source_lang_ids
+
+    def _load_manifest(self, dataset_manifest_path: str) -> Dataset:
+        with open(dataset_manifest_path) as fp_in:
+            dataset = [json.loads(line) for line in fp_in]
+            return Dataset.from_list(dataset)

+ 209 - 0
src/seamless_communication/cli/m4t/classification_head/dataset.py

@@ -0,0 +1,209 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+import argparse
+import dataclasses
+import json
+import logging
+import os
+from pathlib import Path
+
+import torch
+import torchaudio
+
+from datasets import load_dataset
+from seamless_communication.datasets.huggingface import (
+    SpeechTokenizer,
+)
+from seamless_communication.models.unit_extractor import UnitExtractor
+
+from seamless_communication.datasets.datatypes import LangPairSample, MultimodalSample
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger("dataset")
+
+UNITY_TO_COMMON_VOICE_LANG_MAPPING = {
+    "eng": "en",
+    "ita": "it",
+    "afr": "af",
+    "asm": "as",
+    "bel": "be",
+    "bul": "bg",
+    "ben": "bn",
+    "cat": "ca",
+    "ces": "cs",
+    "dan": "da",
+    "deu": "de",
+    "ell": "el",
+    "fin": "fi",
+    "fra": "fr",
+    "glg": "gl",
+    "heb": "he",
+    "hin": "hi",
+    "hrv": "hr",
+    "hun": "hu",
+    "ind": "id",
+    "ibo": "ig",
+    "isl": "is",
+    "jpn": "ja",
+    "jav": "jv",
+    "kaz": "kk",
+    "kan": "kn",
+    "kir": "ky",
+    "kor": "ko",
+    "lit": "lt",
+    "mkd": "mk",
+    "mlt": "mt",
+    "mya": "my",
+    "nld": "nl",
+    "pan": "pa",
+    "pol": "pl",
+    "ron": "ro",
+    "rus": "ru",
+    "snd": "sd",
+    "slk": "sk",
+    "spa": "es",
+    "srp": "sr",
+    "swh": "sw",
+    "tam": "ta",
+    "tel": "te",
+    "tha": "th",
+    "tur": "tr",
+    "ukr": "uk",
+    "urd": "ur",
+    "uzn": "uz",
+    "vie": "vi",
+    "yor": "yo",
+    "zul": "zu",
+}
+
+
+def _check_lang_code_mapping(lang: str) -> None:
+    if lang not in UNITY_TO_COMMON_VOICE_LANG_MAPPING:
+        raise ValueError(
+            f"No language code mapping for {lang}(M4T)->??(CV). "
+            "Please expand `UNITY_TO_COMMON_VOICE_LANG_MAPPING`"
+        )
+
+
+class UnitSpeechTokenizer(SpeechTokenizer):
+    MODEL_NAME = "xlsr2_1b_v2"
+    KMEANS_MODEL_URI = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy"
+    OUTPUT_LAYER_IDX = 34
+
+    def __init__(self, device: torch.device):
+        super().__init__()
+        self.device = device
+        self.unit_extractor = UnitExtractor(
+            model_name_or_card=self.MODEL_NAME,
+            kmeans_uri=self.KMEANS_MODEL_URI,
+            device=self.device,
+        )
+
+    def encode(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
+        return self.unit_extractor.predict(
+            wav.to(self.device),
+            out_layer_idx=self.OUTPUT_LAYER_IDX,
+            sample_rate=sample_rate,
+        )
+
+
+def download_common_voice(
+    lang: str, split: str, save_directory: str, max_samples: int
+) -> None:
+    _check_lang_code_mapping(lang)
+    mozilla_lang = UNITY_TO_COMMON_VOICE_LANG_MAPPING[lang]
+    dataset = load_dataset(
+        "mozilla-foundation/common_voice_17_0",
+        mozilla_lang,
+        split=split,
+        token=os.environ.get("HF_TOKEN"),
+        streaming=True,
+    )
+    audio_dir = os.path.join(save_directory, "audio")
+    if not os.path.exists(audio_dir):
+        os.makedirs(audio_dir)
+    manifest_path: str = os.path.join(save_directory, f"{split}_{lang}_manifest.json")
+    with open(manifest_path, "w") as fp_out:
+        for idx, sample in enumerate(dataset, start=1):
+            wav = torch.from_numpy(sample["audio"]["array"]).unsqueeze(0)
+            logger.info(f"WAV SHAPE {wav.shape}")
+            sampling_rate = sample["audio"]["sampling_rate"]
+            audio_path = (
+                split
+                + "_"
+                + os.path.basename(sample["audio"]["path"]).split(".")[0]
+                + ".wav"
+            )
+            audio_path = os.path.join(audio_dir, audio_path)
+            target_sr = 16000
+            wav = torchaudio.functional.resample(
+                wav, orig_freq=sampling_rate, new_freq=target_sr
+            )
+            torchaudio.save(audio_path, wav, target_sr)
+            sample = MultimodalSample(
+                id=idx, lang=lang, text=sample["sentence"], audio_local_path=audio_path
+            )
+            sample = LangPairSample(sample, sample)
+            fp_out.write(json.dumps(dataclasses.asdict(sample)) + "\n")
+            fp_out.flush()
+            if idx == max_samples:
+                break
+    logger.info(f"Saved {idx} samples for split={split} to {manifest_path}")
+    logger.info(f"Manifest saved to: {manifest_path}")
+
+
+def init_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(
+        description=(
+            "Helper script to download training/evaluation dataset (Common Voice),"
+            "extract units from target audio and save the dataset as a manifest "
+            "consumable by `finetune.py`."
+        )
+    )
+    parser.add_argument(
+        "--lang",
+        type=str,
+        required=True,
+        help="Language of the dataset",
+    )
+    parser.add_argument(
+        "--split",
+        type=str,
+        required=True,
+        help="Dataset split/shard to download (`train`, `validation`, `test`)",
+    )
+    parser.add_argument(
+        "--save_dir",
+        type=Path,
+        required=True,
+        help="Directory where the datasets will be stored with HuggingFace datasets cache files",
+    )
+    parser.add_argument(
+        "--max_samples",
+        type=int,
+        default=1000,
+        help="Max samples to fetch",
+    )
+    return parser
+
+
+def main() -> None:
+    args = init_parser().parse_args()
+    download_common_voice(
+        lang=args.lang,
+        split=args.split,
+        save_directory=args.save_dir,
+        max_samples=args.max_samples,
+    )
+
+
+if __name__ == "__main__":
+    main()

+ 31 - 0
src/seamless_communication/cli/m4t/classification_head/model.py

@@ -0,0 +1,31 @@
+import torch
+from torch import nn
+
+
+class ClassificationHead(nn.Module):
+    def __init__(self, embed_dim: int, n_layers: int = 3, n_classes: int = 5, n_heads: int = 16):
+        super(ClassificationHead, self).__init__()
+        self.num_languages = n_classes
+        self.num_layers = n_layers
+
+        self.attn = nn.MultiheadAttention(embed_dim, num_heads=n_heads)
+        self.layers = nn.ModuleList(
+            [
+                nn.Sequential(
+                    nn.Linear(embed_dim, embed_dim),
+                    nn.BatchNorm1d(embed_dim),  # normalize batch
+                    nn.ReLU(),  # activation function
+                    nn.Dropout(0.1),  # prevent overfitting
+                )
+                for _ in range(n_layers)
+            ]
+            + [nn.Linear(embed_dim, n_classes)]
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        # (Batch, Seq, Embed)
+        x, _ = self.attn(x, x, x)
+        x = x[:, 0]
+        for layer in self.layers:
+            x = layer(x)
+        return x

+ 369 - 0
src/seamless_communication/cli/m4t/classification_head/train.py

@@ -0,0 +1,369 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import List, Optional
+from dataclasses import dataclass
+import matplotlib.pyplot as plt
+import pickle
+
+import torch
+
+from torch.optim import AdamW
+from fairseq2.optim.lr_scheduler import MyleLR
+from fairseq2.nn.padding import PaddingMask
+
+from seamless_communication.cli.m4t.classification_head import dataloader
+from seamless_communication.models.unity import UnitYModel
+from seamless_communication.models.unity import (
+    load_unity_model,
+    load_unity_text_tokenizer,
+    load_unity_unit_tokenizer,
+)
+from seamless_communication.cli.m4t.classification_head.model import ClassificationHead
+
+logging.basicConfig(
+    level=logging.INFO,
+    format=f"%(asctime)s %(levelname)s -- %(name)s.{os.getpid()}: %(message)s",
+)
+
+logger = logging.getLogger("train")
+
+
+@dataclass
+class ClassificationHeadTrainParams:
+    save_model_path: Path
+
+    float_dtype: torch.dtype
+
+    max_epochs: int = 10
+    """Maximum number of trainign epochs"""
+
+    warmup_steps: int = 100
+    """Number of steps with linearly increasing LR"""
+
+    learning_rate: float = 1e-5
+    """Optimizer learining rate"""
+
+    batch_size: int = 100
+    """The batch size during train steps"""
+
+    device: torch.device = torch.device("cuda")
+    """Where to run computation"""
+
+
+def init_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(
+        description="Example finetuning script for M4T models"
+    )
+    parser.add_argument(
+        "--train_dataset",
+        type=Path,
+        required=True,
+        help="Path to manifest with train samples",
+    )
+    parser.add_argument(
+        "--eval_dataset",
+        type=Path,
+        required=True,
+        help="Path to manifest with eval samples",
+    )
+    parser.add_argument(
+        "--model_name",
+        type=str,
+        default="seamlessM4T_medium",
+        help="Base model name (`seamlessM4T_medium`, `seamlessM4T_large`)",
+    )
+    parser.add_argument("--num_languages", type=int, help="The number of classes")
+    parser.add_argument(
+        "--save_model_path",
+        type=Path,
+        default="/tmp/",
+        help="Path to save best finetuned model",
+    )
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=2343,
+        help="Randomizer seed value",
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        default=10,
+        help="Batch size for training",
+    )
+    parser.add_argument(
+        "--eval_batch_size",
+        type=int,
+        default=50,
+        help="Batch size for evaluation",
+    )
+    parser.add_argument(
+        "--patience",
+        type=int,
+        default=3,
+        help=(
+            "Set early termination after `patience` number of evaluations "
+            "without eval loss improvements"
+        ),
+    )
+    parser.add_argument(
+        "--max_epochs",
+        type=int,
+        default=10,
+        help=("Max number of training epochs"),
+    )
+    parser.add_argument(
+        "--learning_rate",
+        type=float,
+        default=1e-4,
+        help=("Finetuning learning rate"),
+    )
+    parser.add_argument(
+        "--label_smoothing",
+        type=float,
+        default=0.1,
+        help=("Label smoothing"),
+    )
+    parser.add_argument(
+        "--warmup_steps",
+        type=int,
+        default=100,
+        help=("Number of steps with linearly increasing learning rate"),
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda",
+        help=("Device to fine-tune on. See `torch.device`."),
+    )
+    parser.add_argument(
+        "--num_layers",
+        type=int,
+        default=2,
+        help="The number of layers in the classification head",
+    )
+    return parser
+
+
+def plot_losslog(
+    losslog: List[float], save_to: Optional[Path] = None, yscale: str = "log"
+) -> None:
+    # TODO: Make this look good
+    plt.plot(losslog)
+    plt.yscale(yscale)
+    plt.title("Training Loss")
+    plt.xlabel("Batch")
+    plt.ylabel("Loss")
+    if save_to:
+        plt.savefig(save_to)
+        plt.clf()
+        with open(save_to.parent / "losslog.pkl", "wb") as f:
+            pickle.dump(losslog, f)
+    else:
+        plt.show()
+
+
+@torch.no_grad()
+def eval(
+    head: torch.nn.Module,
+    frozen_model: UnitYModel,
+    dataloader: dataloader.UnitYLanguageIDDataLoader,
+    params: ClassificationHeadTrainParams,
+) -> float:
+    head.eval()
+    frozen_model.eval()
+    losses = []
+    for batch_idx, (seqs, labels) in enumerate(dataloader.get_dataloader()):
+        assert seqs.src_tokens is not None
+        with torch.autocast(device_type=params.device.type, dtype=params.float_dtype):
+            mask = PaddingMask(seqs.src_lengths, seqs.src_tokens.size(1)).to(
+                params.device
+            )
+            vector, _ = frozen_model.encode(
+                seqs.src_tokens.to(params.device), padding_mask=mask.to(params.device)
+            )
+            logits = head(vector)
+        loss = torch.nn.functional.cross_entropy(
+            logits,
+            labels.to(params.device),
+            label_smoothing=0.1,
+        ) / labels.size(0)
+        losses.append(loss.item())
+        # TODO: remove
+        if batch_idx > 10:
+            break
+    return sum(losses) / len(losses)  # type: ignore
+
+
+def train(
+    head: torch.nn.Module,
+    frozen_model: UnitYModel,
+    dataloader: dataloader.UnitYLanguageIDDataLoader,
+    eval_dataloader: dataloader.UnitYLanguageIDDataLoader,
+    params: ClassificationHeadTrainParams,
+    label_smoothing: float = 0.1,
+    label_weights: Optional[torch.Tensor] = None,
+) -> torch.nn.Module:
+
+    head.train()
+    frozen_model.train()
+    grad_scaler = torch.cuda.amp.GradScaler()
+    optimizer = AdamW(
+        params=head.parameters(),
+        lr=params.learning_rate,
+        betas=(0.9, 0.98),
+        eps=1e-08,
+        maximize=False,
+        weight_decay=0.0,
+        fused=(params.device.type == "cuda"),
+    )
+    lr_scheduler = MyleLR(
+        optimizer=optimizer, num_warmup_steps=params.warmup_steps, start_lr=1e-9
+    )
+    loss_vals = []
+    try:
+        for epoch in range(params.max_epochs):
+            # Run batches through train step
+            for update_idx, (seqs, labels) in enumerate(dataloader.get_dataloader()):
+                assert seqs.src_tokens is not None
+                optimizer.zero_grad()
+                seqs.src_tokens = seqs.src_tokens.to(params.device)
+                labels = labels.to(params.device)
+
+                with torch.autocast(
+                    device_type=params.device.type, dtype=params.float_dtype
+                ):
+                    mask = PaddingMask(seqs.src_lengths, seqs.src_tokens.size(1)).to(
+                        params.device
+                    )
+                    vector, _ = frozen_model.encode(seqs.src_tokens, padding_mask=mask)
+                    logits = head(vector)
+
+                loss = torch.nn.functional.cross_entropy(
+                    logits, labels, label_smoothing=0.1
+                ) / labels.size(0)
+                if loss.isnan().any().item():
+                    logger.error(seqs)
+                    logger.error(labels)
+                    raise RuntimeError(
+                        "Train loss is NaN! Something is wrong in the model!"
+                    )
+                loss_vals.append(loss.item())
+                if update_idx % 100 == 0:
+                    eval_loss = eval(
+                        head=head,
+                        frozen_model=frozen_model,
+                        dataloader=eval_dataloader,
+                        params=params,
+                    )
+                    logger.info(
+                        f" .. epoch={epoch}, "
+                        f"update={update_idx}, "
+                        f"avg_train_loss={(sum(loss_vals) / len(loss_vals)):.3f}, "
+                        f"eval_loss={eval_loss:.3f}"
+                    )
+                    loss_vals = []
+
+                grad_scaler.scale(loss).backward()
+                grad_scaler.step(optimizer)
+                grad_scaler.update()
+                lr_scheduler.step()
+
+    # Catch SIGINT (^C) keyboard interrupt, and save model before terminating
+    except KeyboardInterrupt:
+        logger.info("[SIGINT] Saving optimizer state. Exiting cleanly...")
+        torch.save(
+            optimizer.state_dict(),
+            params.save_model_path.parent / "optimizer_state.pth",
+        )
+    return head
+
+
+def main() -> None:
+    args = init_parser().parse_args()
+    device = torch.device(args.device)
+    float_dtype = (
+        torch.float16 if torch.device(args.device).type != "cpu" else torch.bfloat16
+    )
+
+    text_tokenizer = load_unity_text_tokenizer(args.model_name)
+    unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
+
+    # Freeze everything in the model, only train classification head
+    model = load_unity_model(
+        args.model_name, device=torch.device("cpu"), dtype=torch.float32
+    )
+    model.train()
+    for _, module in model.named_modules():
+        for param in module.parameters():
+            param.requires_grad = False
+
+    head = ClassificationHead(
+        embed_dim=model.model_dim,
+        n_layers=args.num_layers,
+        n_classes=args.num_languages,
+    )
+    head.train()
+
+    assert model.target_vocab_info == text_tokenizer.vocab_info
+    if model.text_encoder is not None:
+        model.text_encoder = None
+
+    # Put model on selected device
+    model = model.to(device)
+    head = head.to(device)
+
+    # Create daataloaders
+    train_dataloader = dataloader.UnitYLanguageIDDataLoader(
+        num_languages=args.num_languages,
+        text_tokenizer=text_tokenizer,
+        unit_tokenizer=unit_tokenizer,
+        batching_config=dataloader.BatchingConfig(
+            batch_size=args.batch_size,
+            max_audio_length_sec=15.0,
+            float_dtype=float_dtype,
+        ),
+        dataset_manifest_path=args.train_dataset,
+    )
+    eval_dataloader = dataloader.UnitYLanguageIDDataLoader(
+        num_languages=args.num_languages,
+        text_tokenizer=text_tokenizer,
+        unit_tokenizer=unit_tokenizer,
+        batching_config=dataloader.BatchingConfig(
+            batch_size=args.eval_batch_size,
+            max_audio_length_sec=100.0,
+            float_dtype=float_dtype,
+        ),
+        dataset_manifest_path=args.eval_dataset,
+    )
+
+    trained_head = train(
+        head=head,
+        frozen_model=model,
+        dataloader=train_dataloader,
+        eval_dataloader=eval_dataloader,
+        label_smoothing=args.label_smoothing,
+        params=ClassificationHeadTrainParams(
+            save_model_path=Path(args.save_model_path),
+            float_dtype=float_dtype,
+            max_epochs=args.max_epochs,
+            warmup_steps=args.warmup_steps,
+            learning_rate=args.learning_rate,
+            batch_size=args.batch_size,
+            device=device,
+        ),
+    )
+
+    torch.save(trained_head.state_dict(), args.save_model_path)
+
+
+if __name__ == "__main__":
+    main()