feat: Diarizer class wrapping pyannote/speaker-diarization-3.1
This commit is contained in:
@@ -0,0 +1,27 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class Diarizer:
|
||||||
|
def __init__(self, hf_token: str):
|
||||||
|
if not hf_token:
|
||||||
|
raise ValueError("hf_token is required for diarization")
|
||||||
|
self._hf_token = hf_token
|
||||||
|
self._pipeline = None
|
||||||
|
|
||||||
|
def _load_pipeline(self):
|
||||||
|
if self._pipeline is None:
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
self._pipeline = Pipeline.from_pretrained(
|
||||||
|
"pyannote/speaker-diarization-3.1",
|
||||||
|
use_auth_token=self._hf_token,
|
||||||
|
)
|
||||||
|
return self._pipeline
|
||||||
|
|
||||||
|
async def diarize(self, wav_path: str) -> list[tuple[float, float, str]]:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
pipeline = await loop.run_in_executor(None, self._load_pipeline)
|
||||||
|
annotation = await loop.run_in_executor(None, lambda: pipeline(wav_path))
|
||||||
|
return [
|
||||||
|
(turn.start, turn.end, speaker)
|
||||||
|
for turn, _, speaker in annotation.itertracks(yield_label=True)
|
||||||
|
]
|
||||||
@@ -9,3 +9,4 @@ numpy>=1.26
|
|||||||
tomli_w>=1.0
|
tomli_w>=1.0
|
||||||
pytest>=8.0
|
pytest>=8.0
|
||||||
pytest-asyncio>=0.23
|
pytest-asyncio>=0.23
|
||||||
|
pyannote.audio>=3.3
|
||||||
|
|||||||
@@ -0,0 +1,38 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_diarizer_returns_list_of_tuples(tmp_path):
|
||||||
|
"""Diarizer.diarize() returns [(start, end, speaker), ...]"""
|
||||||
|
wav = tmp_path / "test.wav"
|
||||||
|
wav.write_bytes(b"\x00" * 100)
|
||||||
|
|
||||||
|
mock_turn_1 = MagicMock()
|
||||||
|
mock_turn_1.start = 0.0
|
||||||
|
mock_turn_1.end = 2.5
|
||||||
|
|
||||||
|
mock_turn_2 = MagicMock()
|
||||||
|
mock_turn_2.start = 2.6
|
||||||
|
mock_turn_2.end = 5.0
|
||||||
|
|
||||||
|
mock_annotation = MagicMock()
|
||||||
|
mock_annotation.itertracks.return_value = [
|
||||||
|
(mock_turn_1, "A", "SPEAKER_00"),
|
||||||
|
(mock_turn_2, "B", "SPEAKER_01"),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_pipeline = MagicMock(return_value=mock_annotation)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from diarization import Diarizer
|
||||||
|
d = Diarizer.__new__(Diarizer)
|
||||||
|
d._pipeline = mock_pipeline
|
||||||
|
|
||||||
|
result = asyncio.run(d.diarize(str(wav)))
|
||||||
|
assert result == [(0.0, 2.5, "SPEAKER_00"), (2.6, 5.0, "SPEAKER_01")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_diarizer_requires_hf_token():
|
||||||
|
from diarization import Diarizer
|
||||||
|
with pytest.raises(ValueError, match="hf_token"):
|
||||||
|
Diarizer(hf_token="")
|
||||||
Reference in New Issue
Block a user