From 1a9d0eacc2351368f0743a0328b47467d33346b3 Mon Sep 17 00:00:00 2001 From: "thomas.kopp" Date: Thu, 2 Apr 2026 00:59:50 +0200 Subject: [PATCH] feat: Diarizer class wrapping pyannote/speaker-diarization-3.1 --- diarization.py | 27 +++++++++++++++++++++++++++ requirements.txt | 1 + tests/test_diarization.py | 38 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+) create mode 100644 diarization.py create mode 100644 tests/test_diarization.py diff --git a/diarization.py b/diarization.py new file mode 100644 index 0000000..bb1527d --- /dev/null +++ b/diarization.py @@ -0,0 +1,27 @@ +import asyncio + + +class Diarizer: + def __init__(self, hf_token: str): + if not hf_token: + raise ValueError("hf_token is required for diarization") + self._hf_token = hf_token + self._pipeline = None + + def _load_pipeline(self): + if self._pipeline is None: + from pyannote.audio import Pipeline + self._pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.1", + use_auth_token=self._hf_token, + ) + return self._pipeline + + async def diarize(self, wav_path: str) -> list[tuple[float, float, str]]: + loop = asyncio.get_running_loop() + pipeline = await loop.run_in_executor(None, self._load_pipeline) + annotation = await loop.run_in_executor(None, lambda: pipeline(wav_path)) + return [ + (turn.start, turn.end, speaker) + for turn, _, speaker in annotation.itertracks(yield_label=True) + ] diff --git a/requirements.txt b/requirements.txt index 6568721..96ee457 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ numpy>=1.26 tomli_w>=1.0 pytest>=8.0 pytest-asyncio>=0.23 +pyannote.audio>=3.3 diff --git a/tests/test_diarization.py b/tests/test_diarization.py new file mode 100644 index 0000000..c3b2908 --- /dev/null +++ b/tests/test_diarization.py @@ -0,0 +1,38 @@ +from unittest.mock import MagicMock, patch +import pytest + + +def test_diarizer_returns_list_of_tuples(tmp_path): + """Diarizer.diarize() returns [(start, end, speaker), ...]""" + wav = tmp_path / "test.wav" + wav.write_bytes(b"\x00" * 100) + + mock_turn_1 = MagicMock() + mock_turn_1.start = 0.0 + mock_turn_1.end = 2.5 + + mock_turn_2 = MagicMock() + mock_turn_2.start = 2.6 + mock_turn_2.end = 5.0 + + mock_annotation = MagicMock() + mock_annotation.itertracks.return_value = [ + (mock_turn_1, "A", "SPEAKER_00"), + (mock_turn_2, "B", "SPEAKER_01"), + ] + + mock_pipeline = MagicMock(return_value=mock_annotation) + + import asyncio + from diarization import Diarizer + d = Diarizer.__new__(Diarizer) + d._pipeline = mock_pipeline + + result = asyncio.run(d.diarize(str(wav))) + assert result == [(0.0, 2.5, "SPEAKER_00"), (2.6, 5.0, "SPEAKER_01")] + + +def test_diarizer_requires_hf_token(): + from diarization import Diarizer + with pytest.raises(ValueError, match="hf_token"): + Diarizer(hf_token="")