From 47909637a84122592f2af550a624cae744e0632d Mon Sep 17 00:00:00 2001 From: "thomas.kopp" Date: Thu, 2 Apr 2026 00:55:53 +0200 Subject: [PATCH] feat: transcribe_file returns timestamped segments when with_segments=True --- tests/test_transcription.py | 49 ++++++++++++++++++++++++++++++++ transcription.py | 56 ++++++++++++++++++++++++++++++------- 2 files changed, 95 insertions(+), 10 deletions(-) diff --git a/tests/test_transcription.py b/tests/test_transcription.py index ff3f497..a9e792e 100644 --- a/tests/test_transcription.py +++ b/tests/test_transcription.py @@ -47,3 +47,52 @@ async def test_transcribe_uses_remote_when_base_url_set(tmp_path): device="auto", base_url="http://beastix:8000", ) assert result == "Hallo Welt" + + +def test_transcribe_file_returns_segments_when_requested(tmp_path): + wav = tmp_path / "test.wav" + wav.write_bytes(b"\x00" * 100) + + mock_model = MagicMock() + mock_seg = MagicMock() + mock_seg.text = " Hallo Welt" + mock_seg.start = 0.0 + mock_seg.end = 1.5 + mock_model.transcribe.return_value = ([mock_seg], MagicMock()) + + from transcription import TranscriptionEngine + eng = TranscriptionEngine() + eng._model = mock_model + + result = asyncio.run(eng.transcribe_file(str(wav), language="de", with_segments=True)) + assert isinstance(result, list) + assert result[0]["text"] == "Hallo Welt" + assert result[0]["start"] == 0.0 + assert result[0]["end"] == 1.5 + + +@pytest.mark.asyncio +async def test_transcribe_remote_returns_segments_when_requested(tmp_path): + import wave, struct + wav = tmp_path / "test.wav" + with wave.open(str(wav), "wb") as wf: + wf.setnchannels(1); wf.setsampwidth(2); wf.setframerate(16000) + wf.writeframes(struct.pack("<100h", *([0] * 100))) + + import respx, httpx + from transcription import TranscriptionEngine + eng = TranscriptionEngine() + + with respx.mock: + respx.post("http://beastix:8000/v1/audio/transcriptions").mock( + return_value=httpx.Response(200, json={ + "text": "Hallo Welt", + "segments": [{"start": 0.0, "end": 1.5, "text": " Hallo Welt"}], + }) + ) + result = await eng.transcribe_file( + str(wav), language="de", model_name="large-v3", + device="auto", base_url="http://beastix:8000", with_segments=True, + ) + assert isinstance(result, list) + assert result[0]["text"] == "Hallo Welt" diff --git a/transcription.py b/transcription.py index b977c50..e0e1630 100644 --- a/transcription.py +++ b/transcription.py @@ -1,5 +1,6 @@ import asyncio import httpx +from typing import Union class TranscriptionEngine: @@ -25,34 +26,69 @@ class TranscriptionEngine: model_name: str = "large-v3", device: str = "auto", base_url: str = "", - ) -> str: + with_segments: bool = False, + ) -> Union[str, list[dict]]: if base_url: - return await self._transcribe_remote(audio_path, language, model_name, base_url) - return await self._transcribe_local(audio_path, language, model_name, device) + return await self._transcribe_remote( + audio_path, language, model_name, base_url, with_segments + ) + return await self._transcribe_local( + audio_path, language, model_name, device, with_segments + ) async def _transcribe_remote( - self, audio_path: str, language: str, model_name: str, base_url: str - ) -> str: + self, + audio_path: str, + language: str, + model_name: str, + base_url: str, + with_segments: bool, + ) -> Union[str, list[dict]]: async with httpx.AsyncClient(timeout=300) as client: with open(audio_path, "rb") as f: + data = {"model": model_name, "language": language} + if with_segments: + data["timestamp_granularities[]"] = "segment" + data["response_format"] = "verbose_json" r = await client.post( f"{base_url}/v1/audio/transcriptions", files={"file": ("audio.wav", f, "audio/wav")}, - data={"model": model_name, "language": language}, + data=data, ) r.raise_for_status() - return r.json()["text"] + body = r.json() + if not with_segments: + return body["text"] + raw_segs = body.get("segments") or [] + if raw_segs: + return [ + {"start": s["start"], "end": s["end"], "text": s["text"].strip()} + for s in raw_segs + ] + return [{"start": 0.0, "end": 9999.0, "text": body["text"].strip()}] async def _transcribe_local( - self, audio_path: str, language: str, model_name: str, device: str - ) -> str: + self, + audio_path: str, + language: str, + model_name: str, + device: str, + with_segments: bool, + ) -> Union[str, list[dict]]: loop = asyncio.get_running_loop() model = self._get_model(model_name, device) segments, _ = await loop.run_in_executor( None, lambda: model.transcribe(audio_path, language=language), ) - return "".join(seg.text for seg in segments).strip() + segments = list(segments) + if not with_segments: + return "".join(seg.text for seg in segments).strip() + return [ + {"start": seg.start, "end": seg.end, "text": seg.text.strip()} + for seg in segments + if seg.text.strip() + ] engine = TranscriptionEngine()