conftest.py 888 B

123456789101112131415161718192021222324252627282930313233
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from argparse import ArgumentTypeError
  7. from typing import cast
  8. import pytest
  9. import tests.common
  10. from fairseq2.typing import Device
  11. def parse_device_arg(value: str) -> Device:
  12. try:
  13. return Device(value)
  14. except RuntimeError:
  15. raise ArgumentTypeError(f"'{value}' is not a valid device name.")
  16. def pytest_addoption(parser: pytest.Parser) -> None:
  17. # fmt: off
  18. parser.addoption(
  19. "--device", default="cpu", type=parse_device_arg,
  20. help="device on which to run tests (default: %(default)s)",
  21. )
  22. # fmt: on
  23. def pytest_sessionstart(session: pytest.Session) -> None:
  24. tests.common.device = cast(Device, session.config.getoption("device"))