177 lines
6.1 KiB
Python
177 lines
6.1 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
import traceback
|
|
from datetime import datetime
|
|
|
|
from api.state import state, Status
|
|
from api.router import broadcast
|
|
from config import load as load_config
|
|
from transcription import engine as transcription_engine
|
|
from llm import OllamaClient
|
|
from output import save_transcript, write_meeting_docs
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def run_pipeline():
|
|
cfg = load_config()
|
|
recorder = getattr(state, "_recorder", None)
|
|
if recorder is None:
|
|
return
|
|
|
|
output_dir = getattr(state, "_recording_output_dir", cfg["output"]["path"])
|
|
instructions = getattr(state, "_recording_instructions", "")
|
|
diar_cfg = cfg.get("diarization", {})
|
|
use_diarization = diar_cfg.get("enabled") and diar_cfg.get("hf_token")
|
|
|
|
recorder.stop()
|
|
await state.set_status(Status.PROCESSING)
|
|
await broadcast({"event": "processing"})
|
|
|
|
wav_path = None
|
|
try:
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
|
wav_path = f.name
|
|
recorder.save_wav(wav_path)
|
|
|
|
if use_diarization:
|
|
await _run_meeting_pipeline(cfg, wav_path, output_dir, instructions, diar_cfg)
|
|
else:
|
|
await _run_solo_pipeline(cfg, wav_path, output_dir, instructions)
|
|
|
|
except Exception as e:
|
|
tb = traceback.format_exc()
|
|
logger.error("Pipeline error:\n%s", tb)
|
|
state.last_error = str(e)
|
|
await state.set_status(Status.ERROR)
|
|
await broadcast({"event": "error", "message": str(e)})
|
|
finally:
|
|
state.recording_user = None
|
|
state._recording_output_dir = None
|
|
state._recording_instructions = ""
|
|
state._speakers_event = None
|
|
state._pending_aligned_segments = None
|
|
state._speaker_names = None
|
|
if wav_path:
|
|
try:
|
|
os.unlink(wav_path)
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
async def _run_solo_pipeline(cfg, wav_path, output_dir, instructions):
|
|
"""Original single-document pipeline (no diarization)."""
|
|
raw_text = await transcription_engine.transcribe_file(
|
|
wav_path,
|
|
language=cfg["whisper"]["language"],
|
|
model_name=cfg["whisper"]["model"],
|
|
device=cfg["whisper"]["device"],
|
|
base_url=cfg["whisper"].get("base_url", ""),
|
|
backend=cfg["whisper"].get("backend", "openai"),
|
|
)
|
|
await broadcast({"event": "transcribed", "raw": raw_text})
|
|
|
|
client = OllamaClient(base_url=cfg["ollama"]["base_url"])
|
|
refined = await client.refine(
|
|
raw_text=raw_text,
|
|
instructions=instructions,
|
|
model=cfg["ollama"]["model"],
|
|
)
|
|
|
|
title = "Diktat"
|
|
for line in refined.splitlines():
|
|
if line.startswith("# "):
|
|
title = line[2:].strip()
|
|
break
|
|
|
|
path = save_transcript(title=title, content=refined, output_dir=output_dir)
|
|
await broadcast({"event": "saved", "path": path, "title": title})
|
|
await state.set_status(Status.IDLE)
|
|
|
|
|
|
async def _run_meeting_pipeline(cfg, wav_path, output_dir, instructions, diar_cfg):
|
|
"""Diarization pipeline: 3 documents, speaker identification."""
|
|
from diarization import Diarizer
|
|
from alignment import align_segments
|
|
|
|
diarizer = Diarizer(hf_token=diar_cfg["hf_token"])
|
|
whisper_task = asyncio.create_task(
|
|
transcription_engine.transcribe_file(
|
|
wav_path,
|
|
language=cfg["whisper"]["language"],
|
|
model_name=cfg["whisper"]["model"],
|
|
device=cfg["whisper"]["device"],
|
|
base_url=cfg["whisper"].get("base_url", ""),
|
|
backend=cfg["whisper"].get("backend", "openai"),
|
|
with_segments=True,
|
|
)
|
|
)
|
|
diar_task = asyncio.create_task(diarizer.diarize(wav_path))
|
|
whisper_segs, speaker_segs = await asyncio.gather(whisper_task, diar_task)
|
|
|
|
aligned = align_segments(whisper_segs, speaker_segs)
|
|
await broadcast({"event": "transcribed", "raw": " ".join(t for _, t in aligned)})
|
|
|
|
excerpt = "\n".join(f"{s}: {t}" for s, t in aligned[:20])
|
|
client = OllamaClient(base_url=cfg["ollama"]["base_url"])
|
|
name_map = await client.identify_speakers(excerpt, model=cfg["ollama"]["model"])
|
|
|
|
if not name_map:
|
|
excerpts_per_speaker = _build_excerpts(aligned)
|
|
state._speakers_event = asyncio.Event()
|
|
state._pending_aligned_segments = aligned
|
|
await state.set_status(Status.AWAITING_SPEAKERS)
|
|
await broadcast({"event": "speakers_unknown", "speakers": [
|
|
{"id": spk, "excerpts": exs}
|
|
for spk, exs in excerpts_per_speaker.items()
|
|
]})
|
|
await state._speakers_event.wait()
|
|
name_map = state._speaker_names or {}
|
|
|
|
def resolve(label):
|
|
name = name_map.get(label, "")
|
|
if name:
|
|
return name
|
|
num = label.replace("SPEAKER_", "").lstrip("0") or "1"
|
|
return f"Sprecher {num}"
|
|
|
|
named_aligned = [(resolve(spk), text) for spk, text in aligned]
|
|
speakers = sorted({spk for spk, _ in named_aligned})
|
|
|
|
total_secs = sum(s["end"] - s["start"] for s in whisper_segs) if whisper_segs else 0
|
|
duration_min = max(1, round(total_secs / 60))
|
|
|
|
transcript_text = "\n\n".join(f"**{spk}:** {txt}" for spk, txt in named_aligned)
|
|
summary = await client.summarize(transcript_text, model=cfg["ollama"]["model"])
|
|
|
|
dt = datetime.now()
|
|
paths = write_meeting_docs(
|
|
aligned_segments=named_aligned,
|
|
summary=summary,
|
|
speakers=speakers,
|
|
duration_min=duration_min,
|
|
output_dir=output_dir,
|
|
dt=dt,
|
|
)
|
|
|
|
await state.set_status(Status.IDLE)
|
|
await broadcast({
|
|
"event": "saved",
|
|
"path": paths["index"],
|
|
"title": f"Meeting {dt.strftime('%d.%m.%Y %H:%M')}",
|
|
"meeting": True,
|
|
"paths": paths,
|
|
})
|
|
|
|
|
|
def _build_excerpts(aligned: list[tuple[str, str]], max_per_speaker: int = 4) -> dict[str, list[str]]:
|
|
"""Build a dict of speaker → list of text excerpts."""
|
|
from collections import defaultdict
|
|
buckets: dict[str, list[str]] = defaultdict(list)
|
|
for spk, text in aligned:
|
|
if len(buckets[spk]) < max_per_speaker:
|
|
buckets[spk].append(text[:200])
|
|
return dict(buckets)
|