1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # MIT_LICENSE file in the root directory of this source tree.
- import unittest
- from argparse import Namespace
- from unittest.mock import Mock
- from seamless_communication.segment.silero_vad import SileroVADSegmenter, Segment
- import numpy as np
- class TestSileroVADSegmenter(unittest.TestCase):
- def test_init_works(self):
- segmenter = SileroVADSegmenter(
- sample_rate=16000,
- chunk_size_sec=10,
- pause_length=0.5)
- self.assertEqual(segmenter.sample_rate, 16000)
- self.assertEqual(segmenter.chunk_size_sec, 10)
- self.assertEqual(segmenter.pause_length, 0.5)
- def test_segment_long_input(self):
- self.segmenter = SileroVADSegmenter(
- sample_rate=16000,
- chunk_size_sec=10,
- pause_length=0.5)
- self.segmenter.get_speech_timestamps = Mock(
- return_value=[{0: 0, 1: 10000},
- {0: 20000, 1: 30000}])
- segments = self.segmenter.segment_long_input(audio=None)
- expected_segments = [[0, 10000], [20000, 30000]]
- self.assertEqual(segments, expected_segments)
- def test_recursive_split(self):
- segmenter = SileroVADSegmenter(
- sample_rate=16000,
- chunk_size_sec=10,
- pause_length=0.5)
- sgm = Segment(0, 10000, np.random.rand(10000))
- segments = []
- max_segment_length = 5000
- min_segment_length = 1000
- window_size_samples = 100
- threshold = .5
- segmenter.recursive_split(
- sgm,
- segments,
- max_segment_length,
- min_segment_length,
- window_size_samples,
- threshold)
- assert all([seg.duration < max_segment_length for seg in segments])
- assert all([seg.duration > min_segment_length for seg in segments])
|