| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 | 
							- # Copyright (c) Meta Platforms, Inc. and affiliates
 
- # All rights reserved.
 
- #
 
- # This source code is licensed under the license found in the
 
- # LICENSE file in the root directory of this source tree.
 
- import tempfile
 
- from argparse import ArgumentTypeError
 
- from typing import cast
 
- from urllib.request import urlretrieve
 
- import pytest
 
- import torch
 
- from fairseq2.data.audio import AudioDecoder, AudioDecoderOutput
 
- from fairseq2.memory import MemoryBlock
 
- from fairseq2.typing import Device
 
- import tests.common
 
- def parse_device_arg(value: str) -> Device:
 
-     try:
 
-         return Device(value)
 
-     except RuntimeError:
 
-         raise ArgumentTypeError(f"'{value}' is not a valid device name.")
 
- def pytest_addoption(parser: pytest.Parser) -> None:
 
-     # fmt: off
 
-     parser.addoption(
 
-         "--device", default="cpu", type=parse_device_arg,
 
-         help="device on which to run tests (default: %(default)s)",
 
-     )
 
-     # fmt: on
 
- def pytest_sessionstart(session: pytest.Session) -> None:
 
-     tests.common.device = cast(Device, session.config.getoption("device"))
 
- @pytest.fixture(scope="module")
 
- def example_rate16k_audio() -> AudioDecoderOutput:
 
-     url = "https://dl.fbaipublicfiles.com/seamlessM4T/LJ037-0171_sr16k.wav"
 
-     audio_decoder = AudioDecoder(dtype=torch.float32, device=tests.common.device)
 
-     with tempfile.NamedTemporaryFile() as f:
 
-         urlretrieve(url, f.name)
 
-         with open(f.name, "rb") as fb:
 
-             block = MemoryBlock(fb.read())
 
-         decoded_audio = audio_decoder(block)
 
-     return decoded_audio
 
 
  |