offline_w2v_bert_encoder.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from __future__ import annotations
  7. from argparse import ArgumentParser, Namespace
  8. from typing import Any, Dict
  9. import torch
  10. from fairseq2.data import SequenceData
  11. from fairseq2.data.data_pipeline import Collater
  12. from fairseq2.data.text import TextTokenizer
  13. from fairseq2.models.wav2vec2 import Wav2Vec2EncoderConfig
  14. from fairseq2.nn.padding import get_seqs_and_padding_mask
  15. from seamless_communication.models.unity.model import UnitYModel
  16. from simuleval.agents import AgentStates, SpeechToSpeechAgent
  17. from simuleval.agents.actions import Action, ReadAction, WriteAction
  18. from simuleval.data.segments import SpeechSegment
  19. class OfflineWav2VecBertEncoderAgent(SpeechToSpeechAgent):
  20. """
  21. Incremental encoding of an wav2vec encoder output
  22. It update the whole encoder states every time when there is a new incoming segment.
  23. """
  24. def __init__(
  25. self,
  26. unity_model: UnitYModel,
  27. w2v2_encoder_config: Wav2Vec2EncoderConfig,
  28. text_tokenizer: TextTokenizer,
  29. args: Namespace,
  30. ) -> None:
  31. super().__init__(args)
  32. self.model = unity_model
  33. self.w2v2_encoder_config = w2v2_encoder_config
  34. self.collate = Collater(
  35. pad_value=text_tokenizer.vocab_info.pad_idx, pad_to_multiple=2
  36. )
  37. self.device = args.device
  38. self.dtype = args.dtype
  39. self.min_starting_wait = args.min_starting_wait_w2vbert
  40. @property
  41. def min_input_length(self) -> int:
  42. return self.w2v2_encoder_config.fbank_stride
  43. @staticmethod
  44. def add_args(parser: ArgumentParser) -> None:
  45. parser.add_argument(
  46. "--min-starting-wait-w2vbert",
  47. default=None,
  48. type=int,
  49. help="Min starting wait in w2vbert",
  50. )
  51. @torch.inference_mode()
  52. def policy(self, states: AgentStates) -> Action:
  53. """
  54. The policy for encoder is always write
  55. only if the input is too short
  56. """
  57. if len(states.source) < self.min_input_length or (
  58. self.min_starting_wait is not None
  59. and len(states.source) < self.min_starting_wait
  60. ):
  61. if states.source_finished:
  62. return WriteAction({}, finished=states.source_finished)
  63. else:
  64. return ReadAction()
  65. inputs = torch.stack(states.source).to(device=self.device, dtype=self.dtype)
  66. src: SequenceData = self.collate(inputs)
  67. seqs, padding_mask = get_seqs_and_padding_mask(src)
  68. encoder_output, _ = self.model.encode_speech(
  69. seqs,
  70. padding_mask,
  71. )
  72. return WriteAction(
  73. SpeechSegment(
  74. content=encoder_output,
  75. tgt_lang=states.tgt_lang,
  76. finished=states.source_finished,
  77. ),
  78. finished=states.source_finished,
  79. )
  80. @classmethod
  81. def from_args(
  82. cls, args: Namespace, **kwargs: Dict[str, Any]
  83. ) -> OfflineWav2VecBertEncoderAgent:
  84. unity_model = kwargs.get("unity_model", None)
  85. assert isinstance(unity_model, UnitYModel)
  86. unity_config = kwargs.get("unity_config", None)
  87. assert unity_config is not None
  88. text_tokenizer = kwargs.get("text_tokenizer", None)
  89. assert isinstance(text_tokenizer, TextTokenizer)
  90. return cls(unity_model, unity_config.w2v2_encoder_config, text_tokenizer, args)