Source code for bmtk.tests.utils.reports.spike_trains.test_sonata_adaptor

import pytest
import numpy as np
import tempfile
import h5py
from six import string_types
import warnings

from bmtk.utils.reports.spike_trains.spike_train_buffer import STMemoryBuffer, STCSVBuffer
from bmtk.utils.reports.spike_trains import sort_order, pop_na
from bmtk.utils.reports.spike_trains.spike_train_readers import load_sonata_file, SonataSTReader, SonataOldReader, EmptySonataReader
from bmtk.utils.reports.spike_trains.spikes_file_writers import write_sonata, write_sonata_itr
from bmtk.utils.sonata.utils import check_magic, get_version, add_hdf5_magic, add_hdf5_version


[docs] def create_st_buffer(st_cls): # Helper for creating spike_trains object if issubclass(st_cls, STCSVBuffer): return st_cls(cache_dir=tempfile.mkdtemp()) else: return st_cls()
[docs] @pytest.mark.parametrize('st_cls', [ STMemoryBuffer, STCSVBuffer ]) @pytest.mark.parametrize('write_fnc', [ write_sonata, write_sonata_itr ]) def test_write_sonata(st_cls, write_fnc): st = create_st_buffer(st_cls) st.add_spikes(population='V1', node_ids=0, timestamps=np.linspace(0, 1.0, 100)) st.add_spikes(population='V1', node_ids=2, timestamps=np.linspace(2.0, 1.0, 10)) st.add_spike(population='V1', node_id=1, timestamp=3.0) st.add_spikes(population='V2', node_ids=[3, 3, 3], timestamps=[0.25, 0.5, 0.75]) tmpfile = tempfile.NamedTemporaryFile(suffix='.h5') write_fnc(tmpfile.name, st) with h5py.File(tmpfile.name, 'r') as h5: assert(check_magic(h5)) assert(get_version(h5) is not None) assert('/spikes/V1' in h5) node_ids = h5['/spikes/V1/node_ids'][()] assert(len(node_ids) == 111) assert(set(np.unique(node_ids)) == {0, 1, 2}) assert(len(h5['/spikes/V1/timestamps'][()]) == 111) assert('/spikes/V2' in h5) assert(np.all(h5['/spikes/V2/node_ids'][()] == [3, 3, 3])) # WARNING: Not all adaptor guarentee order of spikes assert(np.allclose(np.sort(h5['/spikes/V2/timestamps'][()]), [0.25, 0.50, 0.75]))
[docs] @pytest.mark.parametrize('st_cls', [ STMemoryBuffer, STCSVBuffer ]) @pytest.mark.parametrize('write_fnc', [ write_sonata, write_sonata_itr ]) def test_write_sonata_compression(st_cls, write_fnc): def do_one_test(comp_type, test_type): st = create_st_buffer(st_cls) st.add_spikes(population='V1', node_ids=0, timestamps=np.linspace(0, 1.0, 100)) st.add_spikes(population='V1', node_ids=2, timestamps=np.linspace(2.0, 1.0, 10)) st.add_spike(population='V1', node_id=1, timestamp=3.0) st.add_spikes(population='V2', node_ids=[3, 3, 3], timestamps=[0.25, 0.5, 0.75]) tmpfile = tempfile.NamedTemporaryFile(suffix='.h5') write_fnc(tmpfile.name, st, compression=comp_type) with h5py.File(tmpfile.name, 'r') as h5: assert(check_magic(h5)) assert(get_version(h5) is not None) assert('/spikes/V1' in h5) assert(h5['/spikes/V1/node_ids'].compression == test_type) assert(h5['/spikes/V1/timestamps'].compression == test_type) assert('/spikes/V2' in h5) assert(h5['/spikes/V2/node_ids'].compression == test_type) # WARNING: Not all adaptor guarentee order of spikes assert(h5['/spikes/V2/timestamps'].compression == test_type) # an integer (0-9) specifies the compression level. comp_types = [None, 0, 'gzip', 3, 'lzf'] test_types = [None, 'gzip', 'gzip', 'gzip', 'lzf'] for comp_type, test_type in zip(comp_types, test_types): do_one_test(comp_type, test_type)
[docs] @pytest.mark.parametrize('st_cls', [ STMemoryBuffer, STCSVBuffer ]) @pytest.mark.parametrize('write_fnc', [ write_sonata, write_sonata_itr ]) def test_write_sonata_empty(st_cls, write_fnc): # Important use case, a valid simulation may run for a long time but not produce any spikes, make sure it doesn't # fail trying to write any empty set of spike-trains to h5 st = create_st_buffer(st_cls) tmpfile = tempfile.NamedTemporaryFile(suffix='.h5') write_fnc(tmpfile.name, st) with h5py.File(tmpfile.name, 'r') as h5: assert(check_magic(h5)) assert(get_version(h5) is not None) assert('/spikes' in h5) assert(len(h5['/spikes']) == 0)
[docs] @pytest.mark.parametrize('st_cls', [ STMemoryBuffer, STCSVBuffer ]) @pytest.mark.parametrize('write_fnc', [ write_sonata, write_sonata_itr ]) def test_write_sonata_append(st_cls, write_fnc): # Check that we can append spikes data to an existing sonata file. Currently it only works if /spikes/<pop_name> # does not already exists, since append to an h5 can be prohibitive. iI the future may want ot implement. tmpfile = tempfile.NamedTemporaryFile(suffix='.h5') with h5py.File(tmpfile.name, 'w') as h5: h5.create_group('/spikes/V1') st = create_st_buffer(st_cls) st.add_spikes(population='V2', node_ids=0, timestamps=np.linspace(0, 1.0, 100)) write_fnc(tmpfile.name, st, mode='a', sort_order=sort_order.by_id) with h5py.File(tmpfile.name, 'r') as h5: assert(check_magic(h5)) assert(get_version(h5) is not None) assert('/spikes/V1' in h5) assert('/spikes/V2' in h5) assert(len(h5['/spikes/V2/node_ids']) == 100) assert(len(h5['/spikes/V2/timestamps']) == 100) # Throw error if same pop_name alrady exists tmpfile2 = tempfile.NamedTemporaryFile(suffix='.h5') with h5py.File(tmpfile2.name, 'w') as h5: h5.create_group('/spikes/V2') with pytest.raises(ValueError): write_fnc(tmpfile2.name, st, mode='a', sort_order=sort_order.by_id)
[docs] @pytest.mark.parametrize('st_cls', [ STMemoryBuffer, STCSVBuffer ]) @pytest.mark.parametrize('write_fnc', [ write_sonata, write_sonata_itr ]) def test_write_sonata_bytime(st_cls, write_fnc): # Check we can sort by timestamps st = create_st_buffer(st_cls) st.add_spikes(population='V1', node_ids=0, timestamps=[0.5, 0.3, 0.1, 0.2, 0.4]) tmpfile = tempfile.NamedTemporaryFile(suffix='.h5') write_fnc(tmpfile.name, st, sort_order=sort_order.by_time) with h5py.File(tmpfile.name, 'r') as h5: assert(check_magic(h5)) assert(get_version(h5) is not None) assert(h5['/spikes/V1'].attrs['sorting'] == 'by_time') assert(np.all(h5['/spikes/V1/node_ids'][()] == [0, 0, 0, 0, 0])) assert(np.all(h5['/spikes/V1/timestamps'][()] == [0.1, 0.2, 0.3, 0.4, 0.5]))
[docs] @pytest.mark.parametrize('st_cls', [ STMemoryBuffer, STCSVBuffer ]) @pytest.mark.parametrize('write_fnc', [ write_sonata, write_sonata_itr ]) def test_write_sonata_byid(st_cls, write_fnc): # Check we can sort by node_ids st = create_st_buffer(st_cls) st.add_spikes(population='V1', node_ids=[2, 4, 2, 1, 3, 3, 6, 0], timestamps=[0.1]*8) tmpfile = tempfile.NamedTemporaryFile(suffix='.h5') write_fnc(tmpfile.name, st, sort_order=sort_order.by_id) with h5py.File(tmpfile.name, 'r') as h5: assert(check_magic(h5)) assert(get_version(h5) is not None) assert(h5['/spikes/V1'].attrs['sorting'] == 'by_id') assert(np.all(h5['/spikes/V1/node_ids'][()] == [0, 1, 2, 2, 3, 3, 4, 6])) assert(np.all(h5['/spikes/V1/timestamps'][()] == [0.1]*8))
[docs] @pytest.mark.filterwarnings("ignore::UserWarning") def test_sonata_reader(): # Test ability to read an existing sonata file tmp_h5 = tempfile.NamedTemporaryFile(suffix='.h5') with h5py.File(tmp_h5.name, 'w') as h5: add_hdf5_magic(h5) add_hdf5_version(h5) h5.create_dataset('/spikes/V1/node_ids', data=[0, 0, 0, 0, 2, 1, 2], dtype=np.uint) h5.create_dataset('/spikes/V1/timestamps', data=[0.25, 0.5, 0.75, 1.0, 3.0, 0.001, 2.0], dtype=float) h5.create_dataset('/spikes/V2/node_ids', data=[10, 10, 10], dtype=np.uint) h5.create_dataset('/spikes/V2/timestamps', data=[4.0, 4.0, 4.0], dtype=float) h5.create_group('/spikes/V3') st = SonataSTReader(path=tmp_h5.name, default_population='V1') assert(set(st.populations) == {'V1', 'V2'}) assert(st.n_spikes() == 7) assert(st.n_spikes(population='V1') == 7) assert(st.n_spikes(population='V2') == 3) assert(set(st.node_ids()) == {0, 1, 2}) assert(set(st.node_ids(population='V1')) == {0, 1, 2}) assert(np.all(st.node_ids(population='V2') == [10])) assert(np.allclose(np.sort(st.get_times(0)), [0.25, 0.50, 0.75, 1.0])) assert(np.allclose(st.get_times(1, population='V1'), [0.001])) assert(np.allclose(st.get_times(10, population='V2'), [4.0, 4.0, 4.0])) df = st.to_dataframe() assert(len(df) == 10) assert(set(df.columns) == {'timestamps', 'population', 'node_ids'}) df = st.to_dataframe(populations='V1', sort_order=sort_order.by_id, with_population_col=False) assert(len(df) == 7) assert(set(df.columns) == {'timestamps', 'node_ids'}) assert(np.all(np.diff(df['node_ids']) >= 0)) all_spikes = list(st.spikes()) assert(len(all_spikes) == 10) assert(isinstance(all_spikes[0][0], (float, float))) assert(isinstance(all_spikes[0][1], string_types)) assert(isinstance(all_spikes[0][2], (int, np.uint)))
[docs] def test_oldsonata_reader(): # A special reader for an older version of the spikes format tmp_h5 = tempfile.NamedTemporaryFile(suffix='.h5') with h5py.File(tmp_h5.name, 'w') as h5: add_hdf5_magic(h5) add_hdf5_version(h5) h5.create_dataset('/spikes/gids', data=[0, 0, 0, 0, 2, 1, 2], dtype=np.uint) h5.create_dataset('/spikes/timestamps', data=[0.25, 0.5, 0.75, 1.0, 3.0, 0.001, 2.0], dtype=float) st = SonataOldReader(path=tmp_h5.name) assert(np.all(st.populations == [pop_na])) assert(st.n_spikes() == 7) assert(set(st.node_ids()) == {0, 1, 2}) assert(np.allclose(np.sort(st.get_times(0)), [0.25, 0.50, 0.75, 1.0])) df = st.to_dataframe() assert(df.shape == (7, 3)) assert(set(df.columns) == {'timestamps', 'population', 'node_ids'}) all_spikes = list(st.spikes()) assert(len(all_spikes) == 7) assert(isinstance(all_spikes[0][0], (float, float))) assert(all_spikes[0][1] == pop_na) assert(isinstance(all_spikes[0][2], (int, np.uint)))
[docs] def test_load_sonata(): warnings.simplefilter("ignore", UserWarning) # Sonata adaptor's factory method tmp_sonata = tempfile.NamedTemporaryFile(suffix='.h5') with h5py.File(tmp_sonata.name, 'w') as h5: add_hdf5_magic(h5) add_hdf5_version(h5) h5.create_dataset('/spikes/V1/node_ids', data=[0, 0, 0, 0, 2, 1, 2], dtype=np.uint) h5.create_dataset('/spikes/V1/timestamps', data=[0.25, 0.5, 0.75, 1.0, 3.0, 0.001, 2.0], dtype=float) h5.create_dataset('/spikes/V2/node_ids', data=[10, 10, 10], dtype=np.uint) h5.create_dataset('/spikes/V2/timestamps', data=[4.0, 4.0, 4.0], dtype=float) h5.create_group('/spikes/V3') tmp_sonata_old = tempfile.NamedTemporaryFile(suffix='.h5') with h5py.File(tmp_sonata_old.name, 'w') as h5: add_hdf5_magic(h5) add_hdf5_version(h5) h5.create_dataset('/spikes/gids', data=[0, 0, 0, 0, 2, 1, 2], dtype=np.uint) h5.create_dataset('/spikes/timestamps', data=[0.25, 0.5, 0.75, 1.0, 3.0, 0.001, 2.0], dtype=float) tmp_sonata_empty = tempfile.NamedTemporaryFile(suffix='.h5') with h5py.File(tmp_sonata_empty.name, 'w') as h5: add_hdf5_magic(h5) add_hdf5_version(h5) h5.create_group('/spikes/') sr = load_sonata_file(tmp_sonata.name) assert(isinstance(sr, SonataSTReader)) sr = load_sonata_file(tmp_sonata_old.name) assert(isinstance(sr, SonataOldReader)) sr = load_sonata_file(tmp_sonata_empty.name) assert(isinstance(sr, EmptySonataReader))
if __name__ == '__main__': # test_write_sonata(STMemoryBuffer(), write_sonata) # test_write_sonata(STMemoryBuffer(), write_sonata_itr) # test_write_sonata(STCSVBuffer(cache_dir=tempfile.mkdtemp()), write_sonata_itr) # test_write_sonata_empty() # test_write_sonata_bytime() # test_write_sonata_byid() # test_write_sonata_append() # test_old_populations('spike_files/spikes.old.h5') # test_single_populations('spike_files/spikes.one_pop.h5') # test_multi_populations('spike_files/spikes.multipop.h5') # test_multipop_with_default('spike_files/spikes.multipop.h5') # test_empty_spikes() # test_sonata_reader() # test_oldsonata_reader() test_load_sonata()