From 3976ecb52ed100afdb09fdc588445e70b7c9028a Mon Sep 17 00:00:00 2001 From: "thomas.kopp" Date: Wed, 1 Apr 2026 02:22:03 +0200 Subject: [PATCH] =?UTF-8?q?feat:=20transcription=20module=20=E2=80=94=20fa?= =?UTF-8?q?ster-whisper=20with=20ROCm=20auto-detect?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_transcription.py | 25 +++++++++++++++++++++++++ transcription.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 tests/test_transcription.py create mode 100644 transcription.py diff --git a/tests/test_transcription.py b/tests/test_transcription.py new file mode 100644 index 0000000..e4b65bd --- /dev/null +++ b/tests/test_transcription.py @@ -0,0 +1,25 @@ +import asyncio +from unittest.mock import MagicMock + + +def test_transcription_engine_is_singleton(): + from transcription import engine, TranscriptionEngine + assert isinstance(engine, TranscriptionEngine) + + +def test_transcribe_file_calls_whisper(tmp_path): + wav = tmp_path / "test.wav" + wav.write_bytes(b"\x00" * 100) + + mock_model = MagicMock() + mock_segment = MagicMock() + mock_segment.text = " Hallo Welt" + mock_model.transcribe.return_value = ([mock_segment], MagicMock()) + + from transcription import TranscriptionEngine + eng = TranscriptionEngine() + eng._model = mock_model + + result = asyncio.run(eng.transcribe_file(str(wav), language="de")) + assert result == "Hallo Welt" + mock_model.transcribe.assert_called_once_with(str(wav), language="de") diff --git a/transcription.py b/transcription.py new file mode 100644 index 0000000..6de59d2 --- /dev/null +++ b/transcription.py @@ -0,0 +1,36 @@ +import asyncio + + +class TranscriptionEngine: + _model = None + + def _get_model(self, model_name: str = "large-v3", device: str = "auto"): + if self._model is None: + from faster_whisper import WhisperModel + if device == "auto": + try: + self._model = WhisperModel(model_name, device="cuda", compute_type="float16") + except Exception: + self._model = WhisperModel(model_name, device="cpu", compute_type="int8") + else: + compute = "float16" if device in ("cuda", "rocm") else "int8" + self._model = WhisperModel(model_name, device=device, compute_type=compute) + return self._model + + async def transcribe_file( + self, + audio_path: str, + language: str = "de", + model_name: str = "large-v3", + device: str = "auto", + ) -> str: + loop = asyncio.get_event_loop() + model = self._get_model(model_name, device) + segments, _ = await loop.run_in_executor( + None, + lambda: model.transcribe(audio_path, language=language), + ) + return "".join(seg.text for seg in segments).strip() + + +engine = TranscriptionEngine()