common.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from contextlib import contextmanager
  7. from typing import Any, Generator, List, Union
  8. import torch
  9. from fairseq2.typing import Device
  10. from torch import Tensor
  11. # The default device that tests should use. Note that pytest can change it based
  12. # on the provided command line arguments.
  13. device = Device("cpu")
  14. def assert_close(a: Tensor, b: Union[Tensor, List[Any]]) -> None:
  15. """Assert that ``a`` and ``b`` are element-wise equal within a tolerance."""
  16. if not isinstance(b, Tensor):
  17. b = torch.tensor(b, device=device, dtype=a.dtype)
  18. torch.testing.assert_close(a, b) # type: ignore[attr-defined]
  19. def assert_equal(a: Tensor, b: Union[Tensor, List[Any]]) -> None:
  20. """Assert that ``a`` and ``b`` are element-wise equal."""
  21. if not isinstance(b, Tensor):
  22. b = torch.tensor(b, device=device, dtype=a.dtype)
  23. torch.testing.assert_close(a, b, rtol=0, atol=0) # type: ignore[attr-defined]
  24. def has_no_inf(a: Tensor) -> bool:
  25. """Return ``True`` if ``a`` has no positive or negative infinite element."""
  26. return not torch.any(torch.isinf(a))
  27. def has_no_nan(a: Tensor) -> bool:
  28. """Return ``True`` if ``a`` has no NaN element."""
  29. return not torch.any(torch.isnan(a))
  30. @contextmanager
  31. def tmp_rng_seed(device: Device, seed: int = 0) -> Generator[None, None, None]:
  32. """Set a temporary manual RNG seed.
  33. The RNG is reset to its original state once the block is exited.
  34. """
  35. device = Device(device)
  36. if device.type == "cuda":
  37. devices = [device]
  38. else:
  39. devices = []
  40. with torch.random.fork_rng(devices):
  41. torch.manual_seed(seed)
  42. yield