feat: transcribe_file returns timestamped segments when with_segments=True
This commit is contained in:
@@ -47,3 +47,52 @@ async def test_transcribe_uses_remote_when_base_url_set(tmp_path):
|
||||
device="auto", base_url="http://beastix:8000",
|
||||
)
|
||||
assert result == "Hallo Welt"
|
||||
|
||||
|
||||
def test_transcribe_file_returns_segments_when_requested(tmp_path):
|
||||
wav = tmp_path / "test.wav"
|
||||
wav.write_bytes(b"\x00" * 100)
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_seg = MagicMock()
|
||||
mock_seg.text = " Hallo Welt"
|
||||
mock_seg.start = 0.0
|
||||
mock_seg.end = 1.5
|
||||
mock_model.transcribe.return_value = ([mock_seg], MagicMock())
|
||||
|
||||
from transcription import TranscriptionEngine
|
||||
eng = TranscriptionEngine()
|
||||
eng._model = mock_model
|
||||
|
||||
result = asyncio.run(eng.transcribe_file(str(wav), language="de", with_segments=True))
|
||||
assert isinstance(result, list)
|
||||
assert result[0]["text"] == "Hallo Welt"
|
||||
assert result[0]["start"] == 0.0
|
||||
assert result[0]["end"] == 1.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcribe_remote_returns_segments_when_requested(tmp_path):
|
||||
import wave, struct
|
||||
wav = tmp_path / "test.wav"
|
||||
with wave.open(str(wav), "wb") as wf:
|
||||
wf.setnchannels(1); wf.setsampwidth(2); wf.setframerate(16000)
|
||||
wf.writeframes(struct.pack("<100h", *([0] * 100)))
|
||||
|
||||
import respx, httpx
|
||||
from transcription import TranscriptionEngine
|
||||
eng = TranscriptionEngine()
|
||||
|
||||
with respx.mock:
|
||||
respx.post("http://beastix:8000/v1/audio/transcriptions").mock(
|
||||
return_value=httpx.Response(200, json={
|
||||
"text": "Hallo Welt",
|
||||
"segments": [{"start": 0.0, "end": 1.5, "text": " Hallo Welt"}],
|
||||
})
|
||||
)
|
||||
result = await eng.transcribe_file(
|
||||
str(wav), language="de", model_name="large-v3",
|
||||
device="auto", base_url="http://beastix:8000", with_segments=True,
|
||||
)
|
||||
assert isinstance(result, list)
|
||||
assert result[0]["text"] == "Hallo Welt"
|
||||
|
||||
+45
-9
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import httpx
|
||||
from typing import Union
|
||||
|
||||
|
||||
class TranscriptionEngine:
|
||||
@@ -25,34 +26,69 @@ class TranscriptionEngine:
|
||||
model_name: str = "large-v3",
|
||||
device: str = "auto",
|
||||
base_url: str = "",
|
||||
) -> str:
|
||||
with_segments: bool = False,
|
||||
) -> Union[str, list[dict]]:
|
||||
if base_url:
|
||||
return await self._transcribe_remote(audio_path, language, model_name, base_url)
|
||||
return await self._transcribe_local(audio_path, language, model_name, device)
|
||||
return await self._transcribe_remote(
|
||||
audio_path, language, model_name, base_url, with_segments
|
||||
)
|
||||
return await self._transcribe_local(
|
||||
audio_path, language, model_name, device, with_segments
|
||||
)
|
||||
|
||||
async def _transcribe_remote(
|
||||
self, audio_path: str, language: str, model_name: str, base_url: str
|
||||
) -> str:
|
||||
self,
|
||||
audio_path: str,
|
||||
language: str,
|
||||
model_name: str,
|
||||
base_url: str,
|
||||
with_segments: bool,
|
||||
) -> Union[str, list[dict]]:
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
with open(audio_path, "rb") as f:
|
||||
data = {"model": model_name, "language": language}
|
||||
if with_segments:
|
||||
data["timestamp_granularities[]"] = "segment"
|
||||
data["response_format"] = "verbose_json"
|
||||
r = await client.post(
|
||||
f"{base_url}/v1/audio/transcriptions",
|
||||
files={"file": ("audio.wav", f, "audio/wav")},
|
||||
data={"model": model_name, "language": language},
|
||||
data=data,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()["text"]
|
||||
body = r.json()
|
||||
if not with_segments:
|
||||
return body["text"]
|
||||
raw_segs = body.get("segments") or []
|
||||
if raw_segs:
|
||||
return [
|
||||
{"start": s["start"], "end": s["end"], "text": s["text"].strip()}
|
||||
for s in raw_segs
|
||||
]
|
||||
return [{"start": 0.0, "end": 9999.0, "text": body["text"].strip()}]
|
||||
|
||||
async def _transcribe_local(
|
||||
self, audio_path: str, language: str, model_name: str, device: str
|
||||
) -> str:
|
||||
self,
|
||||
audio_path: str,
|
||||
language: str,
|
||||
model_name: str,
|
||||
device: str,
|
||||
with_segments: bool,
|
||||
) -> Union[str, list[dict]]:
|
||||
loop = asyncio.get_running_loop()
|
||||
model = self._get_model(model_name, device)
|
||||
segments, _ = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: model.transcribe(audio_path, language=language),
|
||||
)
|
||||
segments = list(segments)
|
||||
if not with_segments:
|
||||
return "".join(seg.text for seg in segments).strip()
|
||||
return [
|
||||
{"start": seg.start, "end": seg.end, "text": seg.text.strip()}
|
||||
for seg in segments
|
||||
if seg.text.strip()
|
||||
]
|
||||
|
||||
|
||||
engine = TranscriptionEngine()
|
||||
|
||||
Reference in New Issue
Block a user