Source code for bmtk.tests.utils.reports.spike_trains.test_spikes_buffer_mpi

import pytest
import numpy as np
import tempfile
from six import string_types


from bmtk.utils.reports.spike_trains.spike_train_buffer import STMPIBuffer, STCSVMPIBufferV2

try:
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    bcast = comm.bcast
    MPI_rank = comm.Get_rank()
    MPI_size = comm.Get_size()
    has_mpi = True
except:
    MPI_rank = 0
    MPI_size = 1
    bcast = lambda v, n: v
    has_mpi = False


[docs] def tmpdir(): tmp_dir = tempfile.mkdtemp() if MPI_rank == 0 else None tmp_dir = bcast(tmp_dir, 0) return tmp_dir
[docs] @pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi') @pytest.mark.parametrize('st', [ STMPIBuffer(default_population='V1'), STCSVMPIBufferV2(cache_dir=tmpdir()) ]) def test_basic(st): st.add_spikes(population='V1', node_ids=MPI_rank, timestamps=[MPI_rank]*5) st.add_spike(population='V2', node_id=MPI_size, timestamp=float(MPI_rank)) assert(set(st.populations) == {'V1', 'V2'}) assert(set(st.get_populations(on_rank='all')) == {'V1', 'V2'}) assert(set(st.get_populations(on_rank='local')) == {'V1', 'V2'}) assert(st.n_spikes('V1') == MPI_size*5) assert(st.n_spikes('V1', on_rank='all') == MPI_size*5) assert(st.n_spikes('V1', on_rank='local') == 5) assert(st.n_spikes('V2') == MPI_size) assert(st.n_spikes('V2', on_rank='all') == MPI_size) assert(st.n_spikes('V2', on_rank='local') == 1) assert(np.all(np.sort(st.node_ids('V1')) == np.arange(MPI_size))) assert(np.all(np.sort(st.node_ids('V1', on_rank='all')) == np.arange(MPI_size))) assert(np.all(st.node_ids('V1', on_rank='local') == [MPI_rank])) assert(np.all(st.node_ids('V2') == [MPI_size])) assert(np.all(st.node_ids('V2', on_rank='all') == [MPI_size])) assert(np.all(st.node_ids('V2', on_rank='local') == [MPI_size])) assert(np.allclose(st.get_times(population='V1', node_id=0), [0.0]*5)) assert(np.allclose(st.get_times(population='V1', node_id=0, on_rank='all'), [0.0]*5)) times = st.get_times(population='V1', node_id=0, on_rank='local') if MPI_rank == 0: assert(np.allclose(times, [0.0, 0.0, 0.0, 0.0, 0.0])) else: assert(len(times) == 0) assert(np.allclose(np.sort(st.get_times(population='V2', node_id=MPI_size, on_rank='all')), np.arange(MPI_size).astype(float))) assert(np.allclose(st.get_times(population='V2', node_id=MPI_size, on_rank='local'), [float(MPI_rank)]))
[docs] @pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi') @pytest.mark.parametrize('st', [ STMPIBuffer(default_population='V1'), STCSVMPIBufferV2(cache_dir=tmpdir()) ]) def test_basic_root(st): st.add_spikes(population='V1', node_ids=0, timestamps=[float(MPI_rank)]) st.add_spike(population='V1', node_id=1, timestamp=1.0) st.add_spikes(population='R{}'.format(MPI_rank+1), node_ids=MPI_size+1, timestamps=[0.1, 0.2, 0.3]) pops = st.get_populations(on_rank='root') n_spikes = st.n_spikes('V1', on_rank='root') node_ids = st.node_ids('V1', on_rank='root') timestamps = st.get_times(population='V1', node_id=0, on_rank='root') if MPI_rank == 0: assert(set(pops) == {'R{}'.format(r+1) for r in range(MPI_size)} | {'V1'}) assert(n_spikes == MPI_size*2) assert(set(node_ids) == {0, 1}) assert(np.allclose(np.sort(timestamps), np.arange(MPI_size, dtype=float))) else: assert(pops is None) assert(n_spikes is None) assert(node_ids is None) assert(timestamps is None)
[docs] @pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi') @pytest.mark.parametrize('st', [ STMPIBuffer(default_population='V1'), STCSVMPIBufferV2(cache_dir=tmpdir()) ]) def test_split_ids(st): st.add_spikes(population='V1', node_ids=MPI_rank, timestamps=[MPI_rank/10.0]*5) assert(set(st.populations) == {'V1'}) assert(st.n_spikes('V1', on_rank='local') == 5) assert(st.n_spikes('V1', on_rank='all') == 5*MPI_size) assert(np.all(np.sort(st.node_ids(population='V1', on_rank='all')) == np.arange(MPI_size))) df = st.to_dataframe(populations='V1', on_rank='all') assert(len(df) == MPI_size*5) assert(np.all(np.sort(df['node_ids'].unique()) == np.arange(MPI_size))) df = st.to_dataframe(populations='V1', on_rank='root') if MPI_rank == 0: assert(len(df) == MPI_size*5) else: assert(df is None)
[docs] @pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi') @pytest.mark.parametrize('st', [ STMPIBuffer(default_population='V1'), STCSVMPIBufferV2(cache_dir=tmpdir()) ]) def test_to_dataframe(st): st.add_spikes(population='V1', node_ids=MPI_rank, timestamps=[MPI_rank / 10.0]*5) st.add_spike(population='V1', node_id=MPI_size, timestamp=float(MPI_rank)) st.add_spikes(population='R{}'.format(MPI_rank), node_ids=MPI_size+1, timestamps=np.linspace(0.0, 1.0, 5)) df = st.to_dataframe(on_rank='all') assert(len(df) == 5*MPI_size + MPI_size + 5*MPI_size) assert(set(df['population'].unique()) == {'R{}'.format(r) for r in range(MPI_size)} | {'V1'}) assert(set(df[df['population'] == 'V1']['node_ids'].unique()) == {i for i in range(MPI_size+1)}) df = st.to_dataframe(on_rank='local') assert(len(df) == 5 + 1 + 5) assert(set(df['population'].unique()) == {'R{}'.format(MPI_rank), 'V1'}) df = st.to_dataframe(on_rank='root') if MPI_rank == 0: assert(len(df) == 5*MPI_size + MPI_size + 5*MPI_size) assert(set(df['population'].unique()) == {'R{}'.format(r) for r in range(MPI_size)} | {'V1'}) assert(set(df[df['population'] == 'V1']['node_ids'].unique()) == {i for i in range(MPI_size+1)}) else: assert(df is None)
[docs] @pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi') @pytest.mark.parametrize('st', [ STMPIBuffer(default_population='V1'), STCSVMPIBufferV2(cache_dir=tmpdir()) ]) def test_iterator(st): st.add_spikes(population='V1', node_ids=MPI_rank, timestamps=[MPI_rank / 10.0]*5) st.add_spike(population='V1', node_id=MPI_size, timestamp=float(MPI_rank)) st.add_spikes(population='R{}'.format(MPI_rank), node_ids=MPI_size+1, timestamps=np.linspace(0.0, 1.0, 5)) all_spikes = list(st.spikes(on_rank='all')) assert(len(all_spikes) == 5*MPI_size + MPI_size + 5*MPI_size) assert(isinstance(all_spikes[0][0], (float, float))) assert(isinstance(all_spikes[0][1], string_types)) assert(isinstance(all_spikes[0][2], (int, np.uint))) assert({s[1] for s in all_spikes} == {'R{}'.format(r) for r in range(MPI_size)} | {'V1'}) assert({s[2] for s in all_spikes} == {i for i in range(MPI_size+2)}) local_spikes = list(st.spikes(on_rank='local')) assert(len(local_spikes) == 11) assert({s[1] for s in local_spikes} == {'R{}'.format(MPI_rank), 'V1'}) root_spikes = list(st.spikes(on_rank='root')) if MPI_rank == 0: assert(len(root_spikes) == 5*MPI_size + MPI_size + 5*MPI_size) assert({s[1] for s in root_spikes} == {'R{}'.format(r) for r in range(MPI_size)} | {'V1'}) else: assert(len(root_spikes) == 0)
[docs] @pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi') @pytest.mark.parametrize('st', [ STMPIBuffer(default_population='V1'), STCSVMPIBufferV2(cache_dir=tmpdir()) ]) def test_no_root_spikes(st): # An issue we've had before where one of the ranks doesn't have any spiking neurons if MPI_rank > 0: st.add_spikes(population='R{}'.format(MPI_rank), node_ids=MPI_rank, timestamps=np.linspace(0.0, 1.0, 5)) assert(set(st.populations) == {'R{}'.format(r) for r in range(1, MPI_size)}) assert(st.to_dataframe(on_rank='all').shape == ((MPI_size-1)*5, 3)) assert(st.to_dataframe(on_rank='local').shape == (0 if MPI_rank == 0 else 5, 3))
[docs] @pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi') @pytest.mark.parametrize('st', [ STMPIBuffer(default_population='V1'), STCSVMPIBufferV2(cache_dir=tmpdir()) ]) def test_root_spikesonly(st): # Similar to above except only rank 0 has spikes if MPI_rank == 0: st.add_spikes(population='R{}'.format(MPI_rank), node_ids=MPI_rank, timestamps=np.linspace(0.0, 1.0, 5)) assert(st.populations == ['R0']) assert(st.to_dataframe(on_rank='all').shape == (5, 3)) assert(st.to_dataframe(on_rank='local').shape == (0 if MPI_rank != 0 else 5, 3))
[docs] @pytest.mark.skipif(not has_mpi, reason='Can only run test using mpi') @pytest.mark.parametrize('st', [ STMPIBuffer(default_population='V1'), STCSVMPIBufferV2(cache_dir=tmpdir()) ]) def test_no_spikes(st): # Make sure it still works even if there are no spikes assert(len(st.populations) == 0) assert(st.to_dataframe(on_rank='all').shape == (0, 3)) assert(st.to_dataframe(on_rank='local').shape == (0, 3)) df = st.to_dataframe(on_rank='root') assert(df is None if MPI_rank != 0 else df.shape == (0, 3)) assert(list(st.spikes(on_rank='all')) == []) assert(list(st.spikes(on_rank='all')) == []) assert(list(st.spikes('all')) == [])
if __name__ == '__main__': # test_basic(STMPIBuffer(default_population='V1')) # test_basic(STCSVMPIBufferV2(cache_dir=tmpdir())) # test_basic_root(STMPIBuffer(default_population='V1')) # test_basic_root(STCSVMPIBufferV2(cache_dir=tmpdir())) # test_split_ids(STMPIBuffer(default_population='V1')) # test_split_ids(STCSVMPIBufferV2(cache_dir=tmpdir())) # test_to_dataframe(STMPIBuffer(default_population='V1')) # test_to_dataframe(STCSVMPIBufferV2(cache_dir=tmpdir())) # test_iterator(STMPIBuffer(default_population='V1')) # test_iterator(STCSVMPIBufferV2(cache_dir=tmpdir())) # test_no_root_spikes(STMPIBuffer(default_population='V1')) # test_root_spikesonly(STCSVMPIBufferV2(cache_dir=tmpdir())) test_no_spikes(STMPIBuffer(default_population='V1')) # test_no_spikes(STCSVMPIBufferV2(cache_dir=tmpdir()))