|
@@ -15,7 +15,7 @@ from seamless_communication.inference.pretssel_generator import PretsselGenerato
|
|
from seamless_communication.models.unit_extractor import UnitExtractor
|
|
from seamless_communication.models.unit_extractor import UnitExtractor
|
|
from seamless_communication.models.unity import load_gcmvn_stats
|
|
from seamless_communication.models.unity import load_gcmvn_stats
|
|
from tests.common import (
|
|
from tests.common import (
|
|
- assert_equal,
|
|
|
|
|
|
+ assert_unit_close,
|
|
convert_to_collated_fbank,
|
|
convert_to_collated_fbank,
|
|
device,
|
|
device,
|
|
get_default_dtype,
|
|
get_default_dtype,
|
|
@@ -28,7 +28,7 @@ REF_WAVE_EXTRACTED_UNITS: Final = [8976, 2066, 3800, 2357, 2357, 8080, 9479, 218
|
|
|
|
|
|
|
|
|
|
def test_seamless_expressivity(example_rate16k_audio: AudioDecoderOutput) -> None:
|
|
def test_seamless_expressivity(example_rate16k_audio: AudioDecoderOutput) -> None:
|
|
- # float16 is seeing non-deterministic behavior
|
|
|
|
|
|
+ # this model is seeing non-deterministic behavior (fp32 is better)
|
|
dtype = torch.float32
|
|
dtype = torch.float32
|
|
|
|
|
|
audio_dict = example_rate16k_audio
|
|
audio_dict = example_rate16k_audio
|
|
@@ -54,7 +54,7 @@ def test_seamless_expressivity(example_rate16k_audio: AudioDecoderOutput) -> Non
|
|
units = tensor(speech_output.units[0], device=device, dtype=torch.int64)
|
|
units = tensor(speech_output.units[0], device=device, dtype=torch.int64)
|
|
|
|
|
|
# same target units
|
|
# same target units
|
|
- assert_equal(units, tensor(REF_UNITS).to(units))
|
|
|
|
|
|
+ assert_unit_close(units, REF_UNITS)
|
|
|
|
|
|
pretssel_generator = PretsselGenerator(
|
|
pretssel_generator = PretsselGenerator(
|
|
unity_model_name,
|
|
unity_model_name,
|
|
@@ -81,4 +81,4 @@ def test_seamless_expressivity(example_rate16k_audio: AudioDecoderOutput) -> Non
|
|
)
|
|
)
|
|
units = unit_extractor.predict(waveform, 34)
|
|
units = unit_extractor.predict(waveform, 34)
|
|
|
|
|
|
- assert_equal(units, tensor(REF_WAVE_EXTRACTED_UNITS).to(units))
|
|
|
|
|
|
+ assert_unit_close(units, REF_WAVE_EXTRACTED_UNITS)
|