diff --git a/audio.py b/audio.py new file mode 100644 index 0000000..5b64345 --- /dev/null +++ b/audio.py @@ -0,0 +1,46 @@ +import wave +import threading +import numpy as np + + +class AudioRecorder: + def __init__(self, sample_rate: int = 16000): + self.sample_rate = sample_rate + self._buffer: list[np.ndarray] = [] + self._stream = None + self.is_recording = False + self._lock = threading.Lock() + + def _callback(self, indata, frames, time, status): + if self.is_recording: + with self._lock: + self._buffer.append(indata[:, 0].copy().astype(np.int16)) + + def start(self): + import sounddevice as sd + self._buffer = [] + self.is_recording = True + self._stream = sd.InputStream( + samplerate=self.sample_rate, + channels=1, + dtype="int16", + callback=self._callback, + ) + self._stream.start() + + def stop(self): + self.is_recording = False + if self._stream: + self._stream.stop() + self._stream.close() + self._stream = None + + def save_wav(self, path: str) -> str: + with self._lock: + data = np.concatenate(self._buffer) if self._buffer else np.zeros(0, dtype=np.int16) + with wave.open(path, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(self.sample_rate) + wf.writeframes(data.tobytes()) + return path diff --git a/tests/test_audio.py b/tests/test_audio.py new file mode 100644 index 0000000..fef3f84 --- /dev/null +++ b/tests/test_audio.py @@ -0,0 +1,29 @@ +import numpy as np +from unittest.mock import patch, MagicMock + + +def test_recorder_starts_and_stops(): + from audio import AudioRecorder + with patch("sounddevice.InputStream") as MockStream: + mock_stream = MagicMock() + MockStream.return_value.start = MagicMock() + MockStream.return_value.stop = MagicMock() + MockStream.return_value.close = MagicMock() + recorder = AudioRecorder(sample_rate=16000) + assert not recorder.is_recording + recorder._stream = MockStream.return_value + recorder.is_recording = True + recorder.stop() + assert not recorder.is_recording + + +def test_recorder_save_wav(tmp_path): + import wave + from audio import AudioRecorder + recorder = AudioRecorder(sample_rate=16000) + recorder._buffer = [np.zeros(1600, dtype=np.int16)] + out = str(tmp_path / "test.wav") + recorder.save_wav(out) + with wave.open(out) as wf: + assert wf.getframerate() == 16000 + assert wf.getnchannels() == 1