Source code for bmtk.tests.builder.test_edges_sorter

import pytest
import tempfile
import h5py
import numpy as np
import logging

from bmtk.utils.sonata.utils import add_hdf5_magic, add_hdf5_version, check_magic, get_version
from bmtk.builder.edges_sorter import quicksort_edges, external_merge_sort


def _check_edges(h5, n_edges):
    assert(check_magic(h5))
    assert(get_version(h5))

    assert(len(h5['/edges/a_to_b/source_node_id']) == n_edges)
    assert(h5['/edges/a_to_b/source_node_id'].attrs['node_population'] == 'a')
    assert(len(h5['/edges/a_to_b/target_node_id']) == n_edges)
    assert(h5['/edges/a_to_b/target_node_id'].attrs['node_population'] == 'b')
    assert(len(h5['/edges/a_to_b/edge_type_id']) == n_edges)
    assert(len(h5['/edges/a_to_b/edge_group_id']) == n_edges)
    assert(len(h5['/edges/a_to_b/edge_group_index']) == n_edges)
    for i in range(n_edges):
        grp_id = h5['/edges/a_to_b/edge_group_id'][i]
        grp_indx = h5['/edges/a_to_b/edge_group_index'][i]
        assert (h5['/edges/a_to_b/source_node_id'][i] == h5['/edges/a_to_b'][str(grp_id)]['src_id'][grp_indx])
        assert (h5['/edges/a_to_b/target_node_id'][i] == h5['/edges/a_to_b'][str(grp_id)]['trg_id'][grp_indx])
        assert (h5['/edges/a_to_b/edge_type_id'][i] == h5['/edges/a_to_b'][str(grp_id)]['et_id'][grp_indx])


[docs] @pytest.mark.parametrize('sort_func,sort_params', [ (quicksort_edges, {}), (external_merge_sort, {'sort_model_properties': False, 'n_chunks': 5}), (external_merge_sort, {'sort_model_properties': True, 'n_chunks': 5}), ]) def test_sort(sort_func, sort_params): tmp_edges_h5 = tempfile.NamedTemporaryFile(suffix='.h5') source_node_ids = np.tile([0, 1], 5) target_node_ids = np.arange(20, 0, -2, dtype=int) edge_type_ids = np.repeat([103, 100, 104, 101, 102], 2) edge_group_ids = np.repeat([1, 0], 5) edge_group_indices = np.tile(range(5), 2) n_edges = 10 with h5py.File(tmp_edges_h5.name, 'w') as h5: add_hdf5_magic(h5) add_hdf5_version(h5) h5.create_dataset('/edges/a_to_b/source_node_id', data=source_node_ids) h5['/edges/a_to_b/source_node_id'].attrs['node_population'] = 'a' h5.create_dataset('/edges/a_to_b/target_node_id', data=target_node_ids) h5['/edges/a_to_b/target_node_id'].attrs['node_population'] = 'b' h5.create_dataset('/edges/a_to_b/edge_group_id', data=edge_group_ids) h5.create_dataset('/edges/a_to_b/edge_group_index', data=edge_group_indices) h5.create_dataset('/edges/a_to_b/edge_type_id', data=edge_type_ids) for grp_id in np.unique(h5['/edges/a_to_b/source_node_id'][()]): model_grp = h5.create_group('/edges/a_to_b/{}'.format(grp_id)) grp_mask = edge_group_ids == grp_id model_grp.create_dataset('src_id', data=source_node_ids[grp_mask]) model_grp.create_dataset('trg_id', data=target_node_ids[grp_mask]) model_grp.create_dataset('et_id', data=edge_type_ids[grp_mask]) # Sort by source_node_id sorted_tmp_edges_h5 = tempfile.NamedTemporaryFile(suffix='.h5') sort_func( input_edges_path=tmp_edges_h5.name, output_edges_path=sorted_tmp_edges_h5.name, edges_population='/edges/a_to_b', sort_by='source_node_id', **sort_params ) with h5py.File(sorted_tmp_edges_h5.name, 'r') as h5: assert(np.all(np.diff(h5['/edges/a_to_b/source_node_id'][()]) >= 0)) _check_edges(h5, n_edges=n_edges) # Sort by target_node_id sorted_tmp_edges_h5 = tempfile.NamedTemporaryFile(suffix='.h5') sort_func( input_edges_path=tmp_edges_h5.name, output_edges_path=sorted_tmp_edges_h5.name, edges_population='/edges/a_to_b', sort_by='target_node_id', **sort_params ) with h5py.File(sorted_tmp_edges_h5.name, 'r') as h5: assert(np.all(np.diff(h5['/edges/a_to_b/target_node_id'][()]) >= 0)) _check_edges(h5, n_edges=n_edges) # Sort by edge_type_id sorted_tmp_edges_h5 = tempfile.NamedTemporaryFile(suffix='.h5') sort_func( input_edges_path=tmp_edges_h5.name, output_edges_path=sorted_tmp_edges_h5.name, edges_population='/edges/a_to_b', sort_by='edge_type_id', **sort_params ) with h5py.File(sorted_tmp_edges_h5.name, 'r') as h5: assert(np.all(np.diff(h5['/edges/a_to_b/edge_type_id'][()]) >= 0)) _check_edges(h5, n_edges=n_edges) # Sort by edge_group_id sorted_tmp_edges_h5 = tempfile.NamedTemporaryFile(suffix='.h5') sort_func( input_edges_path=tmp_edges_h5.name, output_edges_path=sorted_tmp_edges_h5.name, edges_population='/edges/a_to_b', sort_by='edge_group_id', **sort_params ) with h5py.File(sorted_tmp_edges_h5.name, 'r') as h5: assert(np.all(np.diff(h5['/edges/a_to_b/edge_group_id'][()]) >= 0)) _check_edges(h5, n_edges=n_edges)
if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG) # test_sort(quicksort_edges, {}) test_sort(external_merge_sort, {'sort_model_properties': False, 'n_chunks': 5}) # test_sort(external_merge_sort, {'sort_model_properties': True, 'n_chunks': 5})