Source code for bmtk.tests.utils.reports.spike_trains.test_csv_adaptor_mpi

import pytest
import numpy as np
import pandas as pd
import tempfile

from bmtk.utils.reports.spike_trains.spike_train_buffer import STMPIBuffer, STCSVMPIBufferV2
# from bmtk.utils.reports.spike_trains.adaptors.csv_adaptors import write_csv, write_csv_itr
from bmtk.utils.reports.spike_trains.spikes_file_writers import write_csv, write_csv_itr
from bmtk.utils.reports.spike_trains import sort_order

try:
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    MPI_rank = comm.Get_rank()
    MPI_size = comm.Get_size()
    has_mpi = True
except:
    MPI_rank = 0
    MPI_size = 1
    has_mpi = False


[docs] def create_st_buffer_mpi(st_cls): # Helper for creating spike_trains object if issubclass(st_cls, STCSVMPIBufferV2): tmp_dir = tempfile.mkdtemp() if MPI_rank == 0 else None tmp_dir = comm.bcast(tmp_dir, 0) return st_cls(cache_dir=tmp_dir) else: return st_cls()
[docs] def tmpdir(): tmp_dir = tempfile.mkdtemp() if MPI_rank == 0 else None tmp_dir = comm.bcast(tmp_dir, 0) return tmp_dir
[docs] def tmpfile(): tmp_file = tempfile.NamedTemporaryFile(suffix='.csv').name if MPI_rank == 0 else None tmp_file = comm.bcast(tmp_file, 0) return tmp_file
[docs] @pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi') @pytest.mark.parametrize('st_cls', [ STMPIBuffer, STCSVMPIBufferV2 ]) @pytest.mark.parametrize('write_fnc', [ write_csv, write_csv_itr ]) def test_write_csv(st_cls, write_fnc): st = create_st_buffer_mpi(st_cls) st.add_spikes(population='V1', node_ids=MPI_rank, timestamps=[MPI_rank]*5) st.add_spike(population='V2', node_id=MPI_size, timestamp=float(MPI_rank)) st.add_spikes(population='R{}'.format(MPI_rank), node_ids=0, timestamps=[0.1, 0.2, 0.3, 0.4]) tmp_csv = tmpfile() write_fnc(tmp_csv, st) df = pd.read_csv(tmp_csv, sep=' ') assert(set(df['population'].unique()) == {'R{}'.format(r) for r in range(MPI_size)} | {'V1', 'V2'}) assert(df.shape == (5*MPI_size + MPI_size + 4*MPI_size, 3)) assert(set(df[df['population'] == 'V1']['node_ids'].unique()) == {i for i in range(MPI_size)}) assert(set(df[df['population'] == 'V1']['timestamps']) == {float(i) for i in range(MPI_size)}) for r in range(MPI_size): rank_df = df[df['population'] == 'R{}'.format(r)] assert(np.all(rank_df['node_ids'] == [0, 0, 0, 0])) assert(np.allclose(np.sort(rank_df['timestamps']), [0.1, 0.2, 0.3, 0.4]))
[docs] @pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi') @pytest.mark.parametrize('st_cls', [ STMPIBuffer, STCSVMPIBufferV2 ]) @pytest.mark.parametrize('write_fnc', [ write_csv, write_csv_itr ]) def test_write_csv_byid(st_cls, write_fnc): st = create_st_buffer_mpi(st_cls) st.add_spikes(population='V1', node_ids=[MPI_size + MPI_rank, MPI_rank], timestamps=[0.5, 1.0]) tmp_csv = tmpfile() write_fnc(tmp_csv, st, sort_order=sort_order.by_id) df = pd.read_csv(tmp_csv, sep=' ') assert(df.shape == (MPI_size*2, 3)) assert(np.all(df['population'].unique() == ['V1'])) assert(np.all(df['node_ids'] == np.arange(MPI_size*2)))
[docs] @pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi') @pytest.mark.parametrize('st_cls', [ STMPIBuffer, STCSVMPIBufferV2 ]) @pytest.mark.parametrize('write_fnc', [ write_csv, write_csv_itr ]) def test_write_csv_bytime(st_cls, write_fnc): st = create_st_buffer_mpi(st_cls) st.add_spikes(population='V1', node_ids=[MPI_rank, MPI_rank], timestamps=np.array([MPI_rank/10.0, (MPI_size + MPI_rank)/10.0], dtype=float)) tmp_csv = tmpfile() write_fnc(tmp_csv, st, sort_order=sort_order.by_time) df = pd.read_csv(tmp_csv, sep=' ') assert(df.shape == (MPI_size*2, 3)) assert(np.all(df['population'].unique() == ['V1'])) assert(np.all(np.diff(df['timestamps']) > 0))
[docs] @pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi') @pytest.mark.parametrize('st_cls', [ STMPIBuffer, STCSVMPIBufferV2 ]) @pytest.mark.parametrize('write_fnc', [ write_csv, write_csv_itr ]) def test_write_csv_empty(st_cls, write_fnc): st = create_st_buffer_mpi(st_cls) tmp_csv = tmpfile() write_fnc(tmp_csv, st, sort_order=sort_order.by_time) df = pd.read_csv(tmp_csv, sep=' ') assert(df.shape == (0, 3))
if __name__ == '__main__': # test_write_csv(STMPIBuffer, write_csv_itr) # test_write_csv(STMPIBuffer, write_csv) # test_write_csv_byid(STMPIBuffer, write_csv) # test_write_csv_bytime(STMPIBuffer, write_csv) test_write_csv_empty(STMPIBuffer, write_csv)