import os
import shutil
import pytest
import numpy as np
import pandas as pd
import h5py
import tempfile
from bmtk.builder.network_adaptors.dm_network import DenseNetwork
# from bmtk.builder.network_adaptors.dm_network_orig import DenseNetworkOrig
[docs]
@pytest.mark.parametrize('network_cls', [
(DenseNetwork),
# (DenseNetworkOrig) # dropped support for original dense network
])
def test_save_nsyn_table(network_cls):
np.random.seed(100)
net = network_cls('NET1')
net.add_nodes(N=10, position=[(0.0, 1.0, -1.0)]*10, cell_type='Scnna1', ei='e')
net.add_nodes(N=10, position=[(0.0, 1.0, -1.0)]*10, cell_type='PV1', ei='i')
net.add_nodes(N=10, position=[(0.0, 1.0, -1.0)]*10, tags=np.linspace(0, 100, 10), cell_type='PV2', ei='i')
net.add_edges(source={'ei': 'i'}, target={'ei': 'e'}, connection_rule=lambda s, t: 1,
p1='e2i', p2='e2i')
net.add_edges(source=net.nodes(cell_type='Scnna1'), target=net.nodes(cell_type='PV1'),
connection_rule=lambda s, t: 2, p1='s2p')
net.build()
nodes_h5 = tempfile.NamedTemporaryFile(suffix='.h5')
nodes_csv = tempfile.NamedTemporaryFile(suffix='.csv')
edges_h5 = tempfile.NamedTemporaryFile(suffix='.h5')
edges_csv = tempfile.NamedTemporaryFile(suffix='.csv')
net.save_nodes(nodes_h5.name, nodes_csv.name)
net.save_edges(edges_h5.name, edges_csv.name)
assert(os.path.exists(nodes_h5.name) and os.path.exists(nodes_csv.name))
node_types_df = pd.read_csv(nodes_csv.name, sep=' ')
assert(len(node_types_df) == 3)
assert('cell_type' in node_types_df.columns)
assert('ei' in node_types_df.columns)
assert('positions' not in node_types_df.columns)
nodes_h5 = h5py.File(nodes_h5.name, 'r')
assert('node_id' in nodes_h5['/nodes/NET1'])
assert(len(nodes_h5['/nodes/NET1/node_id']) == 30)
assert(len(nodes_h5['/nodes/NET1/node_type_id']) == 30)
assert(len(nodes_h5['/nodes/NET1/node_group_id']) == 30)
assert(len(nodes_h5['/nodes/NET1/node_group_index']) == 30)
node_groups = {nid: grp for nid, grp in nodes_h5['/nodes/NET1'].items() if isinstance(grp, h5py.Group)}
for grp in node_groups.values():
if len(grp) == 1:
assert('position' in grp and len(grp['position']) == 20)
elif len(grp) == 2:
assert('position' in grp and len(grp['position']) == 10)
assert('tags' in grp and len(grp['tags']) == 10)
else:
assert False
assert(os.path.exists(edges_h5.name) and os.path.exists(edges_csv.name))
edge_types_df = pd.read_csv(edges_csv.name, sep=' ')
assert (len(edge_types_df) == 2)
assert ('p1' in edge_types_df.columns)
assert ('p2' in edge_types_df.columns)
edges_h5 = h5py.File(edges_h5.name, 'r')
assert('source_to_target' in edges_h5['/edges/NET1_to_NET1/indices'])
assert('target_to_source' in edges_h5['/edges/NET1_to_NET1/indices'])
assert(len(edges_h5['/edges/NET1_to_NET1/target_node_id']) == 300)
assert(len(edges_h5['/edges/NET1_to_NET1/source_node_id']) == 300)
# Check edges and node ids match up
# warning, builder may not build edges in sequential order
nid_idxs = np.sort(np.argwhere(edges_h5['/edges/NET1_to_NET1/target_node_id'][()] == 0).flatten())
trg_ids = edges_h5['/edges/NET1_to_NET1/source_node_id'][nid_idxs]
assert(np.all(trg_ids >= 10))
assert(np.all(30 > trg_ids))
edge_0 = nid_idxs[0]
assert(edges_h5['/edges/NET1_to_NET1/edge_type_id'][edge_0] == 100)
edge_id = edges_h5['/edges/NET1_to_NET1/edge_group_id'][edge_0]
edge_idx = edges_h5['/edges/NET1_to_NET1/edge_group_index'][edge_0]
assert(edges_h5['/edges/NET1_to_NET1'][str(edge_id)]['nsyns'][edge_idx] == 1)
nid_idxs = np.sort(np.argwhere(edges_h5['/edges/NET1_to_NET1/target_node_id'][()] == 19).flatten())
trg_ids = edges_h5['/edges/NET1_to_NET1/source_node_id'][nid_idxs]
print(trg_ids)
assert(np.all(trg_ids >= 0))
assert(np.all(10 > trg_ids))
edge_0 = nid_idxs[0]
assert(edges_h5['/edges/NET1_to_NET1/edge_type_id'][edge_0] == 101)
edge_id = edges_h5['/edges/NET1_to_NET1/edge_group_id'][edge_0]
edge_idx = edges_h5['/edges/NET1_to_NET1/edge_group_index'][edge_0]
assert(edges_h5['/edges/NET1_to_NET1'][str(edge_id)]['nsyns'][edge_idx] == 2)
[docs]
@pytest.mark.parametrize('network_cls', [
(DenseNetwork),
# (DenseNetworkOrig)
])
def test_save_weights(network_cls):
net = network_cls('NET1')
net.add_nodes(N=100, position=[(0.0, 1.0, -1.0)]*100, cell_type='Scnna1', ei='e')
net.add_nodes(N=100, position=[(0.0, 1.0, -1.0)]*100, cell_type='PV1', ei='i')
net.add_nodes(N=100, position=[(0.0, 1.0, -1.0)]*100, tags=np.linspace(0, 100, 100), cell_type='PV2', ei='i')
cm = net.add_edges(source={'ei': 'i'}, target={'ei': 'e'}, connection_rule=lambda s, t: 3,
p1='e2i', p2='e2i') # 200*100 = 60000 edges
cm.add_properties(names=['segment', 'distance'], rule=lambda s, t: [1, 0.5], dtypes=[int, float])
net.add_edges(source=net.nodes(cell_type='Scnna1'), target=net.nodes(cell_type='PV1'),
connection_rule=lambda s, t: 2, p1='s2p') # 100*100 = 20000'
net.build()
net_dir = tempfile.mkdtemp()
net.save_nodes('tmp_nodes.h5', 'tmp_node_types.csv', output_dir=net_dir)
net.save_edges('tmp_edges.h5', 'tmp_edge_types.csv', output_dir=net_dir)
edges_h5 = h5py.File('{}/tmp_edges.h5'.format(net_dir), 'r')
assert(net.nedges == 80000)
assert(len(edges_h5['/edges/NET1_to_NET1/0/distance']) == 60000)
assert(len(edges_h5['/edges/NET1_to_NET1/0/segment']) == 60000)
assert(len(edges_h5['/edges/NET1_to_NET1/1/nsyns']) == 10000)
assert(edges_h5['/edges/NET1_to_NET1/0/distance'][0] == 0.5)
assert(edges_h5['/edges/NET1_to_NET1/0/segment'][0] == 1)
assert(edges_h5['/edges/NET1_to_NET1/1/nsyns'][0] == 2)
[docs]
@pytest.mark.parametrize('network_cls', [
(DenseNetwork),
# (DenseNetworkOrig)
])
def test_save_multinetwork(network_cls):
net1 = DenseNetwork('NET1')
net1.add_nodes(N=100, position=[(0.0, 1.0, -1.0)] * 100, cell_type='Scnna1', ei='e')
net1.add_edges(source={'ei': 'e'}, target={'ei': 'e'}, connection_rule=5, ctype_1='n1_rec')
net1.build()
net2 = DenseNetwork('NET2')
net2.add_nodes(N=10, position=[(0.0, 1.0, -1.0)] * 10, cell_type='PV1', ei='i')
net2.add_edges(connection_rule=10, ctype_1='n2_rec')
net2.add_edges(source=net1.nodes(), target={'ei': 'i'}, connection_rule=1, ctype_2='n1_n2')
net2.add_edges(target=net1.nodes(cell_type='Scnna1'), source={'cell_type': 'PV1'}, connection_rule=2,
ctype_2='n2_n1')
net2.build()
net_dir = tempfile.mkdtemp()
net1.save_edges(output_dir=net_dir)
net2.save_edges(output_dir=net_dir)
n1_n1_fname = '{}/{}_{}'.format(net_dir, 'NET1', 'NET1')
edges_h5 = h5py.File(n1_n1_fname + '_edges.h5', 'r')
assert(len(edges_h5['/edges/NET1_to_NET1/target_node_id']) == 100*100)
assert(len(edges_h5['/edges/NET1_to_NET1/0/nsyns']) == 100*100)
assert(edges_h5['/edges/NET1_to_NET1/0/nsyns'][0] == 5)
edge_types_csv = pd.read_csv(n1_n1_fname + '_edge_types.csv', sep=' ')
assert(len(edge_types_csv) == 1)
assert('ctype_2' not in edge_types_csv.columns.values)
assert(edge_types_csv['ctype_1'].iloc[0] == 'n1_rec')
n1_n2_fname = '{}/{}_{}'.format(net_dir, 'NET1', 'NET2')
edges_h5 = h5py.File(n1_n2_fname + '_edges.h5', 'r')
assert(len(edges_h5['/edges/NET1_to_NET2/target_node_id']) == 100*10)
assert(len(edges_h5['/edges/NET1_to_NET2/0/nsyns']) == 100*10)
assert(edges_h5['/edges/NET1_to_NET2/0/nsyns'][0] == 1)
edge_types_csv = pd.read_csv(n1_n2_fname + '_edge_types.csv', sep=' ')
assert(len(edge_types_csv) == 1)
assert('ctype_1' not in edge_types_csv.columns.values)
assert(edge_types_csv['ctype_2'].iloc[0] == 'n1_n2')
n2_n1_fname = '{}/{}_{}'.format(net_dir, 'NET2', 'NET1')
edges_h5 = h5py.File(n2_n1_fname + '_edges.h5', 'r')
assert(len(edges_h5['/edges/NET2_to_NET1/target_node_id']) == 100*10)
assert(len(edges_h5['/edges/NET2_to_NET1/0/nsyns']) == 100*10)
assert(edges_h5['/edges/NET2_to_NET1/0/nsyns'][0] == 2)
edge_types_csv = pd.read_csv(n2_n1_fname + '_edge_types.csv', sep=' ')
assert(len(edge_types_csv) == 1)
assert('ctype_1' not in edge_types_csv.columns.values)
assert(edge_types_csv['ctype_2'].iloc[0] == 'n2_n1')
n2_n2_fname = '{}/{}_{}'.format(net_dir, 'NET2', 'NET2')
edges_h5 = h5py.File(n2_n2_fname + '_edges.h5', 'r')
assert(len(edges_h5['/edges/NET2_to_NET2/target_node_id']) == 10*10)
assert(len(edges_h5['/edges/NET2_to_NET2/0/nsyns']) == 10*10)
assert(edges_h5['/edges/NET2_to_NET2/0/nsyns'][0] == 10)
edge_types_csv = pd.read_csv(n2_n2_fname + '_edge_types.csv', sep=' ')
assert(len(edge_types_csv) == 1)
assert('ctype_2' not in edge_types_csv.columns.values)
assert(edge_types_csv['ctype_1'].iloc[0] == 'n2_rec')
if __name__ == '__main__':
test_save_nsyn_table(DenseNetwork)
# test_save_weights(DenseNetwork)
# test_save_multinetwork(DenseNetwork)