feat: transcription module — faster-whisper with ROCm auto-detect
This commit is contained in:
@@ -0,0 +1,25 @@
|
|||||||
|
import asyncio
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
def test_transcription_engine_is_singleton():
|
||||||
|
from transcription import engine, TranscriptionEngine
|
||||||
|
assert isinstance(engine, TranscriptionEngine)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transcribe_file_calls_whisper(tmp_path):
|
||||||
|
wav = tmp_path / "test.wav"
|
||||||
|
wav.write_bytes(b"\x00" * 100)
|
||||||
|
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_segment = MagicMock()
|
||||||
|
mock_segment.text = " Hallo Welt"
|
||||||
|
mock_model.transcribe.return_value = ([mock_segment], MagicMock())
|
||||||
|
|
||||||
|
from transcription import TranscriptionEngine
|
||||||
|
eng = TranscriptionEngine()
|
||||||
|
eng._model = mock_model
|
||||||
|
|
||||||
|
result = asyncio.run(eng.transcribe_file(str(wav), language="de"))
|
||||||
|
assert result == "Hallo Welt"
|
||||||
|
mock_model.transcribe.assert_called_once_with(str(wav), language="de")
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionEngine:
|
||||||
|
_model = None
|
||||||
|
|
||||||
|
def _get_model(self, model_name: str = "large-v3", device: str = "auto"):
|
||||||
|
if self._model is None:
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
if device == "auto":
|
||||||
|
try:
|
||||||
|
self._model = WhisperModel(model_name, device="cuda", compute_type="float16")
|
||||||
|
except Exception:
|
||||||
|
self._model = WhisperModel(model_name, device="cpu", compute_type="int8")
|
||||||
|
else:
|
||||||
|
compute = "float16" if device in ("cuda", "rocm") else "int8"
|
||||||
|
self._model = WhisperModel(model_name, device=device, compute_type=compute)
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
async def transcribe_file(
|
||||||
|
self,
|
||||||
|
audio_path: str,
|
||||||
|
language: str = "de",
|
||||||
|
model_name: str = "large-v3",
|
||||||
|
device: str = "auto",
|
||||||
|
) -> str:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
model = self._get_model(model_name, device)
|
||||||
|
segments, _ = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: model.transcribe(audio_path, language=language),
|
||||||
|
)
|
||||||
|
return "".join(seg.text for seg in segments).strip()
|
||||||
|
|
||||||
|
|
||||||
|
engine = TranscriptionEngine()
|
||||||
Reference in New Issue
Block a user