diff --git a/api/pipeline.py b/api/pipeline.py index 05e7265..fef7a3f 100644 --- a/api/pipeline.py +++ b/api/pipeline.py @@ -1,16 +1,18 @@ +import asyncio import logging import os import tempfile import traceback +from datetime import datetime from api.state import state, Status - -logger = logging.getLogger(__name__) +from api.router import broadcast from config import load as load_config from transcription import engine as transcription_engine from llm import OllamaClient -from output import save_transcript -from api.router import broadcast +from output import save_transcript, write_meeting_docs + +logger = logging.getLogger(__name__) async def run_pipeline(): @@ -21,6 +23,8 @@ async def run_pipeline(): output_dir = getattr(state, "_recording_output_dir", cfg["output"]["path"]) instructions = getattr(state, "_recording_instructions", "") + diar_cfg = cfg.get("diarization", {}) + use_diarization = diar_cfg.get("enabled") and diar_cfg.get("hf_token") recorder.stop() await state.set_status(Status.PROCESSING) @@ -32,36 +36,10 @@ async def run_pipeline(): wav_path = f.name recorder.save_wav(wav_path) - raw_text = await transcription_engine.transcribe_file( - wav_path, - language=cfg["whisper"]["language"], - model_name=cfg["whisper"]["model"], - device=cfg["whisper"]["device"], - base_url=cfg["whisper"].get("base_url", ""), - ) - await broadcast({"event": "transcribed", "raw": raw_text}) - - client = OllamaClient(base_url=cfg["ollama"]["base_url"]) - refined = await client.refine( - raw_text=raw_text, - instructions=instructions, - model=cfg["ollama"]["model"], - ) - await broadcast({"event": "refined", "markdown": refined}) - - title = "Diktat" - for line in refined.splitlines(): - if line.startswith("# "): - title = line[2:].strip() - break - - path = save_transcript( - title=title, - content=refined, - output_dir=output_dir, - ) - await broadcast({"event": "saved", "path": path, "title": title}) - await state.set_status(Status.IDLE) + if use_diarization: + await _run_meeting_pipeline(cfg, wav_path, output_dir, instructions, diar_cfg) + else: + await _run_solo_pipeline(cfg, wav_path, output_dir, instructions) except Exception as e: tb = traceback.format_exc() @@ -73,8 +51,124 @@ async def run_pipeline(): state.recording_user = None state._recording_output_dir = None state._recording_instructions = "" + state._speakers_event = None + state._pending_aligned_segments = None + state._speaker_names = None if wav_path: try: os.unlink(wav_path) except OSError: pass + + +async def _run_solo_pipeline(cfg, wav_path, output_dir, instructions): + """Original single-document pipeline (no diarization).""" + raw_text = await transcription_engine.transcribe_file( + wav_path, + language=cfg["whisper"]["language"], + model_name=cfg["whisper"]["model"], + device=cfg["whisper"]["device"], + base_url=cfg["whisper"].get("base_url", ""), + ) + await broadcast({"event": "transcribed", "raw": raw_text}) + + client = OllamaClient(base_url=cfg["ollama"]["base_url"]) + refined = await client.refine( + raw_text=raw_text, + instructions=instructions, + model=cfg["ollama"]["model"], + ) + + title = "Diktat" + for line in refined.splitlines(): + if line.startswith("# "): + title = line[2:].strip() + break + + path = save_transcript(title=title, content=refined, output_dir=output_dir) + await broadcast({"event": "saved", "path": path, "title": title}) + await state.set_status(Status.IDLE) + + +async def _run_meeting_pipeline(cfg, wav_path, output_dir, instructions, diar_cfg): + """Diarization pipeline: 3 documents, speaker identification.""" + from diarization import Diarizer + from alignment import align_segments + + diarizer = Diarizer(hf_token=diar_cfg["hf_token"]) + whisper_task = asyncio.create_task( + transcription_engine.transcribe_file( + wav_path, + language=cfg["whisper"]["language"], + model_name=cfg["whisper"]["model"], + device=cfg["whisper"]["device"], + base_url=cfg["whisper"].get("base_url", ""), + with_segments=True, + ) + ) + diar_task = asyncio.create_task(diarizer.diarize(wav_path)) + whisper_segs, speaker_segs = await asyncio.gather(whisper_task, diar_task) + + aligned = align_segments(whisper_segs, speaker_segs) + await broadcast({"event": "transcribed", "raw": " ".join(t for _, t in aligned)}) + + excerpt = "\n".join(f"{s}: {t}" for s, t in aligned[:20]) + client = OllamaClient(base_url=cfg["ollama"]["base_url"]) + name_map = await client.identify_speakers(excerpt, model=cfg["ollama"]["model"]) + + if not name_map: + excerpts_per_speaker = _build_excerpts(aligned) + state._speakers_event = asyncio.Event() + state._pending_aligned_segments = aligned + await state.set_status(Status.AWAITING_SPEAKERS) + await broadcast({"event": "speakers_unknown", "speakers": [ + {"id": spk, "excerpts": exs} + for spk, exs in excerpts_per_speaker.items() + ]}) + await state._speakers_event.wait() + name_map = state._speaker_names or {} + + def resolve(label): + name = name_map.get(label, "") + if name: + return name + num = label.replace("SPEAKER_", "").lstrip("0") or "1" + return f"Sprecher {num}" + + named_aligned = [(resolve(spk), text) for spk, text in aligned] + speakers = sorted({spk for spk, _ in named_aligned}) + + total_secs = sum(s["end"] - s["start"] for s in whisper_segs) if whisper_segs else 0 + duration_min = max(1, round(total_secs / 60)) + + transcript_text = "\n\n".join(f"**{spk}:** {txt}" for spk, txt in named_aligned) + summary = await client.summarize(transcript_text, model=cfg["ollama"]["model"]) + + dt = datetime.now() + paths = write_meeting_docs( + aligned_segments=named_aligned, + summary=summary, + speakers=speakers, + duration_min=duration_min, + output_dir=output_dir, + dt=dt, + ) + + await state.set_status(Status.IDLE) + await broadcast({ + "event": "saved", + "path": paths["index"], + "title": f"Meeting {dt.strftime('%d.%m.%Y %H:%M')}", + "meeting": True, + "paths": paths, + }) + + +def _build_excerpts(aligned: list[tuple[str, str]], max_per_speaker: int = 4) -> dict[str, list[str]]: + """Build a dict of speaker → list of text excerpts.""" + from collections import defaultdict + buckets: dict[str, list[str]] = defaultdict(list) + for spk, text in aligned: + if len(buckets[spk]) < max_per_speaker: + buckets[spk].append(text[:200]) + return dict(buckets)