From 8300851e7720f67dffdf301c1b95949e0153458f Mon Sep 17 00:00:00 2001 From: "thomas.kopp" Date: Wed, 1 Apr 2026 20:28:31 +0200 Subject: [PATCH] =?UTF-8?q?feat:=20remote=20Whisper=20via=20whisper.base?= =?UTF-8?q?=5Furl=20=E2=80=94=20OpenAI-compatible=20upload?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/pipeline.py | 1 + tests/test_transcription.py | 24 ++++++++++++++++++++++++ transcription.py | 22 ++++++++++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/api/pipeline.py b/api/pipeline.py index 3dadf5c..05e7265 100644 --- a/api/pipeline.py +++ b/api/pipeline.py @@ -37,6 +37,7 @@ async def run_pipeline(): language=cfg["whisper"]["language"], model_name=cfg["whisper"]["model"], device=cfg["whisper"]["device"], + base_url=cfg["whisper"].get("base_url", ""), ) await broadcast({"event": "transcribed", "raw": raw_text}) diff --git a/tests/test_transcription.py b/tests/test_transcription.py index e4b65bd..ff3f497 100644 --- a/tests/test_transcription.py +++ b/tests/test_transcription.py @@ -1,4 +1,5 @@ import asyncio +import pytest from unittest.mock import MagicMock @@ -23,3 +24,26 @@ def test_transcribe_file_calls_whisper(tmp_path): result = asyncio.run(eng.transcribe_file(str(wav), language="de")) assert result == "Hallo Welt" mock_model.transcribe.assert_called_once_with(str(wav), language="de") + + +@pytest.mark.asyncio +async def test_transcribe_uses_remote_when_base_url_set(tmp_path): + import wave, struct + wav = tmp_path / "test.wav" + with wave.open(str(wav), "wb") as wf: + wf.setnchannels(1); wf.setsampwidth(2); wf.setframerate(16000) + wf.writeframes(struct.pack("<100h", *([0] * 100))) + + import respx, httpx + from transcription import TranscriptionEngine + eng = TranscriptionEngine() + + with respx.mock: + respx.post("http://beastix:8000/v1/audio/transcriptions").mock( + return_value=httpx.Response(200, json={"text": "Hallo Welt"}) + ) + result = await eng.transcribe_file( + str(wav), language="de", model_name="large-v3", + device="auto", base_url="http://beastix:8000", + ) + assert result == "Hallo Welt" diff --git a/transcription.py b/transcription.py index 6de59d2..52d1b91 100644 --- a/transcription.py +++ b/transcription.py @@ -1,4 +1,5 @@ import asyncio +import httpx class TranscriptionEngine: @@ -23,6 +24,27 @@ class TranscriptionEngine: language: str = "de", model_name: str = "large-v3", device: str = "auto", + base_url: str = "", + ) -> str: + if base_url: + return await self._transcribe_remote(audio_path, language, model_name, base_url) + return await self._transcribe_local(audio_path, language, model_name, device) + + async def _transcribe_remote( + self, audio_path: str, language: str, model_name: str, base_url: str + ) -> str: + async with httpx.AsyncClient(timeout=300) as client: + with open(audio_path, "rb") as f: + r = await client.post( + f"{base_url}/v1/audio/transcriptions", + files={"file": ("audio.wav", f, "audio/wav")}, + data={"model": model_name, "language": language}, + ) + r.raise_for_status() + return r.json()["text"] + + async def _transcribe_local( + self, audio_path: str, language: str, model_name: str, device: str ) -> str: loop = asyncio.get_event_loop() model = self._get_model(model_name, device)