feat: transcription module — faster-whisper with ROCm auto-detect
This commit is contained in:
@@ -0,0 +1,36 @@
|
||||
import asyncio
|
||||
|
||||
|
||||
class TranscriptionEngine:
|
||||
_model = None
|
||||
|
||||
def _get_model(self, model_name: str = "large-v3", device: str = "auto"):
|
||||
if self._model is None:
|
||||
from faster_whisper import WhisperModel
|
||||
if device == "auto":
|
||||
try:
|
||||
self._model = WhisperModel(model_name, device="cuda", compute_type="float16")
|
||||
except Exception:
|
||||
self._model = WhisperModel(model_name, device="cpu", compute_type="int8")
|
||||
else:
|
||||
compute = "float16" if device in ("cuda", "rocm") else "int8"
|
||||
self._model = WhisperModel(model_name, device=device, compute_type=compute)
|
||||
return self._model
|
||||
|
||||
async def transcribe_file(
|
||||
self,
|
||||
audio_path: str,
|
||||
language: str = "de",
|
||||
model_name: str = "large-v3",
|
||||
device: str = "auto",
|
||||
) -> str:
|
||||
loop = asyncio.get_event_loop()
|
||||
model = self._get_model(model_name, device)
|
||||
segments, _ = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: model.transcribe(audio_path, language=language),
|
||||
)
|
||||
return "".join(seg.text for seg in segments).strip()
|
||||
|
||||
|
||||
engine = TranscriptionEngine()
|
||||
Reference in New Issue
Block a user