From 5d4ce51f3463b636a4371852d9046180a7d1a169 Mon Sep 17 00:00:00 2001 From: Julia Sprenger Date: Tue, 9 Jan 2018 18:40:59 +0100 Subject: [PATCH 1/7] Reorder waveform dimensions to (time,spike,channel) to match dimension order of AnalogSignals previous dimension order was (spike,channel,time) --- neo/core/spiketrain.py | 27 ++++++++------- neo/test/coretest/test_spiketrain.py | 49 ++++++++++++++++------------ neo/test/generate_datasets.py | 12 +++---- 3 files changed, 48 insertions(+), 40 deletions(-) diff --git a/neo/core/spiketrain.py b/neo/core/spiketrain.py index f3cab1667..8d97be12b 100644 --- a/neo/core/spiketrain.py +++ b/neo/core/spiketrain.py @@ -73,7 +73,7 @@ def _check_time_in_range(value, t_start, t_stop, view=False): def _check_waveform_dimensions(spiketrain): ''' Verify that waveform is compliant with the waveform definition as - quantity array 3D (spike, channel_index, time) + quantity array 3D (time, spike, channel_index) ''' if not spiketrain.size: @@ -84,10 +84,10 @@ def _check_waveform_dimensions(spiketrain): if (waveforms is None) or (not waveforms.size): return - if waveforms.shape[0] != len(spiketrain): + if waveforms.shape[1] != len(spiketrain): raise ValueError("Spiketrain length (%s) does not match to number of " "waveforms present (%s)" % (len(spiketrain), - waveforms.shape[0])) + waveforms.shape[1])) def _new_spiketrain(cls, signal, t_stop, units=None, dtype=None, @@ -158,7 +158,7 @@ class SpikeTrain(BaseNeo, pq.Quantity): :class:`SpikeTrain` began. This will be converted to the same units as :attr:`times`. Default: 0.0 seconds. - :waveforms: (quantity array 3D (spike, channel_index, time)) + :waveforms: (quantity array 3D (time, spike, channel_index)) The waveforms of each spike. :sampling_rate: (quantity scalar) Number of samples per unit time for the waveforms. @@ -181,7 +181,7 @@ class SpikeTrain(BaseNeo, pq.Quantity): read-only. (:attr:`t_stop` - :attr:`t_start`) :spike_duration: (quantity scalar) Duration of a waveform, read-only. - (:attr:`waveform`.shape[2] * :attr:`sampling_period`) + (:attr:`waveform`.shape[0] * :attr:`sampling_period`) :right_sweep: (quantity scalar) Time from the trigger times of the spikes to the end of the waveforms, read-only. (:attr:`left_sweep` + :attr:`spike_duration`) @@ -218,8 +218,7 @@ def __new__(cls, times, t_stop, units=None, dtype=None, copy=True, constructor, but not when slicing. ''' if len(times) != 0 and waveforms is not None and len(times) != \ - waveforms.shape[ - 0]: # len(times)!=0 has been used to workaround a bug occuring during neo import) + waveforms.shape[1]: raise ValueError( "the number of waveforms should be equal to the number of spikes") @@ -433,7 +432,7 @@ def sort(self): # sort the waveforms by the times sort_indices = np.argsort(self) if self.waveforms is not None and self.waveforms.any(): - self.waveforms = self.waveforms[sort_indices] + self.waveforms = self.waveforms[:,sort_indices,:] # now sort the times # We have sorted twice, but `self = self[sort_indices]` introduces @@ -490,7 +489,7 @@ def __getitem__(self, i): ''' obj = super(SpikeTrain, self).__getitem__(i) if hasattr(obj, 'waveforms') and obj.waveforms is not None: - obj.waveforms = obj.waveforms.__getitem__(i) + obj.waveforms = obj.waveforms.__getitem__([slice(None),i,slice(None)]) return obj def __setitem__(self, i, value): @@ -568,7 +567,7 @@ def time_slice(self, t_start, t_stop): new_st.t_start = max(_t_start, self.t_start) new_st.t_stop = min(_t_stop, self.t_stop) if self.waveforms is not None: - new_st.waveforms = self.waveforms[indices] + new_st.waveforms = self.waveforms[:,indices,:] return new_st @@ -625,8 +624,8 @@ def merge(self, other): sampling_rate=self.sampling_rate, left_sweep=self.left_sweep, **kwargs) if all(wfs): - wfs_stack = np.vstack((self.waveforms, other.waveforms)) - wfs_stack = wfs_stack[sorting] + wfs_stack = np.concatenate((self.waveforms, other.waveforms),axis=1) + wfs_stack = wfs_stack[:,sorting,:] train.waveforms = wfs_stack train.segment = self.segment if train.segment is not None: @@ -659,11 +658,11 @@ def spike_duration(self): ''' Duration of a waveform. - (:attr:`waveform`.shape[2] * :attr:`sampling_period`) + (:attr:`waveform`.shape[0] * :attr:`sampling_period`) ''' if self.waveforms is None or self.sampling_rate is None: return None - return self.waveforms.shape[2] / self.sampling_rate + return self.waveforms.shape[0] / self.sampling_rate @property def sampling_period(self): diff --git a/neo/test/coretest/test_spiketrain.py b/neo/test/coretest/test_spiketrain.py index 6e1fe40af..3e27c98c1 100644 --- a/neo/test/coretest/test_spiketrain.py +++ b/neo/test/coretest/test_spiketrain.py @@ -41,9 +41,9 @@ def setUp(self): def test__get_fake_values(self): self.annotations['seed'] = 0 waveforms = get_fake_value('waveforms', pq.Quantity, seed=3, dim=3) - shape = waveforms.shape[0] + shape = waveforms.shape[1] times = get_fake_value('times', pq.Quantity, seed=0, dim=1, - shape=waveforms.shape[0]) + shape=waveforms.shape[1]) t_start = get_fake_value('t_start', pq.Quantity, seed=1, dim=0) t_stop = get_fake_value('t_stop', pq.Quantity, seed=2, dim=0) left_sweep = get_fake_value('left_sweep', pq.Quantity, seed=4, dim=0) @@ -815,6 +815,7 @@ def test_tstop_units_conversion(self): class TestSorting(unittest.TestCase): def test_sort(self): waveforms = np.array([[[0., 1.]], [[2., 3.]], [[4., 5.]]]) * pq.mV + waveforms = np.moveaxis(waveforms, 2, 0) train = SpikeTrain([3, 4, 5] * pq.s, waveforms=waveforms, name='n', t_stop=10.0) assert_neo_object_is_compliant(train) @@ -831,7 +832,7 @@ def test_sort(self): train.sort() assert_neo_object_is_compliant(train) assert_arrays_equal(train, [3, 4, 5] * pq.s) - assert_arrays_equal(train.waveforms, waveforms[[0, 2, 1]]) + assert_arrays_equal(train.waveforms, waveforms[:,[0, 2, 1],:]) self.assertEqual(train.name, 'n') self.assertEqual(train.t_start, 0.0 * pq.s) self.assertEqual(train.t_stop, 10.0 * pq.s) @@ -845,6 +846,7 @@ def setUp(self): [2.1, 3.1]], [[4., 5.], [4.1, 5.1]]]) * pq.mV + self.waveforms1 = np.moveaxis(self.waveforms1,2,0) self.data1 = np.array([3, 4, 5]) self.data1quant = self.data1 * pq.s self.train1 = SpikeTrain(self.data1quant, waveforms=self.waveforms1, @@ -859,6 +861,7 @@ def test_slice(self): assert_arrays_equal(self.train1[1:2], result) targwaveforms = np.array([[[2., 3.], [2.1, 3.1]]]) * pq.mV + targwaveforms = np.moveaxis(targwaveforms, 2, 0) # but keep everything else pristine assert_neo_object_is_compliant(result) @@ -871,7 +874,7 @@ def test_slice(self): self.assertEqual(self.train1.t_stop, result.t_stop) # except we update the waveforms - assert_arrays_equal(self.train1.waveforms[1:2], result.waveforms) + assert_arrays_equal(self.train1.waveforms[:,1:2,:], result.waveforms) assert_arrays_equal(targwaveforms, result.waveforms) def test_slice_to_end(self): @@ -882,6 +885,7 @@ def test_slice_to_end(self): [2.1, 3.1]], [[4., 5.], [4.1, 5.1]]]) * pq.mV + targwaveforms = np.moveaxis(targwaveforms, 2, 0) # but keep everything else pristine assert_neo_object_is_compliant(result) @@ -894,7 +898,7 @@ def test_slice_to_end(self): self.assertEqual(self.train1.t_stop, result.t_stop) # except we update the waveforms - assert_arrays_equal(self.train1.waveforms[1:], result.waveforms) + assert_arrays_equal(self.train1.waveforms[:,1:,:], result.waveforms) assert_arrays_equal(targwaveforms, result.waveforms) def test_slice_from_beginning(self): @@ -905,6 +909,7 @@ def test_slice_from_beginning(self): [0.1, 1.1]], [[2., 3.], [2.1, 3.1]]]) * pq.mV + targwaveforms = np.moveaxis(targwaveforms, 2, 0) # but keep everything else pristine assert_neo_object_is_compliant(result) @@ -917,7 +922,7 @@ def test_slice_from_beginning(self): self.assertEqual(self.train1.t_stop, result.t_stop) # except we update the waveforms - assert_arrays_equal(self.train1.waveforms[:2], result.waveforms) + assert_arrays_equal(self.train1.waveforms[:,:2,:], result.waveforms) assert_arrays_equal(targwaveforms, result.waveforms) def test_slice_negative_idxs(self): @@ -928,6 +933,7 @@ def test_slice_negative_idxs(self): [0.1, 1.1]], [[2., 3.], [2.1, 3.1]]]) * pq.mV + targwaveforms = np.moveaxis(targwaveforms, 2, 0) # but keep everything else pristine assert_neo_object_is_compliant(result) @@ -940,7 +946,7 @@ def test_slice_negative_idxs(self): self.assertEqual(self.train1.t_stop, result.t_stop) # except we update the waveforms - assert_arrays_equal(self.train1.waveforms[:-1], result.waveforms) + assert_arrays_equal(self.train1.waveforms[:,:-1,:], result.waveforms) assert_arrays_equal(targwaveforms, result.waveforms) @@ -958,6 +964,7 @@ def setUp(self): [8.1, 9.1]], [[10., 11.], [10.1, 11.1]]]) * pq.mV + self.waveforms1 = np.moveaxis(self.waveforms1,2,0) self.data1 = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7]) self.data1quant = self.data1 * pq.ms self.train1 = SpikeTrain(self.data1quant, t_stop=10.0 * pq.ms, @@ -981,6 +988,7 @@ def test_time_slice_typical(self): [4.1, 5.1]], [[6., 7.], [6.1, 7.1]]]) * pq.mV + targwaveforms = np.moveaxis(targwaveforms, 2, 0) assert_arrays_equal(targwaveforms, result.waveforms) # but keep everything else pristine @@ -1006,6 +1014,7 @@ def test_time_slice_differnt_units(self): [4.1, 5.1]], [[6., 7.], [6.1, 7.1]]]) * pq.mV + targwaveforms = np.moveaxis(targwaveforms, 2, 0) assert_arrays_equal(targwaveforms, result.waveforms) # but keep everything else pristine @@ -1067,7 +1076,7 @@ def test_time_slice_empty(self): t_stop = 70.0 * pq.ms result = train.time_slice(t_start, t_stop) assert_arrays_equal(train, result) - assert_arrays_equal(waveforms[:-1], result.waveforms) + assert_arrays_equal(waveforms[:,:-1,:], result.waveforms) # but keep everything else pristine assert_neo_object_is_compliant(result) @@ -1092,6 +1101,7 @@ def test_time_slice_none_stop(self): [8.1, 9.1]], [[10., 11.], [10.1, 11.1]]]) * pq.mV + targwaveforms = np.moveaxis(targwaveforms, 2, 0) assert_arrays_equal(targwaveforms, result.waveforms) # but keep everything else pristine @@ -1113,6 +1123,7 @@ def test_time_slice_none_start(self): [0.1, 1.1]], [[2., 3.], [2.1, 3.1]]]) * pq.mV + targwaveforms = np.moveaxis(targwaveforms, 2, 0) assert_arrays_equal(targwaveforms, result.waveforms) # but keep everything else pristine @@ -1159,6 +1170,7 @@ def setUp(self): [8.1, 9.1]], [[10., 11.], [10.1, 11.1]]]) * pq.mV + self.waveforms1 = np.moveaxis(self.waveforms1,2,0) self.data1 = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7]) self.data1quant = self.data1 * pq.ms self.train1 = SpikeTrain(self.data1quant, t_stop=10.0 * pq.ms, @@ -1176,6 +1188,7 @@ def setUp(self): [8.1, 9.1]], [[10., 11.], [10.1, 11.1]]]) * pq.mV + self.waveforms2 = np.moveaxis(self.waveforms2,2,0) self.data2 = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7]) self.data2quant = self.data1 * pq.ms self.train2 = SpikeTrain(self.data1quant, t_stop=10.0 * pq.ms, @@ -1253,18 +1266,13 @@ def test_incompatible_t_start(self): class TestDuplicateWithNewData(unittest.TestCase): def setUp(self): - self.waveforms = np.array([[[0., 1.], - [0.1, 1.1]], - [[2., 3.], - [2.1, 3.1]], - [[4., 5.], - [4.1, 5.1]], - [[6., 7.], - [6.1, 7.1]], - [[8., 9.], - [8.1, 9.1]], - [[10., 11.], - [10.1, 11.1]]]) * pq.mV + self.waveforms = np.array([[[0., 1.], [0.1, 1.1]], + [[2., 3.], [2.1, 3.1]], + [[4., 5.], [4.1, 5.1]], + [[6., 7.], [6.1, 7.1]], + [[8., 9.], [8.1, 9.1]], + [[10., 11.], [10.1, 11.1]]]) * pq.mV + self.waveforms = np.moveaxis(self.waveforms,2,0) self.data = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7]) self.dataquant = self.data * pq.ms self.train = SpikeTrain(self.dataquant, t_stop=10.0 * pq.ms, @@ -1540,6 +1548,7 @@ def setUp(self): [2.1, 3.1]], [[4., 5.], [4.1, 5.1]]]) * pq.mV + self.waveforms1 = np.moveaxis(self.waveforms1,2,0) self.t_start1 = 0.5 self.t_stop1 = 10.0 self.t_start1quant = self.t_start1 * pq.ms diff --git a/neo/test/generate_datasets.py b/neo/test/generate_datasets.py index 6f9dc6707..6a47757cb 100644 --- a/neo/test/generate_datasets.py +++ b/neo/test/generate_datasets.py @@ -284,21 +284,21 @@ def get_fake_values(cls, annotate=True, seed=None, n=None): iseed = None kwargs[attr[0]] = get_fake_value(*attr, seed=iseed, obj=cls, n=n) - if 'waveforms' in kwargs : #everything here is to force the kwargs to have len(time) == kwargs["waveforms"].shape[0] - if len(kwargs["times"]) != kwargs["waveforms"].shape[0] : - if len(kwargs["times"]) < kwargs["waveforms"].shape[0] : + if 'waveforms' in kwargs : #everything here is to force the kwargs to have len(time) == kwargs["waveforms"].shape[1] + if len(kwargs["times"]) != kwargs["waveforms"].shape[1] : + if len(kwargs["times"]) < kwargs["waveforms"].shape[1] : - dif = kwargs["waveforms"].shape[0] - len(kwargs["times"]) + dif = kwargs["waveforms"].shape[1] - len(kwargs["times"]) new_times =[] - for i in kwargs["times"].magnitude : + for i in kwargs["times"].magnitude: new_times.append(i) np.random.seed(0) new_times = np.concatenate([new_times, np.random.random(dif)]) kwargs["times"] = pq.Quantity(new_times, units=pq.ms) else : - kwargs['times'] = kwargs['times'][:kwargs["waveforms"].shape[0]] + kwargs['times'] = kwargs['times'][:kwargs["waveforms"].shape[1]] if 'times' in kwargs and 'signal' in kwargs: kwargs['times'] = kwargs['times'][:len(kwargs['signal'])] From 66812b99c6bef7c5420bed14abab0b38a312780c Mon Sep 17 00:00:00 2001 From: Julia Sprenger Date: Wed, 24 Jan 2018 15:21:51 +0100 Subject: [PATCH 2/7] Adjust rawios to waveform dimension switch --- neo/rawio/blackrockrawio.py | 133 ++++++++++++----------- neo/rawio/examplerawio.py | 106 +++++++++--------- neo/rawio/neuralynxrawio.py | 187 ++++++++++++++++---------------- neo/rawio/neuroexplorerrawio.py | 69 ++++++------ neo/rawio/plexonrawio.py | 111 +++++++++---------- neo/rawio/spike2rawio.py | 141 ++++++++++++------------ neo/rawio/tdtrawio.py | 107 +++++++++--------- 7 files changed, 430 insertions(+), 424 deletions(-) diff --git a/neo/rawio/blackrockrawio.py b/neo/rawio/blackrockrawio.py index 17be9dfaf..86d35edf3 100644 --- a/neo/rawio/blackrockrawio.py +++ b/neo/rawio/blackrockrawio.py @@ -67,7 +67,7 @@ import quantities as pq -from .baserawio import (BaseRawIO, _signal_channel_dtype, _unit_channel_dtype, +from .baserawio import (BaseRawIO, _signal_channel_dtype, _unit_channel_dtype, _event_channel_dtype) @@ -111,16 +111,16 @@ class BlackrockRawIO(BaseRawIO): Contrary to previsous version of the IO: * nsx_to_load is not a list * must be set at the init before parse_header() - + Examples: >>> reader = BlackrockRawIO(filename='FileSpec2.3001', nsx_to_load=5) >>> reader.parse_header() - - Inspect a set of file consisting of files FileSpec2.3001.ns5 and + + Inspect a set of file consisting of files FileSpec2.3001.ns5 and FileSpec2.3001.nev >>> print(reader) - + Display all informations about signal channels, units, segment size.... """ @@ -136,12 +136,12 @@ def __init__(self, filename=None, nsx_override=None, nev_override=None, BaseRawIO.__init__(self) self.filename = filename - + # remove extension from base _filenames for ext in self.extensions: if self.filename.endswith(os.path.extsep + ext): self.filename = self.filename.replace(os.path.extsep + ext, '') - + self.nsx_to_load = nsx_to_load # remove extensions from overrides @@ -172,8 +172,8 @@ def __init__(self, filename=None, nsx_override=None, nev_override=None, self._avail_files[ext] = True if ext.startswith('ns'): self._avail_nsx.append(int(ext[-1])) - - + + # These dictionaries are used internally to map the file specification # revision of the nsx and nev files to one of the reading routines #NSX @@ -214,34 +214,34 @@ def __init__(self, filename=None, nsx_override=None, nev_override=None, '2.1': self.__get_nonneural_evtypes_variant_a, '2.2': self.__get_nonneural_evtypes_variant_a, '2.3': self.__get_nonneural_evtypes_variant_b} - + def _parse_header(self): - + main_sampling_rate = 30000. - + event_channels = [] unit_channels = [] sig_channels = [] # Step1 NEV file if self._avail_files['nev']: - # Load file spec and headers of available - + # Load file spec and headers of available + # read nev file specification self.__nev_spec = self.__extract_nev_file_spec() # read nev headers self.__nev_basic_header, self.__nev_ext_header = \ self.__nev_header_reader[self.__nev_spec]() - + self.nev_data = self.__nev_data_reader[self.__nev_spec]() spikes = self.nev_data['Spikes'] - + #scan all channel to get number of Unit unit_channels = [] self.internal_unit_ids = [] #pair of chan['packet_id'], spikes['unit_class_nb'] for i in range(len(self.__nev_ext_header[b'NEUEVWAV'])): - + channel_id = self.__nev_ext_header[b'NEUEVWAV']['electrode_id'][i] chan_mask = (spikes['packet_id'] == channel_id) @@ -255,7 +255,7 @@ def _parse_header(self): wf_offset = 0. wf_units = 'uV' # TODO: Double check if this is the correct assumption (10 samples) - # default value: threshold crossing after 10 samples of waveform + # default value: threshold crossing after 10 samples of waveform wf_left_sweep = 10 wf_sampling_rate = main_sampling_rate unit_channels.append((name, _id, wf_units, wf_gain,wf_offset, wf_left_sweep, wf_sampling_rate)) @@ -265,14 +265,14 @@ def _parse_header(self): ev_dict = self.__nonneural_evtypes[self.__nev_spec](events_data) for ev_name in ev_dict: event_channels.append((ev_name, '', 'event')) - + # Step2 NSX file # Load file spec and headers of available nsx files self.__nsx_spec = {} self.__nsx_basic_header = {} self.__nsx_ext_header = {} self.__nsx_data_header = {} - + for nsx_nb in self._avail_nsx: spec = self.__nsx_spec[nsx_nb] = self.__extract_nsx_file_spec(nsx_nb) # read nsx headers @@ -281,15 +281,15 @@ def _parse_header(self): # Read nsx data header(s) for nsxdef get_analogsignal_shape(self, block_index, seg_index): self.__nsx_data_header[nsx_nb] = self.__nsx_dataheader_reader[spec](nsx_nb) - + #We can load only one for one class instance if self.nsx_to_load is None and len(self._avail_nsx)>0: self.nsx_to_load = max(self._avail_nsx) - + if self.nsx_to_load is not None: spec = self.__nsx_spec[self.nsx_to_load] self.nsx_data = self.__nsx_data_reader[spec](self.nsx_to_load) - + self._nb_segment = len(self.nsx_data) sig_sampling_rate = float(main_sampling_rate / self.__nsx_basic_header[self.nsx_to_load]['period']) @@ -322,9 +322,9 @@ def _parse_header(self): (float(chan['max_digital_val']) - float(chan['min_digital_val'])) offset = -float(chan['min_digital_val'])*gain + float(chan['min_analog_val']) group_id = 0 - sig_channels.append((ch_name, ch_id, sig_sampling_rate, sig_dtype, + sig_channels.append((ch_name, ch_id, sig_sampling_rate, sig_dtype, units, gain, offset,group_id,)) - + #t_start/t_stop for segment are given by nsx limits or nev limits self._sigs_t_starts = [] self._seg_t_starts, self._seg_t_stops = [], [] @@ -353,7 +353,7 @@ def _parse_header(self): self._seg_t_starts.append(t_start) self._seg_t_stops.append(float(t_stop)) self._sigs_t_starts.append(float(t_start)) - + else: #not signal at all so 1 segment self._nb_segment = 1 @@ -376,18 +376,18 @@ def _parse_header(self): unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype) event_channels = np.array(event_channels, dtype=_event_channel_dtype) sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype) - + self.header = {} self.header['nb_block'] = 1 self.header['nb_segment'] = [self._nb_segment] self.header['signal_channels'] = sig_channels self.header['unit_channels'] = unit_channels self.header['event_channels'] = event_channels - - + + rec_datetime = self.__nev_params('rec_datetime') if self._avail_files['nev'] else None - - + + #Put annotations at some places for compatibility #with previous BlackrockIO version self._generate_minimal_annotations() @@ -407,7 +407,7 @@ def _parse_header(self): unit_ann['channel_id'] = self.internal_unit_ids[c][0] unit_ann['unit_id'] = self.internal_unit_ids[c][1] unit_ann['unit_tag'] = {0: 'unclassified', 255: 'noise'}.get(unit_id, str(unit_id)) - + for seg_index in range(self._nb_segment): seg_ann = block_ann['segments'][seg_index] seg_ann['file_origin'] = self.filename @@ -416,14 +416,14 @@ def _parse_header(self): # if more than 1 segment means pause # so datetime is valide only for seg_index=0 seg_ann['rec_datetime'] = rec_datetime - + for c in range(sig_channels.size): anasig_an = seg_ann['signals'][c] desc = "AnalogSignal {} from channel_id: {}, label: {}, nsx: {}".format( c, sig_channels['id'][c], sig_channels['name'][c], self.nsx_to_load) anasig_an['description'] = desc anasig_an['file_origin'] = self.filename+'.ns'+str(self.nsx_to_load) - + for c in range(unit_channels.size): channel_id, unit_id = self.internal_unit_ids[c] st_ann = seg_ann['units'][c] @@ -432,16 +432,16 @@ def _parse_header(self): st_ann['description'] = 'SpikeTrain channel_id: {}, unit_id: {}'.format( channel_id, unit_id) st_ann['file_origin'] = self.filename+'.nev' - - + + if self._avail_files['nev']: ev_dict = self.__nonneural_evtypes[self.__nev_spec](events_data) for c in range(event_channels.size): ev_ann = seg_ann['events'][c] name = event_channels['name'][c] ev_ann['description'] = ev_dict[name]['desc'] - - + + def _source_name(self): return self.filename @@ -462,16 +462,16 @@ def _get_signal_t_start(self, block_index, seg_index, channel_indexes): def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes): assert block_index==0 memmap_data = self.nsx_data[seg_index] - + if channel_indexes is None: channel_indexes = slice(None) sig_chunk = memmap_data[i_start:i_stop, channel_indexes] return sig_chunk - - + + def _spike_count(self, block_index, seg_index, unit_index): channel_id, unit_id = self.internal_unit_ids[unit_index] - + all_spikes = self.nev_data['Spikes'] mask = (all_spikes['packet_id']==channel_id) & (all_spikes['unit_class_nb']==unit_id) if self._nb_segment==1: @@ -484,22 +484,22 @@ def _spike_count(self, block_index, seg_index, unit_index): timestamp = timestamp[sl] nb = timestamp.size return nb - + def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): channel_id, unit_id = self.internal_unit_ids[unit_index] - + all_spikes = self.nev_data['Spikes'] - + #select by channel_id and unit_id mask = (all_spikes['packet_id']==channel_id) & (all_spikes['unit_class_nb']==unit_id) unit_spikes = all_spikes[mask] - + timestamp = unit_spikes['timestamp'] sl = self._get_timestamp_slice(timestamp, seg_index, t_start, t_stop) timestamp = timestamp[sl] - + return timestamp - + def _get_timestamp_slice(self, timestamp, seg_index, t_start, t_stop): if self._nb_segment>1: #we must clip event in seg time limits @@ -507,44 +507,45 @@ def _get_timestamp_slice(self, timestamp, seg_index, t_start, t_stop): t_start = self._seg_t_starts[seg_index] if t_stop is None: t_stop = self._seg_t_stops[seg_index] - + if t_start is None: ind_start = None else: ts = np.math.ceil(t_start*self.__nev_basic_header['timestamp_resolution']) ind_start = np.searchsorted(timestamp, ts) - + if t_stop is None: ind_stop = None else: ts = int(t_stop*self.__nev_basic_header['timestamp_resolution']) ind_stop = np.searchsorted(timestamp, ts) #+1 - + sl = slice(ind_start, ind_stop) return sl - + def _rescale_spike_timestamp(self, spike_timestamps, dtype): spike_times = spike_timestamps.astype(dtype) spike_times /= self.__nev_basic_header['timestamp_resolution'] return spike_times - + def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop): channel_id, unit_id = self.internal_unit_ids[unit_index] all_spikes = self.nev_data['Spikes'] mask = (all_spikes['packet_id']==channel_id) & (all_spikes['unit_class_nb']==unit_id) unit_spikes = all_spikes[mask] - + wf_dtype = self.__nev_params('waveform_dtypes')[channel_id] wf_size = self.__nev_params('waveform_size')[channel_id] waveforms = unit_spikes['waveform'].flatten().view(wf_dtype) waveforms = waveforms.reshape(int(unit_spikes.size), 1, int(wf_size)) - + timestamp = unit_spikes['timestamp'] sl = self._get_timestamp_slice(timestamp, seg_index, t_start, t_stop) waveforms = waveforms[sl] - + waveforms = np.moveaxis(waveforms, 2, 0) + return waveforms def _event_count(self, block_index, seg_index, event_channel_index): @@ -561,36 +562,36 @@ def _event_count(self, block_index, seg_index, event_channel_index): timestamp = timestamp[sl] nb = timestamp.size return nb - + def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): name = self.header['event_channels']['name'][event_channel_index] events_data = self.nev_data['NonNeural'] ev_dict = self.__nonneural_evtypes[self.__nev_spec](events_data)[name] - + timestamp = events_data[ev_dict['mask']]['timestamp'] labels = events_data[ev_dict['mask']][ev_dict['field']] - + #time clip sl = self._get_timestamp_slice(timestamp, seg_index, t_start, t_stop) timestamp = timestamp[sl] labels = labels[sl] durations = None - + return timestamp, durations, labels - - + + def _rescale_event_timestamp(self, event_timestamps, dtype): ev_times = event_timestamps.astype(dtype) ev_times /= self.__nev_basic_header['timestamp_resolution'] return ev_times - + ################################################### ################################################### - + #Above here code from Lyuba Zehl, Michael Denker # coming from previous BlackrockIO - + def __extract_nsx_file_spec(self, nsx_nb): """ Extract file specification from an .nsx file. @@ -785,7 +786,7 @@ def __read_nsx_dataheader_variant_b( int(self.__nsx_basic_header[nsx_nb]['channel_count']) * 2 # define new offset (to possible next data block) offset = data_header[index]['offset_to_data_block'] + data_size - + index += 1 return data_header diff --git a/neo/rawio/examplerawio.py b/neo/rawio/examplerawio.py index f1f321420..e7847d432 100644 --- a/neo/rawio/examplerawio.py +++ b/neo/rawio/examplerawio.py @@ -13,18 +13,18 @@ * code hard! The main difficulty **is _parse_header()**. In short you have a create a mandatory dict than contains channel informations:: - + self.header = {} self.header['nb_block'] = 2 self.header['nb_segment'] = [2, 3] self.header['signal_channels'] = sig_channels self.header['unit_channels'] = unit_channels - self.header['event_channels'] = event_channels - + self.header['event_channels'] = event_channels + 2. Step 2: RawIO test: * create a file in neo/rawio/tests with the same name with "test_" prefix * copy paste neo/rawio/tests/test_examplerawio.py and do the same - + 3. Step 3 : Create the neo.io class with the wrapper * Create a file in neo/io/ that endith with "io.py" * Create a that hinerits bot yrou RawIO class and BaseFromRaw class @@ -39,7 +39,7 @@ """ from __future__ import unicode_literals, print_function, division, absolute_import -from .baserawio import (BaseRawIO, _signal_channel_dtype, _unit_channel_dtype, +from .baserawio import (BaseRawIO, _signal_channel_dtype, _unit_channel_dtype, _event_channel_dtype) import numpy as np @@ -53,7 +53,7 @@ class ExampleRawIO(BaseRawIO): For a developer, it is just an example showing guidelines for someone who wants to develop a new IO module. - + Two rules for developers: * Respect the Neo RawIO API (:ref:`_neo_rawio_API`) * Follow :ref:`_io_guiline` @@ -65,7 +65,7 @@ class ExampleRawIO(BaseRawIO): * have 3 unit_channel * have 2 event channel: one have *type=event*, the other have *type=epoch* - + Usage: >>> import neo.rawio @@ -74,19 +74,19 @@ class ExampleRawIO(BaseRawIO): >>> print(r) >>> raw_chunk = r.get_analogsignal_chunk(block_index=0, seg_index=0, i_start=0, i_stop=1024, channel_names=channel_names) - >>> float_chunk = reader.rescale_signal_raw_to_float(raw_chunk, dtype='float64', + >>> float_chunk = reader.rescale_signal_raw_to_float(raw_chunk, dtype='float64', channel_indexes=[0, 3, 6]) >>> spike_timestamp = reader.spike_timestamps(unit_index=0, t_start=None, t_stop=None) >>> spike_times = reader.rescale_spike_timestamp(spike_timestamp, 'float64') >>> ev_timestamps, _, ev_labels = reader.event_timestamps(event_channel_index=0) - + """ extensions = ['fake'] rawmode = 'one-file' def __init__(self, filename=''): BaseRawIO.__init__(self) #note that this filename is ued in self._source_name - self.filename = filename + self.filename = filename def _source_name(self): # this function is used by __repr__ @@ -94,7 +94,7 @@ def _source_name(self): # But for URL you could mask some part of the URL to keep # the main part. return self.filename - + def _parse_header(self): # This is the central of a RawIO # we need to collect in the original format all @@ -102,8 +102,8 @@ def _parse_header(self): # at any place in the file # In short _parse_header can be slow but # _get_analogsignal_chunk need to be as fast as possible - - + + #create signals channels information #This is mandatory!!!! #gain/offset/units are really important because @@ -125,11 +125,11 @@ def _parse_header(self): #group_id isonly for special cases when channel have diferents # sampling rate for instance. See TdtIO for that. #Here this is the general case :all channel have the same characteritics - group_id = 0 + group_id = 0 sig_channels.append((ch_name, chan_id, sr, dtype, units, gain,offset, group_id)) sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype) - - + + #creating units channels #This is mandatory!!!! #Note that if there is no waveform at all in the file @@ -147,7 +147,7 @@ def _parse_header(self): wf_sampling_rate = 10000. unit_channels.append((unit_name, unit_id, wf_units, wf_gain, wf_offset, wf_left_sweep, wf_sampling_rate)) unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype) - + # creating event/epoch channel #This is mandatory!!!! #In RawIO epoch and event they are dealt the same way. @@ -155,7 +155,7 @@ def _parse_header(self): event_channels.append(('Some events', 'ev_0', 'event')) event_channels.append(('Some epochs', 'ep_1', 'epoch')) event_channels = np.array(event_channels, dtype=_event_channel_dtype) - + #fille into header dict #This is mandatory!!!!! self.header = {} @@ -164,10 +164,10 @@ def _parse_header(self): self.header['signal_channels'] = sig_channels self.header['unit_channels'] = unit_channels self.header['event_channels'] = event_channels - + # insert some annotation at some place # at neo.io level IO are free to add some annoations - # to any object. To keep this functionality with the wrapper + # to any object. To keep this functionality with the wrapper # BaseFromRaw you can add annoations in a nested dict. self._generate_minimal_annotations() #If you are a lazy dev you can stop here. @@ -193,8 +193,8 @@ def _parse_header(self): event_an['nickname'] = 'Miss Event 0' elif c==1: event_an['nickname'] = 'MrEpoch 1' - - + + def _segment_t_start(self, block_index, seg_index): # this must return an float scale in second # this t_start will be shared by all object in the segment @@ -211,79 +211,79 @@ def _get_signal_size(self, block_index, seg_index, channel_indexes=None): # we are lucky: signals in all segment have the same shape!! (10.0 seconds) # it is not always the case # this must return an int = the number of sample - + # Note that channel_indexes can be ignored for most cases # except for several sampling rate. return 100000 - + def _get_signal_t_start(self, block_index, seg_index, channel_indexes): # This give the t_start of signals. - # Very often this equal to _segment_t_start but not + # Very often this equal to _segment_t_start but not # always. # this must return an float scale in second # Note that channel_indexes can be ignored for most cases # except for several sampling rate. - + #Here this is the same. # this is not always the case return self._segment_t_start(block_index, seg_index) - + def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes): - # this must return a signal chunk limited with + # this must return a signal chunk limited with # i_start/i_stop (can be None) # channel_indexes can be None (=all channel) or a list or numpy.array # This must return a numpy array 2D (even with one channel). # This must return the orignal dtype. No conversion here. # This must as fast as possible. # Everything that can be done in _parse_header() must not be here. - + #Here we are lucky: our signals is always zeros!! #it is not always the case - #internally signals are int16 + #internally signals are int16 #convertion to real units is done with self.header['signal_channels'] - + if i_start is None: i_start=0 if i_stop is None: i_stop=100000 - + assert i_start>=0, "I don't like your jokes" assert i_stop<=100000, "I don't like your jokes" - + if channel_indexes is None: nb_chan = 16 else: nb_chan = len(channel_indexes) raw_signals = np.zeros((i_stop-i_start, nb_chan), dtype='int16') return raw_signals - + def _spike_count(self, block_index, seg_index, unit_index): # Must return the nb of spike for given (block_index, seg_index, unit_index) #we are lucky: our units have all the same nb of spikes!! #it is not always the case nb_spikes = 20 return nb_spikes - + def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): # In our IO, timstamp are internally coded 'int64' and they # represent the index of the signals 10kHz # we are lucky: spikes have the same discharge in all segments!! # incredible neuron!! This is not always the case - + #the same clip t_start/t_start must be used in _spike_raw_waveforms() - + ts_start = (self._segment_t_start(block_index, seg_index)*10000) - + spike_timestamps = np.arange(0, 10000, 500) + ts_start - + if t_start is not None or t_stop is not None: #restricte spikes to given limits (in seconds) lim0 = int(t_start*10000) lim1 = int(t_stop*10000) mask = (spike_timestamps>=lim0) & (spike_timestamps<=lim1) spike_timestamps = spike_timestamps[mask] - + return spike_timestamps - + def _rescale_spike_timestamp(self, spike_timestamps, dtype): #must rescale to second a particular spike_timestamps #with a fixed dtype so the user can choose the precisino he want. @@ -299,19 +299,19 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, # If there there is no waveform supported in the # IO them _spike_raw_waveforms must return None - + #In our IO waveforms come from all channels #they are int16 #convertion to real units is done with self.header['unit_channels'] #Here, we have a realistic case: all waveforms are only noise. #it is not always the case # we 20 spikes with a sweep of 50 (5ms) - + np.random.seed(2205) # a magic number (my birthday) waveforms = np.random.randint(low=-2**4, high=2**4, size=20*50, dtype='int16') - waveforms = waveforms.reshape(20, 1, 50) + waveforms = waveforms.reshape(50, 20, 1) return waveforms - + def _event_count(self, block_index, seg_index, event_channel_index): # event and spike are very similar # we have 2 event channels @@ -321,13 +321,13 @@ def _event_count(self, block_index, seg_index, event_channel_index): elif event_channel_index==1: #epoch channel return 10 - + def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): #the main difference between spike channel and event channel # is that for here we have 3 numpy array timestamp, durations, labels #durations must be None for 'event' #label must a dtype ='U' - + # in our IO event are directly coded in seconds seg_t_start = self._segment_t_start(block_index, seg_index) if event_channel_index==0: @@ -338,31 +338,31 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_ timestamp = np.arange(0, 10, dtype='float64') + .5 + seg_t_start durations = np.ones((10), dtype='float64') * .25 labels = np.array(['zoneX']*5+['zoneZ']*5, dtype='U12') - + if t_start is not None: keep = timestamp>=t_start timestamp, labels = timestamp[keep], labels[keep] if durations is not None: durations = durations[keep] - + if t_stop is not None: keep = timestamp<=t_stop timestamp, labels = timestamp[keep], labels[keep] if durations is not None: durations = durations[keep] - + return timestamp, durations, labels - + def _rescale_event_timestamp(self, event_timestamps, dtype): #must rescale to second a particular event_timestamps #with a fixed dtype so the user can choose the precisino he want. - + # really easy here because in our case it is already seconds event_times = event_timestamps.astype(dtype) return event_times - + def _rescale_epoch_duration(self, raw_duration, dtype): # really easy here because in our case it is already seconds durations = raw_duration.astype(dtype) diff --git a/neo/rawio/neuralynxrawio.py b/neo/rawio/neuralynxrawio.py index 5b5df47e4..57736e872 100644 --- a/neo/rawio/neuralynxrawio.py +++ b/neo/rawio/neuralynxrawio.py @@ -21,7 +21,7 @@ #from __future__ import unicode_literals is not compatible with numpy.dtype both py2 py3 -from .baserawio import (BaseRawIO, _signal_channel_dtype, _unit_channel_dtype, +from .baserawio import (BaseRawIO, _signal_channel_dtype, _unit_channel_dtype, _event_channel_dtype) import numpy as np @@ -38,15 +38,15 @@ class NeuralynxRawIO(BaseRawIO): """" Class for reading dataset recorded by Neuralynx. - + Examples: >>> reader = NeuralynxRawIO(dirname='Cheetah_v5.5.1/original_data') >>> reader.parse_header() - + Inspect all file in the directory. >>> print(reader) - + Display all informations about signal channels, units, segment size.... """ extensions = ['nse', 'ncs', 'nev', 'ntt'] @@ -54,26 +54,26 @@ class NeuralynxRawIO(BaseRawIO): def __init__(self, dirname='', **kargs): self.dirname = dirname BaseRawIO.__init__(self, **kargs) - + def _source_name(self): return self.dirname - + def _parse_header(self): - + sig_channels = [] unit_channels = [] event_channels = [] - + self.ncs_filenames = OrderedDict() #chan_id: filename self.nse_ntt_filenames = OrderedDict() #chan_id: filename self.nev_filenames = OrderedDict() #chan_id: filename - + self._nev_memmap = {} self._spike_memmap = {} self.internal_unit_ids = [] #channel_index > (channel_id, unit_id) self.internal_event_ids = [] - + # explore the directory looking for ncs, nev, nse and ntt # And construct channels headers signal_annotations = [] @@ -81,17 +81,17 @@ def _parse_header(self): event_annotations = [] for filename in sorted(os.listdir(self.dirname)): filename = os.path.join(self.dirname, filename) - + _, ext = os.path.splitext(filename) ext = ext[1:]#remove dot if ext not in self.extensions: continue - + #All file have more or less the same header structure info = read_txt_header(filename) chan_name = info['channel_name'] chan_id = info['channel_id'] - + if ext=='ncs': # a signal channels units = 'uV' @@ -100,32 +100,32 @@ def _parse_header(self): gain *= -1 offset = 0. group_id = 0 - sig_channels.append((chan_name, chan_id, info['sampling_rate'], 'int16', + sig_channels.append((chan_name, chan_id, info['sampling_rate'], 'int16', units, gain,offset, group_id)) self.ncs_filenames[chan_id] = filename keys = ['DspFilterDelay_µs', 'recording_opened', 'FileType', 'DspDelayCompensation', 'recording_closed', - 'DspLowCutFilterType', 'HardwareSubSystemName', 'DspLowCutNumTaps', 'DSPLowCutFilterEnabled', + 'DspLowCutFilterType', 'HardwareSubSystemName', 'DspLowCutNumTaps', 'DSPLowCutFilterEnabled', 'HardwareSubSystemType', 'DspHighCutNumTaps', 'ADMaxValue', 'DspLowCutFrequency', - 'DSPHighCutFilterEnabled', 'RecordSize', 'InputRange', 'DspHighCutFrequency', + 'DSPHighCutFilterEnabled', 'RecordSize', 'InputRange', 'DspHighCutFrequency', 'input_inverted', 'NumADChannels', 'DspHighCutFilterType', ] d = {k:info[k] for k in keys if k in info} signal_annotations.append(d) - + elif ext in ('nse', 'ntt'): # nse and ntt are pretty similar execept for the wavform shape # a file can contain several unit_id (so several unit channel) assert chan_id not in self.nse_ntt_filenames self.nse_ntt_filenames[chan_id] = filename, 'Several nse or ntt files have the same unit_id!!!' - + dtype = get_nse_or_ntt_dtype(info, ext) data = np.memmap(filename, dtype=dtype, mode='r', offset=HEADER_SIZE) self._spike_memmap[chan_id] = data - + unit_ids = np.unique(data['unit_id']) for unit_id in unit_ids: # a spike channel for each (chan_id, unit_id) self.internal_unit_ids.append((chan_id, unit_id)) - + unit_name = "ch{}#{}".format(chan_id, unit_id) unit_id = '{}'.format(unit_id) wf_units = 'uV' @@ -135,10 +135,10 @@ def _parse_header(self): wf_offset = 0. wf_left_sweep = -1 #DONT KNOWN wf_sampling_rate = info['sampling_rate'] - unit_channels.append((unit_name, '{}'.format(unit_id), wf_units, wf_gain, wf_offset, + unit_channels.append((unit_name, '{}'.format(unit_id), wf_units, wf_gain, wf_offset, wf_left_sweep, wf_sampling_rate)) unit_annotations.append(dict(file_origin=filename)) - + elif ext=='nev': # an event channel # each ('event_id', 'ttl_input') give a new event channel @@ -151,21 +151,21 @@ def _parse_header(self): name = '{} event_id={} ttl={}'.format(chan_name, event_id, ttl_input) event_channels.append((name, chan_id, 'event')) self.internal_event_ids.append(internal_event_id) - + self._nev_memmap[chan_id] = data - + sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype) unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype) event_channels = np.array(event_channels, dtype=_event_channel_dtype) - + if sig_channels.size>0: sampling_rate = np.unique(sig_channels['sampling_rate']) assert sampling_rate.size==1 self._sigs_sampling_rate = sampling_rate[0] - + #read ncs files for gaps detection and nb_segment computation self.read_ncs_files(self.ncs_filenames) - + #timestamp limit in nev, nse #so need to scan all spike and event to ts0, ts1 = None, None @@ -178,7 +178,7 @@ def _parse_header(self): ts1 = ts[-1] ts0 = min(ts0, ts[0]) ts1 = max(ts0, ts[-1]) - + if self._timestamp_limits is None: #case NO ncs but HAVE nev or nse self._timestamp_limits = [(ts0, ts1)] @@ -200,7 +200,7 @@ def _parse_header(self): self._seg_t_stops = self._sigs_t_stop self.global_t_start = self._sigs_t_start[0] self.global_t_stop = self._sigs_t_stop[-1] - + #fille into header dict self.header = {} self.header['nb_block'] = 1 @@ -208,7 +208,7 @@ def _parse_header(self): self.header['signal_channels'] = sig_channels self.header['unit_channels'] = unit_channels self.header['event_channels'] = event_channels - + # Annotations self._generate_minimal_annotations() bl_annotations = self.raw_annotations['blocks'][0] @@ -219,83 +219,83 @@ def _parse_header(self): for c in range(sig_channels.size): sig_ann = seg_annotations['signals'][c] sig_ann.update(signal_annotations[c]) - + for c in range(unit_channels.size): unit_ann = seg_annotations['units'][c] unit_ann.update(unit_annotations[c]) - + for c in range(event_channels.size): #annotations for channel events event_id, ttl_input = self.internal_event_ids[c] chan_id = event_channels[c]['id'] - + ev_ann = seg_annotations['events'][c] ev_ann['file_origin'] = self.nev_filenames[chan_id] - #~ ev_ann['marker_id'] = - #~ ev_ann['nttl'] = - #~ ev_ann['digital_marker'] = - #~ ev_ann['analog_marker'] = - + #~ ev_ann['marker_id'] = + #~ ev_ann['nttl'] = + #~ ev_ann['digital_marker'] = + #~ ev_ann['analog_marker'] = + def _segment_t_start(self, block_index, seg_index): return self._seg_t_starts[seg_index] - self.global_t_start def _segment_t_stop(self, block_index, seg_index): return self._seg_t_stops[seg_index] - self.global_t_start - + def _get_signal_size(self, block_index, seg_index, channel_indexes): return self._sigs_length[seg_index] def _get_signal_t_start(self, block_index, seg_index, channel_indexes): return self._sigs_t_start[seg_index] - self.global_t_start - + def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes): if i_start is None: i_start=0 if i_stop is None: i_stop=self._sigs_length[seg_index] - + block_start = i_start//BLOCK_SIZE block_stop = i_stop//BLOCK_SIZE+1 sl0 = i_start % 512 sl1 = sl0 + (i_stop-i_start) - + if channel_indexes is None: channel_indexes = slice(None) channel_ids = self.header['signal_channels'][channel_indexes]['id'] - + sigs_chunk = np.zeros((i_stop-i_start, len(channel_ids)), dtype='int16') for i, chan_id in enumerate(channel_ids): data = self._sigs_memmap[seg_index][chan_id] sub = data[block_start:block_stop] sigs_chunk[:, i] = sub['samples'].flatten()[sl0:sl1] - + return sigs_chunk - - + + def _spike_count(self, block_index, seg_index, unit_index): chan_id, unit_id = self.internal_unit_ids[unit_index] data = self._spike_memmap[chan_id] ts = data['timestamp'] - + ts0, ts1 = self._timestamp_limits[seg_index] - + keep = (ts>=ts0) & (ts<=ts1) & (unit_id==data['unit_id']) nb_spike = int(data[keep].size) return nb_spike - + def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): chan_id, unit_id = self.internal_unit_ids[unit_index] data = self._spike_memmap[chan_id] ts = data['timestamp'] - + ts0, ts1 = self._timestamp_limits[seg_index] if t_start is not None: ts0 = int((t_start+self.global_t_start)*1e6) if t_start is not None: ts1 = int((t_stop+self.global_t_start)*1e6) - + keep = (ts>=ts0) & (ts<=ts1) & (unit_id==data['unit_id']) timestamps = ts[keep] return timestamps - + def _rescale_spike_timestamp(self, spike_timestamps, dtype): spike_times = spike_timestamps.astype(dtype) spike_times /= 1e6 @@ -306,15 +306,15 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, chan_id, unit_id = self.internal_unit_ids[unit_index] data = self._spike_memmap[chan_id] ts = data['timestamp'] - + ts0, ts1 = self._timestamp_limits[seg_index] if t_start is not None: ts0 = int((t_start+self.global_t_start)*1e6) if t_start is not None: ts1 = int((t_stop+self.global_t_start)*1e6) - + keep = (ts>=ts0) & (ts<=ts1) & (unit_id==data['unit_id']) - + wfs = data[keep]['samples'] if wfs.ndim ==2: #case for nse @@ -322,9 +322,10 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, else: #case for ntt change (n, 32, 4) to (n, 4, 32) waveforms = wfs.swapaxes(1,2) - + waveforms = np.moveaxis(waveforms, 2, 0) + return waveforms - + def _event_count(self, block_index, seg_index, event_channel_index): event_id, ttl_input = self.internal_event_ids[event_channel_index] chan_id = self.header['event_channels'][event_channel_index]['id'] @@ -335,18 +336,18 @@ def _event_count(self, block_index, seg_index, event_channel_index): (data['ttl_input']==ttl_input) nb_event = int(data[keep].size) return nb_event - + def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): event_id, ttl_input = self.internal_event_ids[event_channel_index] chan_id = self.header['event_channels'][event_channel_index]['id'] data = self._nev_memmap[chan_id] ts0, ts1 = self._timestamp_limits[seg_index] - + if t_start is not None: ts0 = int((t_start+self.global_t_start)*1e6) if t_start is not None: ts1 = int((t_stop+self.global_t_start)*1e6) - + ts = data['timestamp'] keep = (ts>=ts0) & (ts<=ts1) & (data['event_id']==event_id) &\ (data['ttl_input']==ttl_input) @@ -355,13 +356,13 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_ labels = subdata['event_string'].astype('U') durations = None return timestamps, durations, labels - + def _rescale_event_timestamp(self, event_timestamps, dtype): event_times = event_timestamps.astype(dtype) event_times /= 1e6 event_times -= self.global_t_start return event_times - + def read_ncs_files(self, ncs_filenames): """ Given a list of ncs files contrsuct: @@ -371,40 +372,40 @@ def read_ncs_files(self, ncs_filenames): * self._sigs_length = [] * self._nb_segment * self._timestamp_limits - + The first file is read entirely to detect gaps in timestamp. each gap lead to a new segment. - + Other files are not read entirely but we check than gaps are at the same place. - - + + gap_indexes can be given (when cached) to avoid full read. - + """ if len(ncs_filenames)==0: self._nb_segment = 1 self._timestamp_limits = None return - + good_delta = int(BLOCK_SIZE*1e6/self._sigs_sampling_rate) chan_id0 = list(ncs_filenames.keys())[0] filename0 = ncs_filenames[chan_id0] data0 = np.memmap(filename0, dtype=ncs_dtype, mode='r', offset=HEADER_SIZE) - + gap_indexes = None if self.use_cache: gap_indexes = self._cache.get('gap_indexes') - + #detect gaps on first file if gap_indexes is None: #this can be long!!!! timestamps0 = data0['timestamp'] deltas0 = np.diff(timestamps0) - + #It should be that: #gap_indexes, = np.nonzero(deltas0!=good_delta) - + # but for a file I have found many deltas0==15999 deltas0==16000 # I guess this is a round problem # So this is the same with a tolerance of 1 or 2 ticks @@ -413,13 +414,13 @@ def read_ncs_files(self, ncs_filenames): mask &= (deltas0!=good_delta-tolerance) mask &= (deltas0!=good_delta+tolerance) gap_indexes, = np.nonzero(mask) - + if self.use_cache: self.add_in_cache(gap_indexes=gap_indexes) - + gap_bounds = [0] + (gap_indexes+1).tolist() + [data0.size] self._nb_segment = len(gap_bounds)-1 - + self._sigs_memmap = [ {} for seg_index in range(self._nb_segment) ] self._sigs_t_start = [] self._sigs_t_stop = [] @@ -429,17 +430,17 @@ def read_ncs_files(self, ncs_filenames): for chan_id, ncs_filename in self.ncs_filenames.items(): data = np.memmap(ncs_filename, dtype=ncs_dtype, mode='r', offset=HEADER_SIZE) assert data.size==data0.size, 'ncs files do not have the same data length' - + for seg_index in range(self._nb_segment): i0 = gap_bounds[seg_index] i1 = gap_bounds[seg_index+1] - + assert data[i0]['timestamp']==data0[i0]['timestamp'], 'ncs files do not have the same gaps' assert data[i1-1]['timestamp']==data0[i1-1]['timestamp'], 'ncs files do not have the same gaps' - + subdata = data[i0:i1] self._sigs_memmap[seg_index][chan_id] = subdata - + if chan_id==chan_id0: ts0 = subdata[0]['timestamp'] ts1 = subdata[-1]['timestamp'] + np.uint64(BLOCK_SIZE/self._sigs_sampling_rate*1e6) @@ -450,9 +451,9 @@ def read_ncs_files(self, ncs_filenames): self._sigs_t_stop.append(t_stop) length = subdata.size * BLOCK_SIZE self._sigs_length.append(length) - -# keys in + +# keys in txt_header_keys = [ ('AcqEntName', 'channel_name', None),#used ('FileType', '', None), @@ -515,7 +516,7 @@ def read_txt_header(filename): with open(filename, 'rb') as f: txt_header = f.read(HEADER_SIZE) txt_header = txt_header.strip(b'\x00').decode('latin-1') - + # find keys info = OrderedDict() for k1, k2, type_ in txt_header_keys: @@ -527,28 +528,28 @@ def read_txt_header(filename): info[k2] = r[0] if type_ is not None: info[k2] = type_(info[k2]) - + #some conversions if 'bit_to_microVolt' in info: info['bit_to_microVolt'] = info['bit_to_microVolt']*1e6 if 'version' in info: version = info['version'].replace('"', '') info['version'] = distutils.version.LooseVersion(version) - + # filename and datetime if info['version']<=distutils.version.LooseVersion('5.6.4'): datetime1_regex = '## Time Opened \(m/d/y\): (?P\S+) \(h:m:s\.ms\) (?P