diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 521f30013d..217f5ba693 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -156,7 +156,16 @@ def get_binary_description(self): class BinaryRecordingSegment(BaseRecordingSegment): def __init__(self, datfile, sampling_frequency, t_start, num_chan, dtype, time_axis, file_offset): BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency, t_start=t_start) - self._timeseries = read_binary_recording(datfile, num_chan, dtype, time_axis, file_offset) + self.num_chan = num_chan + self.dtype = np.dtype(dtype) + self.file_offset = file_offset + self.time_axis = time_axis + self.datfile = datfile + self.num_samples = (Path(datfile).stat().st_size - file_offset) // (num_chan * np.dtype(dtype).itemsize) + if self.time_axis == 0: + self.shape = (self.num_samples, self.num_chan) + else: + self.shape = (self.num_chan, self.num_samples) def get_num_samples(self) -> int: """Returns the number of samples in this signal block @@ -164,7 +173,7 @@ def get_num_samples(self) -> int: Returns: SampleIndex: Number of samples in the signal block """ - return self._timeseries.shape[0] + return self.num_samples def get_traces( self, @@ -172,9 +181,14 @@ def get_traces( end_frame: Union[int, None] = None, channel_indices: Union[List, None] = None, ) -> np.ndarray: - traces = self._timeseries[start_frame:end_frame] + data = np.memmap(self.datfile, self.dtype, mode="r", offset=self.file_offset, shape=self.shape) + if self.time_axis == 1: + data = data.T + + traces = data[start_frame:end_frame] if channel_indices is not None: traces = traces[:, channel_indices] + return traces diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index e8815e1c6b..634c2903d2 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -3,7 +3,7 @@ from pathlib import Path from spikeinterface.core import BinaryRecordingExtractor - +from spikeinterface.core.numpyextractors import NumpyRecording if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "core" @@ -32,5 +32,25 @@ def test_BinaryRecordingExtractor(): assert (cache_folder / "test_BinaryRecordingExtractor_copied_0.raw").is_file() +def test_round_trip(tmp_path): + num_channels = 10 + num_samples = 50 + traces_list = [np.ones(shape=(num_samples, num_channels), dtype="int32")] + sampling_frequency = 30_000.0 + recording = NumpyRecording(traces_list=traces_list, sampling_frequency=sampling_frequency) + + file_path = tmp_path / "test_BinaryRecordingExtractor.raw" + dtype = recording.get_dtype() + BinaryRecordingExtractor.write_recording(recording=recording, dtype=dtype, file_paths=file_path) + + sampling_frequency = recording.get_sampling_frequency() + num_chan = recording.get_num_channels() + binary_recorder = BinaryRecordingExtractor( + file_paths=file_path, sampling_frequency=sampling_frequency, num_chan=num_chan, dtype=dtype + ) + + assert np.allclose(recording.get_traces(), binary_recorder.get_traces()) + + if __name__ == "__main__": test_BinaryRecordingExtractor()