import pytest
import numpy as np
import h5py
import tempfile
from bmtk.utils.reports.spike_trains.spike_train_buffer import STMPIBuffer, STCSVMPIBufferV2
# from bmtk.utils.reports.spike_trains.adaptors.sonata_adaptors import write_sonata, write_sonata_itr
from bmtk.utils.reports.spike_trains.spikes_file_writers import write_sonata, write_sonata_itr
from bmtk.utils.sonata.utils import check_magic, get_version
from bmtk.utils.reports.spike_trains import sort_order
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
bcast = comm.bcast
MPI_rank = comm.Get_rank()
MPI_size = comm.Get_size()
has_mpi = True
except:
MPI_rank = 0
MPI_size = 1
has_mpi = False
[docs]
def create_st_buffer_mpi(st_cls):
# Helper for creating spike_trains object
if issubclass(st_cls, STCSVMPIBufferV2):
tmp_dir = tempfile.mkdtemp() if MPI_rank == 0 else None
tmp_dir = comm.bcast(tmp_dir, 0)
return st_cls(cache_dir=tmp_dir)
else:
return st_cls()
[docs]
def tmpdir():
tmp_dir = tempfile.mkdtemp() if MPI_rank == 0 else None
tmp_dir = comm.bcast(tmp_dir, 0)
return tmp_dir
[docs]
def tmpfile():
tmp_file = tempfile.NamedTemporaryFile(suffix='.h5').name if MPI_rank == 0 else None
tmp_file = comm.bcast(tmp_file, 0)
return tmp_file
[docs]
@pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi')
@pytest.mark.parametrize('st_cls', [
STMPIBuffer,
STCSVMPIBufferV2
])
@pytest.mark.parametrize('write_fnc', [
write_sonata,
write_sonata_itr
])
def test_write_sonata(st_cls, write_fnc):
st = create_st_buffer_mpi(st_cls)
st.add_spikes(population='V1', node_ids=MPI_rank, timestamps=[MPI_rank]*5)
st.add_spike(population='V2', node_id=MPI_size, timestamp=float(MPI_rank))
st.add_spikes(population='R{}'.format(MPI_rank), node_ids=0, timestamps=[0.1, 0.2, 0.3, 0.4])
tmp_h5 = tmpfile()
write_fnc(tmp_h5, st)
if MPI_rank == 0:
# Warnings: some systems creates lock even for reading an hdf5 file
with h5py.File(tmp_h5, 'r') as h5:
assert(check_magic(h5))
assert(get_version(h5) is not None)
assert(set(h5['/spikes'].keys()) >= {'R{}'.format(r) for r in range(MPI_size)} | {'V1', 'V2'})
assert(set(h5['/spikes/V1']['node_ids'][()]) == {i for i in range(MPI_size)})
assert(set(h5['/spikes/V2']['timestamps'][()]) == {float(i) for i in range(MPI_size)})
for r in range(MPI_size):
grp = h5['/spikes/R{}'.format(r)]
assert(np.all(grp['node_ids'][()] == [0, 0, 0, 0]))
assert(np.allclose(grp['timestamps'][()], [0.1, 0.2, 0.3, 0.4]))
[docs]
@pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi')
@pytest.mark.parametrize('st_cls', [
STMPIBuffer,
STCSVMPIBufferV2
])
@pytest.mark.parametrize('write_fnc', [
write_sonata,
write_sonata_itr
])
def test_write_sonata_compression(st_cls, write_fnc):
def do_one_test(comp_type, test_type):
st = create_st_buffer_mpi(st_cls)
st.add_spikes(population='V1', node_ids=MPI_rank, timestamps=[MPI_rank]*5)
st.add_spike(population='V2', node_id=MPI_size, timestamp=float(MPI_rank))
st.add_spikes(population='R{}'.format(MPI_rank), node_ids=0, timestamps=[0.1, 0.2, 0.3, 0.4])
tmp_h5 = tmpfile()
write_fnc(tmp_h5, st, compression=comp_type)
if MPI_rank == 0:
# Warnings: some systems creates lock even for reading an hdf5 file
with h5py.File(tmp_h5, 'r') as h5:
assert(check_magic(h5))
assert(get_version(h5) is not None)
assert(h5['/spikes/V1']['node_ids'].compression == test_type)
assert(h5['/spikes/V2']['timestamps'].compression == test_type)
for r in range(MPI_size):
grp = h5['/spikes/R{}'.format(r)]
assert(grp['node_ids'].compression == test_type)
assert(grp['timestamps'].compression == test_type)
comp_types = [None, 'gzip', 3, 'lzf']
test_types = [None, 'gzip', 'gzip', 'lzf']
for comp_type, test_type in zip(comp_types, test_types):
do_one_test(comp_type, test_type)
[docs]
@pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi')
@pytest.mark.parametrize('st_cls', [
STMPIBuffer,
STCSVMPIBufferV2
])
@pytest.mark.parametrize('write_fnc', [
write_sonata,
write_sonata_itr
])
def test_write_sonata_byid(st_cls, write_fnc):
st = create_st_buffer_mpi(st_cls)
st.add_spikes(population='V1', node_ids=[MPI_size + MPI_rank, MPI_rank], timestamps=[0.5, 1.0])
tmp_h5 = tmpfile()
write_fnc(tmp_h5, st, sort_order=sort_order.by_id)
if MPI_rank == 0:
with h5py.File(tmp_h5, 'r') as h5:
assert(check_magic(h5))
assert(get_version(h5) is not None)
assert(np.all(h5['/spikes/V1']['node_ids'][()] == list(range(MPI_size*2))))
assert(len(h5['/spikes/V1']['timestamps'][()]) == MPI_size * 2)
[docs]
@pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi')
@pytest.mark.parametrize('st_cls', [
STMPIBuffer,
STCSVMPIBufferV2
])
@pytest.mark.parametrize('write_fnc', [
write_sonata,
write_sonata_itr
])
def test_write_sonata_bytime(st_cls, write_fnc):
st = create_st_buffer_mpi(st_cls)
st.add_spikes(population='V1', node_ids=[MPI_rank, MPI_rank],
timestamps=np.array([MPI_rank/10.0, (MPI_size + MPI_rank)/10.0], dtype=float))
tmp_h5 = tmpfile()
write_fnc(tmp_h5, st, sort_order=sort_order.by_time)
if MPI_rank == 0:
with h5py.File(tmp_h5, 'r') as h5:
assert(check_magic(h5))
assert(get_version(h5) is not None)
assert(len(h5['/spikes/V1']['node_ids'][()]) == MPI_size*2)
assert(np.all(np.diff(h5['/spikes/V1']['timestamps'][()]) > 0))
[docs]
@pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi')
@pytest.mark.parametrize('st_cls', [
STMPIBuffer,
STCSVMPIBufferV2
])
@pytest.mark.parametrize('write_fnc', [
write_sonata,
write_sonata_itr
])
def test_write_sonata_empty(st_cls, write_fnc):
st = create_st_buffer_mpi(st_cls)
tmp_h5 = tmpfile()
write_fnc(tmp_h5, st)
if MPI_rank == 0:
with h5py.File(tmp_h5, 'r') as h5:
assert(check_magic(h5))
assert(get_version(h5) is not None)
assert('/spikes' in h5)
if __name__ == '__main__':
# test_write_sonata(STMPIBuffer, write_sonata)
# test_write_sonata(STMPIBuffer, write_sonata_itr)
# test_write_sonata_byid(STMPIBuffer, write_sonata)
# test_write_sonata_bytime(STMPIBuffer, write_sonata)
test_write_sonata_empty(STMPIBuffer, write_sonata)