Jelajahi Sumber

Various nit fixes (#138)

Can Balioglu 1 tahun lalu
induk
melakukan
f9608d8cbd

+ 1 - 0
dev_requirements.txt

@@ -1,3 +1,4 @@
+audiocraft
 black
 flake8
 isort

+ 6 - 1
pyproject.toml

@@ -27,6 +27,11 @@ warn_unused_ignores = false
 minversion = "7.1"
 testpaths = ["tests"]
 filterwarnings = [
-    "ignore:torch.nn.utils.weight_norm is deprecated in favor of",
+    "ignore:Deprecated call to `pkg_resources",
+    "ignore:Please use `line_search_wolfe",
+    "ignore:Please use `spmatrix",
     "ignore:TypedStorage is deprecated",
+    "ignore:distutils Version classes are deprecated",
+    "ignore:pkg_resources is deprecated",
+    "ignore:torch.nn.utils.weight_norm is deprecated in favor of",
 ]

+ 2 - 2
src/seamless_communication/models/tokenizer.py

@@ -13,7 +13,7 @@ from fairseq2.data.text import (
     TextTokenDecoder,
     TextTokenEncoder,
     TextTokenizer,
-    vocabulary_from_sentencepiece,
+    vocab_info_from_sentencepiece,
 )
 from fairseq2.data.typing import PathLike
 from fairseq2.typing import Device, finaloverride
@@ -47,7 +47,7 @@ class SPMTokenizer(TextTokenizer):
         # Each language is represented by a `__lang__` control symbol.
         control_symbols = [self._lang_tok_to_internal(lang) for lang in sorted(langs)]
         self.model = SentencePieceModel(pathname, control_symbols)
-        vocab_info = vocabulary_from_sentencepiece(self.model)
+        vocab_info = vocab_info_from_sentencepiece(self.model)
         super().__init__(vocab_info)
 
     @classmethod

+ 2 - 2
src/seamless_communication/models/unity/char_tokenizer.py

@@ -20,7 +20,7 @@ from fairseq2.data.text import (
     TextTokenDecoder,
     TextTokenEncoder,
     TextTokenizer,
-    vocabulary_from_sentencepiece,
+    vocab_info_from_sentencepiece,
 )
 from fairseq2.data.typing import PathLike
 from fairseq2.typing import Device, finaloverride
@@ -39,7 +39,7 @@ class CharTokenizer(TextTokenizer):
         """
         self.model = SentencePieceModel(pathname)
 
-        vocab_info = vocabulary_from_sentencepiece(self.model)
+        vocab_info = vocab_info_from_sentencepiece(self.model)
 
         super().__init__(vocab_info)
 

+ 4 - 4
src/seamless_communication/streaming/agents/online_text_decoder.py

@@ -367,12 +367,12 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
                 break
 
             pred_indices.append(index)
-            if self.state_bag.step == 0:
-                self.state_bag.increment_step(
+            if self.state_bag.step_nr == 0:
+                self.state_bag.increment_step_nr(
                     len(self.prefix_indices + states.target_indices)
                 )
             else:
-                self.state_bag.increment_step()
+                self.state_bag.increment_step_nr()
 
         states.target_indices += pred_indices
 
@@ -426,7 +426,7 @@ class UnitYMMATextDecoderAgent(MMASpeechToTextDecoderAgent):
             # TODO: a temporary solution.
             ending_token_index = self.text_tokenizer.model.token_to_index(",")
             token_list.append(ending_token_index)
-            self.state_bag.increment_step()
+            self.state_bag.increment_step_nr()
 
             _, _, decoder_features = self.run_decoder(states, [ending_token_index])
             decoder_features_out = torch.cat(

+ 6 - 9
tests/integration/models/test_watermarked_vocoder.py

@@ -5,22 +5,19 @@
 # LICENSE file in the root directory of this source tree.
 
 import sys
-from typing import cast, List, Final, Optional
-from anyio import Path
+from pathlib import Path
+from typing import Final, List, Optional, cast
+
 import torch
-from fairseq2.typing import Device
 from fairseq2.data import Collater, SequenceData
 from fairseq2.data.audio import AudioDecoderOutput
+from fairseq2.typing import Device
 from torch.nn import Module
 
 from seamless_communication.inference.pretssel_generator import PretsselGenerator
-from seamless_communication.models.unity.loader import load_gcmvn_stats
 from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
-from tests.common import (
-    assert_close,
-    convert_to_collated_fbank,
-)
-
+from seamless_communication.models.unity.loader import load_gcmvn_stats
+from tests.common import assert_close, convert_to_collated_fbank
 
 N_MEL_BINS = 80