diff --git a/tests/test_transcription.py b/tests/test_transcription.py new file mode 100644 index 0000000..e4b65bd --- /dev/null +++ b/tests/test_transcription.py @@ -0,0 +1,25 @@ +import asyncio +from unittest.mock import MagicMock + + +def test_transcription_engine_is_singleton(): + from transcription import engine, TranscriptionEngine + assert isinstance(engine, TranscriptionEngine) + + +def test_transcribe_file_calls_whisper(tmp_path): + wav = tmp_path / "test.wav" + wav.write_bytes(b"\x00" * 100) + + mock_model = MagicMock() + mock_segment = MagicMock() + mock_segment.text = " Hallo Welt" + mock_model.transcribe.return_value = ([mock_segment], MagicMock()) + + from transcription import TranscriptionEngine + eng = TranscriptionEngine() + eng._model = mock_model + + result = asyncio.run(eng.transcribe_file(str(wav), language="de")) + assert result == "Hallo Welt" + mock_model.transcribe.assert_called_once_with(str(wav), language="de") diff --git a/transcription.py b/transcription.py new file mode 100644 index 0000000..6de59d2 --- /dev/null +++ b/transcription.py @@ -0,0 +1,36 @@ +import asyncio + + +class TranscriptionEngine: + _model = None + + def _get_model(self, model_name: str = "large-v3", device: str = "auto"): + if self._model is None: + from faster_whisper import WhisperModel + if device == "auto": + try: + self._model = WhisperModel(model_name, device="cuda", compute_type="float16") + except Exception: + self._model = WhisperModel(model_name, device="cpu", compute_type="int8") + else: + compute = "float16" if device in ("cuda", "rocm") else "int8" + self._model = WhisperModel(model_name, device=device, compute_type=compute) + return self._model + + async def transcribe_file( + self, + audio_path: str, + language: str = "de", + model_name: str = "large-v3", + device: str = "auto", + ) -> str: + loop = asyncio.get_event_loop() + model = self._get_model(model_name, device) + segments, _ = await loop.run_in_executor( + None, + lambda: model.transcribe(audio_path, language=language), + ) + return "".join(seg.text for seg in segments).strip() + + +engine = TranscriptionEngine()