feat: transcribe_file returns timestamped segments when with_segments=True

This commit is contained in:
2026-04-02 00:55:53 +02:00
parent 7dfc0e0c5f
commit 47909637a8
2 changed files with 95 additions and 10 deletions
+49
View File
@@ -47,3 +47,52 @@ async def test_transcribe_uses_remote_when_base_url_set(tmp_path):
device="auto", base_url="http://beastix:8000",
)
assert result == "Hallo Welt"
def test_transcribe_file_returns_segments_when_requested(tmp_path):
wav = tmp_path / "test.wav"
wav.write_bytes(b"\x00" * 100)
mock_model = MagicMock()
mock_seg = MagicMock()
mock_seg.text = " Hallo Welt"
mock_seg.start = 0.0
mock_seg.end = 1.5
mock_model.transcribe.return_value = ([mock_seg], MagicMock())
from transcription import TranscriptionEngine
eng = TranscriptionEngine()
eng._model = mock_model
result = asyncio.run(eng.transcribe_file(str(wav), language="de", with_segments=True))
assert isinstance(result, list)
assert result[0]["text"] == "Hallo Welt"
assert result[0]["start"] == 0.0
assert result[0]["end"] == 1.5
@pytest.mark.asyncio
async def test_transcribe_remote_returns_segments_when_requested(tmp_path):
import wave, struct
wav = tmp_path / "test.wav"
with wave.open(str(wav), "wb") as wf:
wf.setnchannels(1); wf.setsampwidth(2); wf.setframerate(16000)
wf.writeframes(struct.pack("<100h", *([0] * 100)))
import respx, httpx
from transcription import TranscriptionEngine
eng = TranscriptionEngine()
with respx.mock:
respx.post("http://beastix:8000/v1/audio/transcriptions").mock(
return_value=httpx.Response(200, json={
"text": "Hallo Welt",
"segments": [{"start": 0.0, "end": 1.5, "text": " Hallo Welt"}],
})
)
result = await eng.transcribe_file(
str(wav), language="de", model_name="large-v3",
device="auto", base_url="http://beastix:8000", with_segments=True,
)
assert isinstance(result, list)
assert result[0]["text"] == "Hallo Welt"
+46 -10
View File
@@ -1,5 +1,6 @@
import asyncio
import httpx
from typing import Union
class TranscriptionEngine:
@@ -25,34 +26,69 @@ class TranscriptionEngine:
model_name: str = "large-v3",
device: str = "auto",
base_url: str = "",
) -> str:
with_segments: bool = False,
) -> Union[str, list[dict]]:
if base_url:
return await self._transcribe_remote(audio_path, language, model_name, base_url)
return await self._transcribe_local(audio_path, language, model_name, device)
return await self._transcribe_remote(
audio_path, language, model_name, base_url, with_segments
)
return await self._transcribe_local(
audio_path, language, model_name, device, with_segments
)
async def _transcribe_remote(
self, audio_path: str, language: str, model_name: str, base_url: str
) -> str:
self,
audio_path: str,
language: str,
model_name: str,
base_url: str,
with_segments: bool,
) -> Union[str, list[dict]]:
async with httpx.AsyncClient(timeout=300) as client:
with open(audio_path, "rb") as f:
data = {"model": model_name, "language": language}
if with_segments:
data["timestamp_granularities[]"] = "segment"
data["response_format"] = "verbose_json"
r = await client.post(
f"{base_url}/v1/audio/transcriptions",
files={"file": ("audio.wav", f, "audio/wav")},
data={"model": model_name, "language": language},
data=data,
)
r.raise_for_status()
return r.json()["text"]
body = r.json()
if not with_segments:
return body["text"]
raw_segs = body.get("segments") or []
if raw_segs:
return [
{"start": s["start"], "end": s["end"], "text": s["text"].strip()}
for s in raw_segs
]
return [{"start": 0.0, "end": 9999.0, "text": body["text"].strip()}]
async def _transcribe_local(
self, audio_path: str, language: str, model_name: str, device: str
) -> str:
self,
audio_path: str,
language: str,
model_name: str,
device: str,
with_segments: bool,
) -> Union[str, list[dict]]:
loop = asyncio.get_running_loop()
model = self._get_model(model_name, device)
segments, _ = await loop.run_in_executor(
None,
lambda: model.transcribe(audio_path, language=language),
)
return "".join(seg.text for seg in segments).strip()
segments = list(segments)
if not with_segments:
return "".join(seg.text for seg in segments).strip()
return [
{"start": seg.start, "end": seg.end, "text": seg.text.strip()}
for seg in segments
if seg.text.strip()
]
engine = TranscriptionEngine()