Source code for bmtk.tests.utils.reports.spike_trains.test_plotting

import pytest
import numpy as np
import tempfile

from bmtk.utils.reports import SpikeTrains
from bmtk.utils.reports.spike_trains import plotting

matplotlib = pytest.importorskip('matplotlib')


matplotlib.rcParams.update({'figure.max_open_warning': 0})  # stop pytest memory warning


[docs] @pytest.fixture def spike_trains(): st = SpikeTrains(default_population='V1') for n in range(0, 20): times = np.random.uniform(0.0, 1500.0, 10) times = np.sort(times) st.add_spikes(node_ids=n, timestamps=times) return st
[docs] def test_load_spikes_api(spike_trains): fig = plotting.plot_raster(spike_trains=spike_trains, show=False) assert(isinstance(fig, matplotlib.figure.Figure))
[docs] def test_load_spikes_file(spike_trains): tmpfile = tempfile.NamedTemporaryFile(suffix='.h5') spike_trains.to_sonata(tmpfile.name) fig = plotting.plot_raster(spike_trains=tmpfile.name, show=False) assert(isinstance(fig, matplotlib.figure.Figure))
[docs] @pytest.mark.parametrize('node_groups', [ None, [{'node_ids': np.arange(10, 20), 'label': 'all', 'c': 'k'}], [{'node_ids': [0, 1, 2, 3, 4, 5, 7, 8, 9, 10], 'label': 'low'}, {'node_ids': np.array([11, 12, 13, 14, 15]), 'label': 'mid'}, {'node_ids': range(16, 22), 'label': 'high'}] ]) @pytest.mark.parametrize('with_histogram', [ True, False ]) def test_plot_raster(spike_trains, node_groups, with_histogram): fig = plotting.plot_raster(spike_trains=spike_trains, node_groups=node_groups, with_histogram=with_histogram, show=False) assert(isinstance(fig, matplotlib.figure.Figure)) assert(len(fig.axes) == 2 if with_histogram else 1)
[docs] @pytest.mark.parametrize('node_groups', [ None, [{'node_ids': np.arange(10, 20), 'label': 'all', 'c': 'k'}], [{'node_ids': [0, 1, 2, 3, 4, 5, 7, 8, 9, 10], 'label': 'low'}, {'node_ids': np.array([11, 12, 13, 14, 15]), 'label': 'mid'}, {'node_ids': range(16, 22), 'label': 'high'}] ]) @pytest.mark.parametrize('smoothing', [ True, False, None ]) def test_plot_rates(spike_trains, node_groups, smoothing): fig = plotting.plot_rates(spike_trains=spike_trains, node_groups=node_groups, smoothing=smoothing, show=False) assert(isinstance(fig, matplotlib.figure.Figure))
[docs] @pytest.mark.parametrize('node_groups', [ None, [{'node_ids': np.arange(10, 20), 'label': 'all'}], [{'node_ids': [0, 1, 2, 3, 4, 5, 7, 8, 9, 10], 'label': 'low'}, {'node_ids': np.array([11, 12, 13, 14, 15]), 'label': 'mid'}, {'node_ids': range(16, 22), 'label': 'high'}] ]) def test_plot_rates_boxplot(spike_trains, node_groups): fig = plotting.plot_rates(spike_trains=spike_trains, node_groups=node_groups, show=False) assert(isinstance(fig, matplotlib.figure.Figure))
[docs] def show_plot(): st = SpikeTrains(default_population='V1') for n in range(0, 100): n_vals = np.sin(n*np.pi/100)*150 + 10 times = np.random.uniform(0.0, 1500.0, int(n_vals)) times = np.sort(times) st.add_spikes(node_ids=n, timestamps=times) # plotting.plot_raster(spike_trains=st, title='V1 Spikes') # plotting.plot_rates( # spike_trains=st, # node_groups=[{'node_ids': [0, 1, 2, 3, 4, 5, 7, 8, 9, 10], 'label': 'low'}, # {'node_ids': np.array([11, 12, 13, 14, 15]), 'label': 'mid'}, # {'node_ids': range(16, 110), 'label': 'high'}], # smoothing=True # ) node_groups = [{'node_ids': [0, 1, 2, 3, 4, 5, 7, 8, 9, 10], 'label': 'low'}, {'node_ids': np.array([11, 12, 13, 14, 15]), 'label': 'mid'}, {'node_ids': range(16, 110), 'label': 'high'}] plotting.plot_rates_boxplot( spike_trains=st, node_groups=[{'node_ids': [0, 1, 2, 3, 4, 5, 7, 8, 9, 10], 'label': 'low'}, {'node_ids': np.array([11, 12, 13, 14, 15]), 'label': 'mid'}, {'node_ids': range(16, 110), 'label': 'high'}] # node_groups=node_groups ) print(node_groups)
if __name__ == '__main__': # test_load_spikes_api() # test_load_spikes_file() # test_raster_base() # test_raster_no_hist() # test_raster_node_groups() # show_plot() #test_plot_rates(spike_trains=spike_trains(), node_groups=[{'node_ids': [0, 1, 2, 3, 4, 5, 7, 8, 9, 10], 'label': 'low'}, # {'node_ids': np.array([11, 12, 13, 14, 15]), 'label': 'mid'}, # {'node_ids': range(16, 110), 'label': 'high'}], smoothing=False) # test_plot_raster( # spike_trains=spike_trains(), # node_groups=[{'node_ids': [0, 1, 2, 3, 4, 5, 7, 8, 9, 10], 'label': 'low', 'c': 'k'}, # {'node_ids': np.array([11, 12, 13, 14, 15]), 'label': 'mid'}, # {'node_ids': range(16, 110), 'label': 'high'}], # with_histogram=False) test_plot_raster( spike_trains=spike_trains(), node_groups=[{'node_ids': np.arange(10, 100), 'label': 'all'}], with_histogram=False)