feat: remote Whisper via whisper.base_url — OpenAI-compatible upload
This commit is contained in:
@@ -37,6 +37,7 @@ async def run_pipeline():
|
||||
language=cfg["whisper"]["language"],
|
||||
model_name=cfg["whisper"]["model"],
|
||||
device=cfg["whisper"]["device"],
|
||||
base_url=cfg["whisper"].get("base_url", ""),
|
||||
)
|
||||
await broadcast({"event": "transcribed", "raw": raw_text})
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
@@ -23,3 +24,26 @@ def test_transcribe_file_calls_whisper(tmp_path):
|
||||
result = asyncio.run(eng.transcribe_file(str(wav), language="de"))
|
||||
assert result == "Hallo Welt"
|
||||
mock_model.transcribe.assert_called_once_with(str(wav), language="de")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcribe_uses_remote_when_base_url_set(tmp_path):
|
||||
import wave, struct
|
||||
wav = tmp_path / "test.wav"
|
||||
with wave.open(str(wav), "wb") as wf:
|
||||
wf.setnchannels(1); wf.setsampwidth(2); wf.setframerate(16000)
|
||||
wf.writeframes(struct.pack("<100h", *([0] * 100)))
|
||||
|
||||
import respx, httpx
|
||||
from transcription import TranscriptionEngine
|
||||
eng = TranscriptionEngine()
|
||||
|
||||
with respx.mock:
|
||||
respx.post("http://beastix:8000/v1/audio/transcriptions").mock(
|
||||
return_value=httpx.Response(200, json={"text": "Hallo Welt"})
|
||||
)
|
||||
result = await eng.transcribe_file(
|
||||
str(wav), language="de", model_name="large-v3",
|
||||
device="auto", base_url="http://beastix:8000",
|
||||
)
|
||||
assert result == "Hallo Welt"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
|
||||
class TranscriptionEngine:
|
||||
@@ -23,6 +24,27 @@ class TranscriptionEngine:
|
||||
language: str = "de",
|
||||
model_name: str = "large-v3",
|
||||
device: str = "auto",
|
||||
base_url: str = "",
|
||||
) -> str:
|
||||
if base_url:
|
||||
return await self._transcribe_remote(audio_path, language, model_name, base_url)
|
||||
return await self._transcribe_local(audio_path, language, model_name, device)
|
||||
|
||||
async def _transcribe_remote(
|
||||
self, audio_path: str, language: str, model_name: str, base_url: str
|
||||
) -> str:
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
with open(audio_path, "rb") as f:
|
||||
r = await client.post(
|
||||
f"{base_url}/v1/audio/transcriptions",
|
||||
files={"file": ("audio.wav", f, "audio/wav")},
|
||||
data={"model": model_name, "language": language},
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()["text"]
|
||||
|
||||
async def _transcribe_local(
|
||||
self, audio_path: str, language: str, model_name: str, device: str
|
||||
) -> str:
|
||||
loop = asyncio.get_event_loop()
|
||||
model = self._get_model(model_name, device)
|
||||
|
||||
Reference in New Issue
Block a user