Bladeren bron

Loose the PretsselModel test check by allowing one unit different b/t runs (#127)

Yilin Yang 1 jaar geleden
bovenliggende
commit
b9f101b2b7
2 gewijzigde bestanden met toevoegingen van 27 en 4 verwijderingen
  1. 23 0
      tests/common.py
  2. 4 4
      tests/integration/models/test_expressivity.py

+ 23 - 0
tests/common.py

@@ -39,6 +39,29 @@ def assert_equal(a: Tensor, b: Union[Tensor, List[Any]]) -> None:
     torch.testing.assert_close(a, b, rtol=0, atol=0)  # type: ignore[attr-defined]
 
 
+def assert_unit_close(
+    a: Tensor,
+    b: Union[Tensor, List[Any]],
+    num_unit_tol: int = 1,
+    percent_unit_tol: float = 0.0,
+) -> None:
+    """Assert two unit sequence are equal within a tolerance"""
+    if not isinstance(b, Tensor):
+        b = torch.tensor(b, device=device, dtype=a.dtype)
+
+    assert (
+        a.shape == b.shape
+    ), f"Two shapes are different, one is {a.shape}, the other is {b.shape}"
+
+    if percent_unit_tol > 0.0:
+        num_unit_tol = int(percent_unit_tol * len(a))
+
+    num_unit_diff = (a != b).sum()
+    assert (
+        num_unit_diff <= num_unit_tol
+    ), f"The difference is beyond tolerance, {num_unit_diff} units are different, tolerance is {num_unit_tol}"
+
+
 def has_no_inf(a: Tensor) -> bool:
     """Return ``True`` if ``a`` has no positive or negative infinite element."""
     return not torch.any(torch.isinf(a))

+ 4 - 4
tests/integration/models/test_expressivity.py

@@ -15,7 +15,7 @@ from seamless_communication.inference.pretssel_generator import PretsselGenerato
 from seamless_communication.models.unit_extractor import UnitExtractor
 from seamless_communication.models.unity import load_gcmvn_stats
 from tests.common import (
-    assert_equal,
+    assert_unit_close,
     convert_to_collated_fbank,
     device,
     get_default_dtype,
@@ -28,7 +28,7 @@ REF_WAVE_EXTRACTED_UNITS: Final = [8976, 2066, 3800, 2357, 2357, 8080, 9479, 218
 
 
 def test_seamless_expressivity(example_rate16k_audio: AudioDecoderOutput) -> None:
-    # float16 is seeing non-deterministic behavior
+    # this model is seeing non-deterministic behavior (fp32 is better)
     dtype = torch.float32
 
     audio_dict = example_rate16k_audio
@@ -54,7 +54,7 @@ def test_seamless_expressivity(example_rate16k_audio: AudioDecoderOutput) -> Non
     units = tensor(speech_output.units[0], device=device, dtype=torch.int64)
 
     # same target units
-    assert_equal(units, tensor(REF_UNITS).to(units))
+    assert_unit_close(units, REF_UNITS)
 
     pretssel_generator = PretsselGenerator(
         unity_model_name,
@@ -81,4 +81,4 @@ def test_seamless_expressivity(example_rate16k_audio: AudioDecoderOutput) -> Non
     )
     units = unit_extractor.predict(waveform, 34)
 
-    assert_equal(units, tensor(REF_WAVE_EXTRACTED_UNITS).to(units))
+    assert_unit_close(units, REF_WAVE_EXTRACTED_UNITS)