feat: Diarizer class wrapping pyannote/speaker-diarization-3.1

This commit is contained in:
2026-04-02 00:59:50 +02:00
parent 47909637a8
commit 1a9d0eacc2
3 changed files with 66 additions and 0 deletions
+27
View File
@@ -0,0 +1,27 @@
import asyncio
class Diarizer:
def __init__(self, hf_token: str):
if not hf_token:
raise ValueError("hf_token is required for diarization")
self._hf_token = hf_token
self._pipeline = None
def _load_pipeline(self):
if self._pipeline is None:
from pyannote.audio import Pipeline
self._pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=self._hf_token,
)
return self._pipeline
async def diarize(self, wav_path: str) -> list[tuple[float, float, str]]:
loop = asyncio.get_running_loop()
pipeline = await loop.run_in_executor(None, self._load_pipeline)
annotation = await loop.run_in_executor(None, lambda: pipeline(wav_path))
return [
(turn.start, turn.end, speaker)
for turn, _, speaker in annotation.itertracks(yield_label=True)
]
+1
View File
@@ -9,3 +9,4 @@ numpy>=1.26
tomli_w>=1.0
pytest>=8.0
pytest-asyncio>=0.23
pyannote.audio>=3.3
+38
View File
@@ -0,0 +1,38 @@
from unittest.mock import MagicMock, patch
import pytest
def test_diarizer_returns_list_of_tuples(tmp_path):
"""Diarizer.diarize() returns [(start, end, speaker), ...]"""
wav = tmp_path / "test.wav"
wav.write_bytes(b"\x00" * 100)
mock_turn_1 = MagicMock()
mock_turn_1.start = 0.0
mock_turn_1.end = 2.5
mock_turn_2 = MagicMock()
mock_turn_2.start = 2.6
mock_turn_2.end = 5.0
mock_annotation = MagicMock()
mock_annotation.itertracks.return_value = [
(mock_turn_1, "A", "SPEAKER_00"),
(mock_turn_2, "B", "SPEAKER_01"),
]
mock_pipeline = MagicMock(return_value=mock_annotation)
import asyncio
from diarization import Diarizer
d = Diarizer.__new__(Diarizer)
d._pipeline = mock_pipeline
result = asyncio.run(d.diarize(str(wav)))
assert result == [(0.0, 2.5, "SPEAKER_00"), (2.6, 5.0, "SPEAKER_01")]
def test_diarizer_requires_hf_token():
from diarization import Diarizer
with pytest.raises(ValueError, match="hf_token"):
Diarizer(hf_token="")