Browse Source

Make unit_extractor configurable by dtype. (#128)

Kaushik Ram Sadagopan 1 year ago
parent
commit
5a2d61655f

+ 12 - 20
src/seamless_communication/inference/translator.py

@@ -7,7 +7,7 @@ import logging
 from dataclasses import dataclass
 from enum import Enum, auto
 from pathlib import Path
-from typing import Callable, List, Optional, Tuple, Union, cast
+from typing import List, Optional, Tuple, Union, cast
 
 import torch
 import torch.nn as nn
@@ -39,6 +39,7 @@ from seamless_communication.models.unity import (
 )
 from seamless_communication.models.vocoder import load_vocoder_model
 from seamless_communication.toxicity import (
+    BadWordChecker,
     load_bad_word_checker,
 )
 from seamless_communication.toxicity.mintox import mintox_pipeline
@@ -110,9 +111,9 @@ class Translator(nn.Module):
         # Load the model.
         if device == torch.device("cpu"):
             dtype = torch.float32
-        self.model = self.load_model_for_inference(
-            load_unity_model, model_name_or_card, device, dtype
-        )
+
+        self.model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
+        self.model.eval()
         assert isinstance(self.model, UnitYModel)
 
         if text_tokenizer is None:
@@ -126,10 +127,9 @@ class Translator(nn.Module):
         if self.model.t2u_model is not None:
             self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
 
+        self.bad_word_checker: Optional[BadWordChecker] = None
         if apply_mintox:
             self.bad_word_checker = load_bad_word_checker("mintox")
-        else:
-            self.bad_word_checker = None
 
         self.apply_mintox = apply_mintox
 
@@ -150,20 +150,10 @@ class Translator(nn.Module):
         if vocoder_name_or_card is not None and (
             output_modality is None or output_modality == Modality.SPEECH
         ):
-            self.vocoder = self.load_model_for_inference(
-                load_vocoder_model, vocoder_name_or_card, device, torch.float32
+            self.vocoder = load_vocoder_model(
+                vocoder_name_or_card, device=device, dtype=torch.float32
             )
-
-    @staticmethod
-    def load_model_for_inference(
-        load_model_fn: Callable[..., nn.Module],
-        model_name_or_card: Union[str, AssetCard],
-        device: Device,
-        dtype: DataType,
-    ) -> nn.Module:
-        model = load_model_fn(model_name_or_card, device=device, dtype=dtype)
-        model.eval()
-        return model
+            self.vocoder.eval()
 
     @classmethod
     def get_prediction(
@@ -272,7 +262,9 @@ class Translator(nn.Module):
         input_modality, output_modality = self.get_modalities_from_task_str(task_str)
 
         if self.apply_mintox and src_lang is None:
-            raise ValueError("`src_lang` must be specified when `apply_mintox` is `True`.")
+            raise ValueError(
+                "`src_lang` must be specified when `apply_mintox` is `True`."
+            )
 
         if isinstance(input, dict):
             src = cast(SequenceData, input)

+ 3 - 4
src/seamless_communication/models/unit_extractor/kmeans.py

@@ -7,19 +7,18 @@
 import numpy as np
 import torch
 from fairseq2.assets import download_manager
-from fairseq2.typing import Device
+from fairseq2.typing import DataType, Device
 from torch import Tensor, nn
 
 
 class KmeansModel(nn.Module):
-    def __init__(self, kmeans_uri: str, device: Device):
+    def __init__(self, kmeans_uri: str, device: Device, dtype: DataType):
         super().__init__()
         km_path = download_manager.download_checkpoint(kmeans_uri, kmeans_uri)
         km_model = np.load(km_path)
         centroids_numpy = km_model.transpose()
         centroids = torch.from_numpy(centroids_numpy)
-
-        self.centroids = centroids.to(device)
+        self.centroids = centroids.to(device=device, dtype=dtype)
         self.centroid_norm = (self.centroids**2).sum(0, keepdims=True)
 
     def forward(self, x: Tensor) -> Tensor:

+ 10 - 9
src/seamless_communication/models/unit_extractor/unit_extractor.py

@@ -21,7 +21,6 @@ from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.typing import DataType, Device
 from torch import Tensor, nn
 
-from seamless_communication.inference import Translator
 from seamless_communication.models.unit_extractor.kmeans import KmeansModel
 from seamless_communication.models.unit_extractor.wav2vec2_layer_output import (
     Wav2Vec2LayerOutputModel,
@@ -47,15 +46,18 @@ class UnitExtractor(nn.Module):
         dtype: DataType = torch.float32,
     ):
         super().__init__()
-        wav2vec2_model = Translator.load_model_for_inference(
-            load_wav2vec2_model, model_name_or_card, device, dtype
+
+        wav2vec2_model = load_wav2vec2_model(
+            model_name_or_card, device=device, dtype=dtype
         )
+        wav2vec2_model.eval()
         assert isinstance(wav2vec2_model, Wav2Vec2Model)
         self.model = Wav2Vec2LayerOutputModel(wav2vec2_model)
-        self.device = device
         self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
         self.collate = Collater(pad_value=2, pad_to_multiple=2)
-        self.kmeans_model = KmeansModel(kmeans_uri, device)
+        self.kmeans_model = KmeansModel(kmeans_uri, device, dtype)
+        self.device = device
+        self.dtype = dtype
 
     @torch.inference_mode()
     def predict(
@@ -79,7 +81,7 @@ class UnitExtractor(nn.Module):
                 audio = audio.transpose(0, 1)
 
             decoded_audio = {
-                "waveform": audio,
+                "waveform": audio.to(dtype=self.dtype),
                 "sample_rate": sample_rate,
                 "format": -1,
             }
@@ -104,9 +106,8 @@ class UnitExtractor(nn.Module):
 
         reduced_units = reduce_list(units.cpu().tolist())
 
-        vocoder = Translator.load_model_for_inference(
-            load_vocoder_model, vocoder_name, device, torch.float32
-        )
+        vocoder = load_vocoder_model(vocoder_name, device=device, dtype=torch.float32)
+        vocoder.eval()
         assert isinstance(vocoder, Vocoder)
         wav = vocoder(reduced_units, src_lang, spkr=-1, dur_prediction=True)
         return wav  # type: ignore[no-any-return]

+ 10 - 7
src/seamless_communication/toxicity/mintox.py

@@ -4,19 +4,22 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Tuple
 
 from torch import Tensor
 import torch
 from torch.nn import functional as F
 
 
-from seamless_communication.inference.generator import SequenceToUnitOutput, SequenceGeneratorOptions
+from seamless_communication.inference.generator import (
+    SequenceToUnitOutput,
+    SequenceGeneratorOptions,
+)
 from seamless_communication.toxicity.bad_word_checker import (
     BadWordChecker,
 )
 from fairseq2.generation import SequenceToTextOutput, BannedSequenceProcessor
-from fairseq2.data.text.text_tokenizer import TextTokenizer, TextTokenEncoder
+from fairseq2.data.text.text_tokenizer import TextTokenizer
 from fairseq2.data.typing import StringLike
 from fairseq2.typing import Device
 from fairseq2.data import SequenceData
@@ -32,7 +35,7 @@ def _extract_bad_words_with_batch_indices(
     target_texts: List[StringLike],
     source_lang: str,
     target_lang: str,
-    bad_word_checker: BadWordChecker
+    bad_word_checker: BadWordChecker,
 ) -> Tuple[List[str], List[int]]:
     all_bad_words, batch_indices = [], []
 
@@ -139,9 +142,9 @@ def mintox_pipeline(
     text_generation_opts: SequenceGeneratorOptions = SequenceGeneratorOptions(
         beam_size=5, soft_max_seq_len=(1, 200)
     ),
-    unit_generation_opts: Optional[
-        SequenceGeneratorOptions
-    ] = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(25, 50)),
+    unit_generation_opts: Optional[SequenceGeneratorOptions] = SequenceGeneratorOptions(
+        beam_size=5, soft_max_seq_len=(25, 50)
+    ),
     bad_word_checker: BadWordChecker = None,
 ) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
     """MinTox: Mitigation at INference time of added TOXicity."""

+ 7 - 5
tests/integration/models/test_unit_extractor.py

@@ -26,11 +26,6 @@ def test_unit_extractor() -> None:
     dtype = get_default_dtype()
 
     translator = Translator(model_name, "vocoder_v2", device, dtype=dtype)
-    unit_extractor = UnitExtractor(
-        "xlsr2_1b_v2",
-        "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
-        device=device,
-    )
 
     # Generate english speech for the english text.
     _, speech_output = translator.predict(
@@ -41,6 +36,13 @@ def test_unit_extractor() -> None:
     )
     assert speech_output is not None
 
+    unit_extractor = UnitExtractor(
+        "xlsr2_1b_v2",
+        "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
+        device=device,
+        dtype=torch.float32,
+    )
+
     units = unit_extractor.predict(speech_output.audio_wavs[0][0], 34)
 
     assert_equal(units, tensor(REF_ENG_UNITS, device=device, dtype=torch.int64))