浏览代码

Fix all mypy issues in streaming, and some minor bugs in tree pipeline. (#147)

* Fix all mypy issues in streaming, and some minor bugs in tree pipeline.

* Revert changes to dev_requirements.txt, setup.py
Kaushik Ram Sadagopan 1 年之前
父节点
当前提交
e568857c64

+ 1 - 1
src/seamless_communication/cli/streaming/scorers/seamless_whisper_asr_bleu.py

@@ -30,7 +30,7 @@ def normalize_text_whisper(sentences: List[str], lang: str) -> List[str]:
 
 
 
 
 @register_quality_scorer("SEAMLESS_WHISPER_ASR_BLEU")
 @register_quality_scorer("SEAMLESS_WHISPER_ASR_BLEU")
-class SeamlessWhisperASRSacreBLEUScorer(WhisperASRSacreBLEUScorer):
+class SeamlessWhisperASRSacreBLEUScorer(WhisperASRSacreBLEUScorer):  # type: ignore
     def __init__(
     def __init__(
         self,
         self,
         tokenizer: str = "13a",
         tokenizer: str = "13a",

+ 1 - 1
src/seamless_communication/streaming/agents/common.py

@@ -21,7 +21,7 @@ class EarlyStoppingMixin:
         raise NotImplementedError()
         raise NotImplementedError()
 
 
 
 
-class AgentStates(AgentStatesOrig):
+class AgentStates(AgentStatesOrig):  # type: ignore
     def update_target(self, segment: Segment) -> None:
     def update_target(self, segment: Segment) -> None:
         """An AgentStates impl which doesn't update states.target"""
         """An AgentStates impl which doesn't update states.target"""
         self.target_finished = segment.finished
         self.target_finished = segment.finished

+ 1 - 1
src/seamless_communication/streaming/agents/detokenizer.py

@@ -16,7 +16,7 @@ from seamless_communication.streaming.agents.common import (
 )
 )
 
 
 
 
-class DetokenizerAgent(NoUpdateTargetMixin, TextToTextAgent):
+class DetokenizerAgent(NoUpdateTargetMixin, TextToTextAgent):  # type: ignore
     def __init__(self, args: Namespace):
     def __init__(self, args: Namespace):
         super().__init__(args)
         super().__init__(args)
         self.detokenize_only = args.detokenize_only
         self.detokenize_only = args.detokenize_only

+ 1 - 1
src/seamless_communication/streaming/agents/offline_w2v_bert_encoder.py

@@ -24,7 +24,7 @@ from seamless_communication.streaming.agents.common import (
 )
 )
 
 
 
 
-class OfflineWav2VecBertEncoderAgent(NoUpdateTargetMixin, SpeechToSpeechAgent):
+class OfflineWav2VecBertEncoderAgent(NoUpdateTargetMixin, SpeechToSpeechAgent):  # type: ignore
     """
     """
     Incremental encoding of an wav2vec encoder output
     Incremental encoding of an wav2vec encoder output
     It update the whole encoder states every time when there is a new incoming segment.
     It update the whole encoder states every time when there is a new incoming segment.

+ 2 - 2
src/seamless_communication/streaming/agents/online_feature_extractor.py

@@ -26,7 +26,7 @@ SAMPLE_RATE = 16000
 FEATURE_DIM = 80
 FEATURE_DIM = 80
 
 
 
 
-class FeatureStates(AgentStates):
+class FeatureStates(AgentStates):  # type: ignore
     def reset(self) -> None:
     def reset(self) -> None:
         super().reset()
         super().reset()
         self.previous_residual_samples: List[float] = []
         self.previous_residual_samples: List[float] = []
@@ -45,7 +45,7 @@ class FeatureStates(AgentStates):
             self.source.append(segment.content)
             self.source.append(segment.content)
 
 
 
 
-class OnlineFeatureExtractorAgent(SpeechToSpeechAgent):
+class OnlineFeatureExtractorAgent(SpeechToSpeechAgent):  # type: ignore
     """
     """
     Extract speech features on the fly.
     Extract speech features on the fly.
     """
     """

+ 7 - 5
src/seamless_communication/streaming/agents/online_text_decoder.py

@@ -23,7 +23,7 @@ from simuleval.data.segments import Segment, TextSegment
 from torch import Tensor
 from torch import Tensor
 
 
 
 
-class DecoderAgentStates(AgentStates):
+class DecoderAgentStates(AgentStates):  # type: ignore
     def reset(self) -> None:
     def reset(self) -> None:
         self.source_len = 0
         self.source_len = 0
         self.target_indices: List[int] = []
         self.target_indices: List[int] = []
@@ -50,7 +50,7 @@ class DecoderAgentStates(AgentStates):
             self.source_len = self.source.size(1)
             self.source_len = self.source.size(1)
 
 
 
 
-class OnlineTextDecoderAgent(GenericAgent):
+class OnlineTextDecoderAgent(GenericAgent):  # type: ignore
     """
     """
     Online text decoder
     Online text decoder
     """
     """
@@ -139,7 +139,7 @@ class OnlineTextDecoderAgent(GenericAgent):
             self.prefix_indices[-1] = tgt_lang_tag_idx
             self.prefix_indices[-1] = tgt_lang_tag_idx
 
 
 
 
-class MMATextDecoderAgent(OnlineTextDecoderAgent):
+class MMATextDecoderAgent(OnlineTextDecoderAgent):  # type: ignore
     def __init__(
     def __init__(
         self,
         self,
         model: MonotonicDecoderModel,
         model: MonotonicDecoderModel,
@@ -278,15 +278,17 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
         states: DecoderAgentStates,
         states: DecoderAgentStates,
         pred_indices: List[int],
         pred_indices: List[int],
         decoder_features_out: Tensor,
         decoder_features_out: Tensor,
-        blocked_ngrams: Set[str],
+        blocked_ngrams: Optional[Set[str]],
         index: int,
         index: int,
-    ) -> bool:
+    ) -> Tuple[bool, Tensor]:
         """
         """
         This check is used to force a READ decision when n-gram repeat
         This check is used to force a READ decision when n-gram repeat
         happens before source_finished
         happens before source_finished
         """
         """
         if not self.block_ngrams or states.source_finished:
         if not self.block_ngrams or states.source_finished:
             return False, decoder_features_out
             return False, decoder_features_out
+
+        assert blocked_ngrams is not None
         all_indices = states.target_indices + pred_indices + [index]
         all_indices = states.target_indices + pred_indices + [index]
         for n in [3, 2]:  # TODO: make it configurable
         for n in [3, 2]:  # TODO: make it configurable
             if len(all_indices) >= n and states.ngram_block_count <= 4:
             if len(all_indices) >= n and states.ngram_block_count <= 4:

+ 2 - 2
src/seamless_communication/streaming/agents/online_unit_decoder.py

@@ -20,7 +20,7 @@ from simuleval.agents.states import AgentStates
 from simuleval.data.segments import Segment, TextSegment
 from simuleval.data.segments import Segment, TextSegment
 
 
 
 
-class NARUnitDecoderAgentStates(AgentStates):
+class NARUnitDecoderAgentStates(AgentStates):  # type: ignore
     def reset(self) -> None:
     def reset(self) -> None:
         self.source_token_list: List[str] = []
         self.source_token_list: List[str] = []
         self.source_indices: Optional[torch.Tensor] = None
         self.source_indices: Optional[torch.Tensor] = None
@@ -51,7 +51,7 @@ class NARUnitDecoderAgentStates(AgentStates):
         self.source = content
         self.source = content
 
 
 
 
-class NARUnitYUnitDecoderAgent(GenericAgent):
+class NARUnitYUnitDecoderAgent(GenericAgent):  # type: ignore
     """Non-autoregressive unit decoder"""
     """Non-autoregressive unit decoder"""
 
 
     source_type = "text"
     source_type = "text"

+ 1 - 1
src/seamless_communication/streaming/agents/online_vocoder.py

@@ -14,7 +14,7 @@ from simuleval.agents.actions import ReadAction, WriteAction
 from simuleval.data.segments import SpeechSegment
 from simuleval.data.segments import SpeechSegment
 
 
 
 
-class VocoderAgent(TextToSpeechAgent):
+class VocoderAgent(TextToSpeechAgent):  # type: ignore
     def __init__(self, vocoder: Vocoder, args: Namespace) -> None:
     def __init__(self, vocoder: Vocoder, args: Namespace) -> None:
         super().__init__(args)
         super().__init__(args)
         self.sample_rate = args.sample_rate
         self.sample_rate = args.sample_rate

+ 2 - 2
src/seamless_communication/streaming/agents/silero_vad.py

@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
 SPEECH_PROB_THRESHOLD = 0.6
 SPEECH_PROB_THRESHOLD = 0.6
 
 
 
 
-class SileroVADStates(EarlyStoppingMixin, AgentStates):
+class SileroVADStates(EarlyStoppingMixin, AgentStates):  # type: ignore
     def __init__(self, args: Namespace) -> None:
     def __init__(self, args: Namespace) -> None:
         self.model, utils = torch.hub.load(
         self.model, utils = torch.hub.load(
             repo_or_dir="snakers4/silero-vad",
             repo_or_dir="snakers4/silero-vad",
@@ -253,7 +253,7 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):
                 )
                 )
 
 
 
 
-class SileroVADAgent(SpeechToSpeechAgent):
+class SileroVADAgent(SpeechToSpeechAgent):  # type: ignore
     def __init__(self, args: Namespace) -> None:
     def __init__(self, args: Namespace) -> None:
         super().__init__(args)
         super().__init__(args)
         self.chunk_size_samples = args.chunk_size_samples
         self.chunk_size_samples = args.chunk_size_samples

+ 6 - 6
src/seamless_communication/streaming/agents/unity_pipeline.py

@@ -57,7 +57,7 @@ class UnitYPipelineMixin:
 
 
     @classmethod
     @classmethod
     def add_args(cls, parser: ArgumentParser) -> None:
     def add_args(cls, parser: ArgumentParser) -> None:
-        super().add_args(parser)
+        super().add_args(parser)  # type: ignore
         parser.add_argument("--task", type=str, help="Task type")
         parser.add_argument("--task", type=str, help="Task type")
         parser.add_argument(
         parser.add_argument(
             "--unity-model-name",
             "--unity-model-name",
@@ -157,7 +157,7 @@ class UnitYPipelineMixin:
         }
         }
 
 
 
 
-class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
+class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):  # type: ignore
     pipeline: List[GenericAgent] = []
     pipeline: List[GenericAgent] = []
 
 
     def __init__(self, args: Namespace):
     def __init__(self, args: Namespace):
@@ -199,8 +199,8 @@ class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
         return cls(args)
         return cls(args)
 
 
 
 
-class UnitYAgentTreePipeline(UnitYPipelineMixin, TreeAgentPipeline):
-    pipeline = {}
+class UnitYAgentTreePipeline(UnitYPipelineMixin, TreeAgentPipeline):  # type: ignore
+    pipeline: Any = {}
 
 
     def __init__(self, args: Namespace):
     def __init__(self, args: Namespace):
         models_and_configs = self.load_model(args)
         models_and_configs = self.load_model(args)
@@ -231,10 +231,10 @@ class UnitYAgentTreePipeline(UnitYPipelineMixin, TreeAgentPipeline):
             # An early stop.
             # An early stop.
             # The temporary solution is to start over
             # The temporary solution is to start over
             if states is not None:
             if states is not None:
-                maybe_reset_states(states.values())
+                maybe_reset_states(states)
             else:
             else:
                 self.reset()
                 self.reset()
             for segment in output_segment:
             for segment in output_segment:
                 segment.finished = False
                 segment.finished = False
 
 
-        return output_segment
+        return output_segment  # type: ignore[no-any-return]

+ 1 - 1
src/seamless_communication/streaming/dataloaders/s2tt.py

@@ -38,7 +38,7 @@ def count_lines(filename: Path) -> int:
 
 
 
 
 @register_dataloader("fairseq2_s2tt")
 @register_dataloader("fairseq2_s2tt")
-class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader):
+class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader):  # type: ignore
     def __init__(self, data_pipeline: DataPipeline, args: Namespace) -> None:
     def __init__(self, data_pipeline: DataPipeline, args: Namespace) -> None:
         self.args = args
         self.args = args
         self.data_file: Path = Path(getattr(self.args, "data_file", ""))
         self.data_file: Path = Path(getattr(self.args, "data_file", ""))