test_silero_vad.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. import unittest
  7. from argparse import Namespace
  8. from unittest.mock import Mock
  9. from seamless_communication.segment.silero_vad import SileroVADSegmenter, Segment
  10. import numpy as np
  11. class TestSileroVADSegmenter(unittest.TestCase):
  12. def test_init_works(self):
  13. segmenter = SileroVADSegmenter(
  14. sample_rate=16000,
  15. chunk_size_sec=10,
  16. pause_length=0.5)
  17. self.assertEqual(segmenter.sample_rate, 16000)
  18. self.assertEqual(segmenter.chunk_size_sec, 10)
  19. self.assertEqual(segmenter.pause_length, 0.5)
  20. def test_segment_long_input(self):
  21. self.segmenter = SileroVADSegmenter(
  22. sample_rate=16000,
  23. chunk_size_sec=10,
  24. pause_length=0.5)
  25. self.segmenter.get_speech_timestamps = Mock(
  26. return_value=[{0: 0, 1: 10000},
  27. {0: 20000, 1: 30000}])
  28. segments = self.segmenter.segment_long_input(audio=None)
  29. expected_segments = [[0, 10000], [20000, 30000]]
  30. self.assertEqual(segments, expected_segments)
  31. def test_recursive_split(self):
  32. segmenter = SileroVADSegmenter(
  33. sample_rate=16000,
  34. chunk_size_sec=10,
  35. pause_length=0.5)
  36. sgm = Segment(0, 10000, np.random.rand(10000))
  37. segments = []
  38. max_segment_length = 5000
  39. min_segment_length = 1000
  40. window_size_samples = 100
  41. threshold = .5
  42. segmenter.recursive_split(
  43. sgm,
  44. segments,
  45. max_segment_length,
  46. min_segment_length,
  47. window_size_samples,
  48. threshold)
  49. assert all([seg.duration < max_segment_length for seg in segments])
  50. assert all([seg.duration > min_segment_length for seg in segments])