39 lines
1.0 KiB
Python
39 lines
1.0 KiB
Python
from unittest.mock import MagicMock, patch
|
|
import pytest
|
|
|
|
|
|
def test_diarizer_returns_list_of_tuples(tmp_path):
|
|
"""Diarizer.diarize() returns [(start, end, speaker), ...]"""
|
|
wav = tmp_path / "test.wav"
|
|
wav.write_bytes(b"\x00" * 100)
|
|
|
|
mock_turn_1 = MagicMock()
|
|
mock_turn_1.start = 0.0
|
|
mock_turn_1.end = 2.5
|
|
|
|
mock_turn_2 = MagicMock()
|
|
mock_turn_2.start = 2.6
|
|
mock_turn_2.end = 5.0
|
|
|
|
mock_annotation = MagicMock()
|
|
mock_annotation.itertracks.return_value = [
|
|
(mock_turn_1, "A", "SPEAKER_00"),
|
|
(mock_turn_2, "B", "SPEAKER_01"),
|
|
]
|
|
|
|
mock_pipeline = MagicMock(return_value=mock_annotation)
|
|
|
|
import asyncio
|
|
from diarization import Diarizer
|
|
d = Diarizer.__new__(Diarizer)
|
|
d._pipeline = mock_pipeline
|
|
|
|
result = asyncio.run(d.diarize(str(wav)))
|
|
assert result == [(0.0, 2.5, "SPEAKER_00"), (2.6, 5.0, "SPEAKER_01")]
|
|
|
|
|
|
def test_diarizer_requires_hf_token():
|
|
from diarization import Diarizer
|
|
with pytest.raises(ValueError, match="hf_token"):
|
|
Diarizer(hf_token="")
|