feat: remote Whisper via whisper.base_url — OpenAI-compatible upload
This commit is contained in:
@@ -37,6 +37,7 @@ async def run_pipeline():
|
|||||||
language=cfg["whisper"]["language"],
|
language=cfg["whisper"]["language"],
|
||||||
model_name=cfg["whisper"]["model"],
|
model_name=cfg["whisper"]["model"],
|
||||||
device=cfg["whisper"]["device"],
|
device=cfg["whisper"]["device"],
|
||||||
|
base_url=cfg["whisper"].get("base_url", ""),
|
||||||
)
|
)
|
||||||
await broadcast({"event": "transcribed", "raw": raw_text})
|
await broadcast({"event": "transcribed", "raw": raw_text})
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import pytest
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
|
||||||
@@ -23,3 +24,26 @@ def test_transcribe_file_calls_whisper(tmp_path):
|
|||||||
result = asyncio.run(eng.transcribe_file(str(wav), language="de"))
|
result = asyncio.run(eng.transcribe_file(str(wav), language="de"))
|
||||||
assert result == "Hallo Welt"
|
assert result == "Hallo Welt"
|
||||||
mock_model.transcribe.assert_called_once_with(str(wav), language="de")
|
mock_model.transcribe.assert_called_once_with(str(wav), language="de")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcribe_uses_remote_when_base_url_set(tmp_path):
|
||||||
|
import wave, struct
|
||||||
|
wav = tmp_path / "test.wav"
|
||||||
|
with wave.open(str(wav), "wb") as wf:
|
||||||
|
wf.setnchannels(1); wf.setsampwidth(2); wf.setframerate(16000)
|
||||||
|
wf.writeframes(struct.pack("<100h", *([0] * 100)))
|
||||||
|
|
||||||
|
import respx, httpx
|
||||||
|
from transcription import TranscriptionEngine
|
||||||
|
eng = TranscriptionEngine()
|
||||||
|
|
||||||
|
with respx.mock:
|
||||||
|
respx.post("http://beastix:8000/v1/audio/transcriptions").mock(
|
||||||
|
return_value=httpx.Response(200, json={"text": "Hallo Welt"})
|
||||||
|
)
|
||||||
|
result = await eng.transcribe_file(
|
||||||
|
str(wav), language="de", model_name="large-v3",
|
||||||
|
device="auto", base_url="http://beastix:8000",
|
||||||
|
)
|
||||||
|
assert result == "Hallo Welt"
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionEngine:
|
class TranscriptionEngine:
|
||||||
@@ -23,6 +24,27 @@ class TranscriptionEngine:
|
|||||||
language: str = "de",
|
language: str = "de",
|
||||||
model_name: str = "large-v3",
|
model_name: str = "large-v3",
|
||||||
device: str = "auto",
|
device: str = "auto",
|
||||||
|
base_url: str = "",
|
||||||
|
) -> str:
|
||||||
|
if base_url:
|
||||||
|
return await self._transcribe_remote(audio_path, language, model_name, base_url)
|
||||||
|
return await self._transcribe_local(audio_path, language, model_name, device)
|
||||||
|
|
||||||
|
async def _transcribe_remote(
|
||||||
|
self, audio_path: str, language: str, model_name: str, base_url: str
|
||||||
|
) -> str:
|
||||||
|
async with httpx.AsyncClient(timeout=300) as client:
|
||||||
|
with open(audio_path, "rb") as f:
|
||||||
|
r = await client.post(
|
||||||
|
f"{base_url}/v1/audio/transcriptions",
|
||||||
|
files={"file": ("audio.wav", f, "audio/wav")},
|
||||||
|
data={"model": model_name, "language": language},
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
return r.json()["text"]
|
||||||
|
|
||||||
|
async def _transcribe_local(
|
||||||
|
self, audio_path: str, language: str, model_name: str, device: str
|
||||||
) -> str:
|
) -> str:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
model = self._get_model(model_name, device)
|
model = self._get_model(model_name, device)
|
||||||
|
|||||||
Reference in New Issue
Block a user