diff --git a/config.py b/config.py new file mode 100644 index 0000000..5d4d7ae --- /dev/null +++ b/config.py @@ -0,0 +1,58 @@ +import os +import tomllib + +CONFIG_PATH = os.path.expanduser("~/.config/tueit-transcriber/config.toml") + +DEFAULTS = { + "ollama": { + "base_url": "http://localhost:11434", + "model": "gemma3:12b", + }, + "whisper": { + "model": "large-v3", + "language": "de", + "device": "auto", # "auto" = use GPU if ROCm available, else CPU + }, + "server": { + "port": 8765, + }, + "output": { + "path": os.path.expanduser( + "~/cloud.shron.de/Hetzner Storagebox/work" + ), + }, + "pid_file": os.path.expanduser("~/.local/run/tueit-transcriber.pid"), +} + + +def load() -> dict: + os.makedirs(os.path.dirname(CONFIG_PATH), exist_ok=True) + if not os.path.exists(CONFIG_PATH): + _write_defaults() + with open(CONFIG_PATH, "rb") as f: + on_disk = tomllib.load(f) + return _deep_merge(DEFAULTS, on_disk) + + +def _deep_merge(base: dict, override: dict) -> dict: + result = dict(base) + for k, v in override.items(): + if k in result and isinstance(result[k], dict) and isinstance(v, dict): + result[k] = _deep_merge(result[k], v) + else: + result[k] = v + return result + + +def _write_defaults(): + try: + import tomli_w + with open(CONFIG_PATH, "wb") as f: + tomli_w.dump(DEFAULTS, f) + except ImportError: + with open(CONFIG_PATH, "w") as f: + f.write("# tüit Transkriptor config\n\n") + f.write('[ollama]\nbase_url = "http://localhost:11434"\nmodel = "gemma3:12b"\n\n') + f.write('[whisper]\nmodel = "large-v3"\nlanguage = "de"\ndevice = "auto"\n\n') + f.write('[server]\nport = 8765\n\n') + f.write(f'[output]\npath = "{DEFAULTS["output"]["path"]}"\n') diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..726f542 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,25 @@ +import os +import tempfile +from unittest.mock import patch + + +def test_config_loads_defaults(): + with tempfile.TemporaryDirectory() as tmpdir: + cfg_path = os.path.join(tmpdir, "config.toml") + with patch("config.CONFIG_PATH", cfg_path): + import importlib, config + importlib.reload(config) + cfg = config.load() + assert cfg["ollama"]["model"] == "gemma3:12b" + assert cfg["whisper"]["model"] == "large-v3" + assert cfg["server"]["port"] == 8765 + + +def test_config_creates_file_on_first_run(): + with tempfile.TemporaryDirectory() as tmpdir: + import importlib, config + importlib.reload(config) + cfg_path = os.path.join(tmpdir, "config.toml") + with patch("config.CONFIG_PATH", cfg_path): + config.load() + assert os.path.exists(cfg_path)