alignment_extractor.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. import os
  7. from typing import Any, List, Tuple, Union
  8. import numpy
  9. import torch
  10. import torch.nn as nn
  11. import torchaudio
  12. from fairseq2.typing import DataType, Device
  13. from fairseq2.data.typing import StringLike
  14. from torch import Tensor
  15. from seamless_communication.models.aligner.loader import load_unity2_alignment_model
  16. from seamless_communication.models.unit_extractor import UnitExtractor
  17. try:
  18. import matplotlib.pyplot as plt
  19. matplotlib_available = True
  20. except ImportError:
  21. matplotlib_available = False
  22. class AlignmentExtractor(nn.Module):
  23. def __init__(
  24. self,
  25. aligner_model_name_or_card: str,
  26. unit_extractor_model_name_or_card: Union[Any, str] = None,
  27. unit_extractor_output_layer: Union[Any, int] = None,
  28. unit_extractor_kmeans_model_uri: Union[Any, str] = None,
  29. device: Device = Device("cpu"),
  30. dtype: DataType = torch.float32,
  31. ):
  32. super().__init__()
  33. self.device = device
  34. self.dtype = dtype
  35. if self.dtype == torch.float16 and self.device == Device("cpu"):
  36. raise RuntimeError("FP16 only works on GPU, set args accordingly")
  37. self.alignment_model = load_unity2_alignment_model(
  38. aligner_model_name_or_card, device=self.device, dtype=self.dtype
  39. )
  40. self.alignment_model.eval()
  41. self.unit_extractor = None
  42. self.unit_extractor_output_layer = 0
  43. if unit_extractor_model_name_or_card is not None:
  44. self.unit_extractor = UnitExtractor(
  45. unit_extractor_model_name_or_card,
  46. unit_extractor_kmeans_model_uri,
  47. device=device,
  48. dtype=dtype,
  49. )
  50. self.unit_extractor_output_layer = unit_extractor_output_layer
  51. def load_audio(
  52. self, audio_path: str, sampling_rate: int = 16_000
  53. ) -> Tuple[Tensor, int]:
  54. assert os.path.exists(audio_path)
  55. audio, rate = torchaudio.load(audio_path)
  56. if rate != sampling_rate:
  57. audio = torchaudio.functional.resample(audio, rate, sampling_rate)
  58. rate = sampling_rate
  59. return audio, rate
  60. def prepare_audio(self, audio: Union[str, Tensor]) -> Tensor:
  61. # TODO: switch to fairseq2 data pipeline once it supports resampling
  62. if isinstance(audio, str):
  63. audio, _ = self.load_audio(audio, sampling_rate=16_000)
  64. if audio.ndim > 1:
  65. # averaging over channels
  66. assert audio.size(0) < audio.size(
  67. 1
  68. ), "Expected [Channel,Time] shape, but Channel > Time"
  69. audio = audio.mean(0)
  70. assert (
  71. audio.ndim == 1
  72. ), f"After channel averaging audio shape expected to be [Time] i.e. mono audio"
  73. audio = audio.to(self.device, self.dtype)
  74. return audio
  75. def extract_units(self, audio: Tensor) -> Tensor:
  76. assert isinstance(
  77. self.unit_extractor, UnitExtractor
  78. ), "Unit extractor is required to get units from audio tensor"
  79. units = self.unit_extractor.predict(audio, self.unit_extractor_output_layer - 1)
  80. return units
  81. @torch.inference_mode()
  82. def extract_alignment(
  83. self,
  84. audio: Union[str, Tensor],
  85. text: str,
  86. plot: bool = False,
  87. add_trailing_silence: bool = False,
  88. ) -> Tuple[Tensor, Tensor, List[StringLike]]:
  89. if isinstance(audio, Tensor) and not torch.is_floating_point(audio):
  90. # we got units as audio arg
  91. units = audio
  92. units = units.to(self.device)
  93. audio_tensor = None
  94. else:
  95. audio_tensor = self.prepare_audio(audio)
  96. units = self.extract_units(audio_tensor)
  97. tokenized_unit_ids = self.alignment_model.alignment_frontend.tokenize_unit(
  98. units
  99. ).unsqueeze(0)
  100. tokenized_text_ids = (
  101. self.alignment_model.alignment_frontend.tokenize_text(
  102. text, add_trailing_silence=add_trailing_silence
  103. )
  104. .to(self.device)
  105. .unsqueeze(0)
  106. )
  107. tokenized_text_tokens = (
  108. self.alignment_model.alignment_frontend.tokenize_text_to_tokens(
  109. text, add_trailing_silence=add_trailing_silence
  110. )
  111. )
  112. _, alignment_durations = self.alignment_model(
  113. tokenized_text_ids, tokenized_unit_ids
  114. )
  115. if plot and (audio_tensor is not None):
  116. self.plot_alignment(
  117. audio_tensor.cpu(), tokenized_text_tokens, alignment_durations.cpu()
  118. )
  119. return alignment_durations, tokenized_text_ids, tokenized_text_tokens
  120. def detokenize_text(self, tokenized_text_ids: Tensor) -> StringLike:
  121. return self.alignment_model.alignment_frontend.decode_text(tokenized_text_ids)
  122. def plot_alignment(
  123. self, audio: Tensor, text_tokens: List[StringLike], durations: Tensor
  124. ) -> None:
  125. if not matplotlib_available:
  126. raise RuntimeError(
  127. "Please `pip install matplotlib` in order to use plot alignment."
  128. )
  129. _, ax = plt.subplots(figsize=(22, 3.5))
  130. ax.plot(audio, color="gray", linewidth=0.3)
  131. durations_cumul = numpy.concatenate([numpy.array([0]), numpy.cumsum(durations)])
  132. alignment_ticks = durations_cumul * 320 # 320 is hardcoded for 20ms rate here
  133. ax.vlines(
  134. alignment_ticks,
  135. ymax=1,
  136. ymin=-1,
  137. color="indigo",
  138. linestyles="dashed",
  139. lw=0.5,
  140. )
  141. middle_tick_positions = (
  142. durations_cumul[:-1] + (durations_cumul[1:] - durations_cumul[:-1]) / 2
  143. )
  144. ax.set_xticks(middle_tick_positions * 320)
  145. ax.set_xticklabels(text_tokens, fontsize=13)
  146. ax.set_xlim(0, len(audio))
  147. ax.set_ylim(audio.min(), audio.max())
  148. ax.set_yticks([])
  149. plt.tight_layout()
  150. plt.show()