common.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from contextlib import contextmanager
  7. from typing import Any, Generator, List, Optional, Union
  8. import torch
  9. from fairseq2.data import Collater
  10. from fairseq2.data.audio import WaveformToFbankConverter, WaveformToFbankInput
  11. from fairseq2.typing import DataType, Device
  12. from torch import Tensor
  13. # The default device that tests should use. Note that pytest can change it based
  14. # on the provided command line arguments.
  15. device = Device("cpu")
  16. def assert_close(
  17. a: Tensor,
  18. b: Union[Tensor, List[Any]],
  19. rtol: Optional[float] = None,
  20. atol: Optional[float] = None,
  21. ) -> None:
  22. """Assert that ``a`` and ``b`` are element-wise equal within a tolerance."""
  23. if not isinstance(b, Tensor):
  24. b = torch.tensor(b, device=device, dtype=a.dtype)
  25. torch.testing.assert_close(a, b, rtol=rtol, atol=atol) # type: ignore[attr-defined]
  26. def assert_equal(a: Tensor, b: Union[Tensor, List[Any]]) -> None:
  27. """Assert that ``a`` and ``b`` are element-wise equal."""
  28. if not isinstance(b, Tensor):
  29. b = torch.tensor(b, device=device, dtype=a.dtype)
  30. torch.testing.assert_close(a, b, rtol=0, atol=0) # type: ignore[attr-defined]
  31. def assert_unit_close(
  32. a: Tensor,
  33. b: Union[Tensor, List[Any]],
  34. num_unit_tol: int = 1,
  35. percent_unit_tol: float = 0.0,
  36. ) -> None:
  37. """Assert two unit sequence are equal within a tolerance"""
  38. if not isinstance(b, Tensor):
  39. b = torch.tensor(b, device=device, dtype=a.dtype)
  40. assert (
  41. a.shape == b.shape
  42. ), f"Two shapes are different, one is {a.shape}, the other is {b.shape}"
  43. if percent_unit_tol > 0.0:
  44. num_unit_tol = int(percent_unit_tol * len(a))
  45. num_unit_diff = (a != b).sum()
  46. assert (
  47. num_unit_diff <= num_unit_tol
  48. ), f"The difference is beyond tolerance, {num_unit_diff} units are different, tolerance is {num_unit_tol}"
  49. def has_no_inf(a: Tensor) -> bool:
  50. """Return ``True`` if ``a`` has no positive or negative infinite element."""
  51. return not torch.any(torch.isinf(a))
  52. def has_no_nan(a: Tensor) -> bool:
  53. """Return ``True`` if ``a`` has no NaN element."""
  54. return not torch.any(torch.isnan(a))
  55. @contextmanager
  56. def tmp_rng_seed(device: Device, seed: int = 0) -> Generator[None, None, None]:
  57. """Set a temporary manual RNG seed.
  58. The RNG is reset to its original state once the block is exited.
  59. """
  60. device = Device(device)
  61. if device.type == "cuda":
  62. devices = [device]
  63. else:
  64. devices = []
  65. with torch.random.fork_rng(devices):
  66. torch.manual_seed(seed)
  67. yield
  68. def get_default_dtype() -> DataType:
  69. if device == Device("cpu"):
  70. dtype = torch.float32
  71. else:
  72. dtype = torch.float16
  73. return dtype
  74. def convert_to_collated_fbank(audio_dict: WaveformToFbankInput, dtype: DataType) -> Any:
  75. convert_to_fbank = WaveformToFbankConverter(
  76. num_mel_bins=80,
  77. waveform_scale=2**15,
  78. channel_last=True,
  79. standardize=True,
  80. device=device,
  81. dtype=dtype,
  82. )
  83. collater = Collater(pad_value=1)
  84. feat = collater(convert_to_fbank(audio_dict))["fbank"]
  85. return feat