feat: transcribe_file returns timestamped segments when with_segments=True
This commit is contained in:
@@ -47,3 +47,52 @@ async def test_transcribe_uses_remote_when_base_url_set(tmp_path):
|
|||||||
device="auto", base_url="http://beastix:8000",
|
device="auto", base_url="http://beastix:8000",
|
||||||
)
|
)
|
||||||
assert result == "Hallo Welt"
|
assert result == "Hallo Welt"
|
||||||
|
|
||||||
|
|
||||||
|
def test_transcribe_file_returns_segments_when_requested(tmp_path):
|
||||||
|
wav = tmp_path / "test.wav"
|
||||||
|
wav.write_bytes(b"\x00" * 100)
|
||||||
|
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_seg = MagicMock()
|
||||||
|
mock_seg.text = " Hallo Welt"
|
||||||
|
mock_seg.start = 0.0
|
||||||
|
mock_seg.end = 1.5
|
||||||
|
mock_model.transcribe.return_value = ([mock_seg], MagicMock())
|
||||||
|
|
||||||
|
from transcription import TranscriptionEngine
|
||||||
|
eng = TranscriptionEngine()
|
||||||
|
eng._model = mock_model
|
||||||
|
|
||||||
|
result = asyncio.run(eng.transcribe_file(str(wav), language="de", with_segments=True))
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert result[0]["text"] == "Hallo Welt"
|
||||||
|
assert result[0]["start"] == 0.0
|
||||||
|
assert result[0]["end"] == 1.5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcribe_remote_returns_segments_when_requested(tmp_path):
|
||||||
|
import wave, struct
|
||||||
|
wav = tmp_path / "test.wav"
|
||||||
|
with wave.open(str(wav), "wb") as wf:
|
||||||
|
wf.setnchannels(1); wf.setsampwidth(2); wf.setframerate(16000)
|
||||||
|
wf.writeframes(struct.pack("<100h", *([0] * 100)))
|
||||||
|
|
||||||
|
import respx, httpx
|
||||||
|
from transcription import TranscriptionEngine
|
||||||
|
eng = TranscriptionEngine()
|
||||||
|
|
||||||
|
with respx.mock:
|
||||||
|
respx.post("http://beastix:8000/v1/audio/transcriptions").mock(
|
||||||
|
return_value=httpx.Response(200, json={
|
||||||
|
"text": "Hallo Welt",
|
||||||
|
"segments": [{"start": 0.0, "end": 1.5, "text": " Hallo Welt"}],
|
||||||
|
})
|
||||||
|
)
|
||||||
|
result = await eng.transcribe_file(
|
||||||
|
str(wav), language="de", model_name="large-v3",
|
||||||
|
device="auto", base_url="http://beastix:8000", with_segments=True,
|
||||||
|
)
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert result[0]["text"] == "Hallo Welt"
|
||||||
|
|||||||
+46
-10
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import httpx
|
import httpx
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionEngine:
|
class TranscriptionEngine:
|
||||||
@@ -25,34 +26,69 @@ class TranscriptionEngine:
|
|||||||
model_name: str = "large-v3",
|
model_name: str = "large-v3",
|
||||||
device: str = "auto",
|
device: str = "auto",
|
||||||
base_url: str = "",
|
base_url: str = "",
|
||||||
) -> str:
|
with_segments: bool = False,
|
||||||
|
) -> Union[str, list[dict]]:
|
||||||
if base_url:
|
if base_url:
|
||||||
return await self._transcribe_remote(audio_path, language, model_name, base_url)
|
return await self._transcribe_remote(
|
||||||
return await self._transcribe_local(audio_path, language, model_name, device)
|
audio_path, language, model_name, base_url, with_segments
|
||||||
|
)
|
||||||
|
return await self._transcribe_local(
|
||||||
|
audio_path, language, model_name, device, with_segments
|
||||||
|
)
|
||||||
|
|
||||||
async def _transcribe_remote(
|
async def _transcribe_remote(
|
||||||
self, audio_path: str, language: str, model_name: str, base_url: str
|
self,
|
||||||
) -> str:
|
audio_path: str,
|
||||||
|
language: str,
|
||||||
|
model_name: str,
|
||||||
|
base_url: str,
|
||||||
|
with_segments: bool,
|
||||||
|
) -> Union[str, list[dict]]:
|
||||||
async with httpx.AsyncClient(timeout=300) as client:
|
async with httpx.AsyncClient(timeout=300) as client:
|
||||||
with open(audio_path, "rb") as f:
|
with open(audio_path, "rb") as f:
|
||||||
|
data = {"model": model_name, "language": language}
|
||||||
|
if with_segments:
|
||||||
|
data["timestamp_granularities[]"] = "segment"
|
||||||
|
data["response_format"] = "verbose_json"
|
||||||
r = await client.post(
|
r = await client.post(
|
||||||
f"{base_url}/v1/audio/transcriptions",
|
f"{base_url}/v1/audio/transcriptions",
|
||||||
files={"file": ("audio.wav", f, "audio/wav")},
|
files={"file": ("audio.wav", f, "audio/wav")},
|
||||||
data={"model": model_name, "language": language},
|
data=data,
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
return r.json()["text"]
|
body = r.json()
|
||||||
|
if not with_segments:
|
||||||
|
return body["text"]
|
||||||
|
raw_segs = body.get("segments") or []
|
||||||
|
if raw_segs:
|
||||||
|
return [
|
||||||
|
{"start": s["start"], "end": s["end"], "text": s["text"].strip()}
|
||||||
|
for s in raw_segs
|
||||||
|
]
|
||||||
|
return [{"start": 0.0, "end": 9999.0, "text": body["text"].strip()}]
|
||||||
|
|
||||||
async def _transcribe_local(
|
async def _transcribe_local(
|
||||||
self, audio_path: str, language: str, model_name: str, device: str
|
self,
|
||||||
) -> str:
|
audio_path: str,
|
||||||
|
language: str,
|
||||||
|
model_name: str,
|
||||||
|
device: str,
|
||||||
|
with_segments: bool,
|
||||||
|
) -> Union[str, list[dict]]:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
model = self._get_model(model_name, device)
|
model = self._get_model(model_name, device)
|
||||||
segments, _ = await loop.run_in_executor(
|
segments, _ = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: model.transcribe(audio_path, language=language),
|
lambda: model.transcribe(audio_path, language=language),
|
||||||
)
|
)
|
||||||
return "".join(seg.text for seg in segments).strip()
|
segments = list(segments)
|
||||||
|
if not with_segments:
|
||||||
|
return "".join(seg.text for seg in segments).strip()
|
||||||
|
return [
|
||||||
|
{"start": seg.start, "end": seg.end, "text": seg.text.strip()}
|
||||||
|
for seg in segments
|
||||||
|
if seg.text.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
engine = TranscriptionEngine()
|
engine = TranscriptionEngine()
|
||||||
|
|||||||
Reference in New Issue
Block a user