test_conformer_shaw.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import torch
  7. from fairseq2.data.audio import AudioDecoderOutput
  8. from fairseq2.nn.padding import get_seqs_and_padding_mask
  9. from seamless_communication.models.conformer_shaw import load_conformer_shaw_model
  10. from tests.common import (
  11. convert_to_collated_fbank,
  12. get_default_dtype,
  13. device,
  14. )
  15. REF_MEAN, REF_STD = -0.0001, 0.1547
  16. def test_conformer_shaw_600m(example_rate16k_audio: AudioDecoderOutput) -> None:
  17. dtype = get_default_dtype()
  18. audio_dict = example_rate16k_audio
  19. src = convert_to_collated_fbank(audio_dict, dtype=dtype)
  20. seqs, padding_mask = get_seqs_and_padding_mask(src)
  21. model = load_conformer_shaw_model("conformer_shaw", device=device, dtype=dtype)
  22. model.eval()
  23. with torch.inference_mode():
  24. seqs, padding_mask = model.encoder_frontend(seqs, padding_mask)
  25. seqs, _ = model.encoder(seqs, padding_mask)
  26. std, mean = torch.std_mean(seqs)
  27. assert round(mean.item(), 4) == REF_MEAN
  28. assert round(std.item(), 4) == REF_STD