Source code for bmtk.tests.builder.test_index_builders

import pytest
import tempfile
import h5py
import numpy as np

from bmtk.utils.sonata.utils import add_hdf5_magic, add_hdf5_version, check_magic, get_version
from bmtk.builder.index_builders import create_index_on_disk, create_index_in_memory


def _check_index(h5, index_col, id_to_range_col, range_to_edge_col, n_edges):
    assert(index_col in h5)
    assert(id_to_range_col in h5)
    assert(range_to_edge_col in h5)
    id_lu = h5[id_to_range_col][()]
    edge_idx_lu = h5[range_to_edge_col][()]
    edge_count = 0
    for i, id_range in enumerate(id_lu):
        if id_range[1] - id_range[0] > 0:
            for range_idx in range(id_range[0], id_range[1]):
                for edge_idx in range(edge_idx_lu[range_idx][0], edge_idx_lu[range_idx][1]):
                    assert(h5[index_col][edge_idx] == i)
                    edge_count += 1

    assert(edge_count == n_edges)


[docs] @pytest.mark.parametrize('indexer_func,indexer_args', [ (create_index_in_memory, {}), (create_index_on_disk, {'max_edge_reads': 10}) ]) def test_create_index(indexer_func, indexer_args): tmp_edges_h5 = tempfile.NamedTemporaryFile(suffix='.h5') n_edges = 20 source_node_ids = np.tile([0, 1, 3, 4], 5) target_node_ids = np.repeat([73, 72, 52, 4], 5) edge_type_ids = np.random.choice([100, 102, 103], size=n_edges, replace=True) edge_group_ids = np.full(n_edges, fill_value=0) edge_group_indices = np.arange(0, n_edges, dtype=int) with h5py.File(tmp_edges_h5.name, 'w') as h5: add_hdf5_magic(h5) add_hdf5_version(h5) h5.create_dataset('/edges/a_to_b/source_node_id', data=source_node_ids) h5['/edges/a_to_b/source_node_id'].attrs['node_population'] = 'a' h5.create_dataset('/edges/a_to_b/target_node_id', data=target_node_ids) h5['/edges/a_to_b/target_node_id'].attrs['node_population'] = 'b' h5.create_dataset('/edges/a_to_b/edge_group_id', data=edge_group_ids) h5.create_dataset('/edges/a_to_b/edge_group_index', data=edge_group_indices) h5.create_dataset('/edges/a_to_b/edge_type_id', data=edge_type_ids) h5.create_group('/edges/a_to_b/0') # Add target_node_id index indexer_func( edges_file=tmp_edges_h5.name, edges_population='/edges/a_to_b', index_type='target_node_id', **indexer_args ) with h5py.File(tmp_edges_h5.name, 'r') as h5: _check_index( h5=h5, index_col='/edges/a_to_b/target_node_id', id_to_range_col='/edges/a_to_b/indices/target_to_source/node_id_to_range', range_to_edge_col='/edges/a_to_b/indices/target_to_source/range_to_edge_id', n_edges=n_edges ) # Add source_node_id index indexer_func( edges_file=tmp_edges_h5.name, edges_population='/edges/a_to_b', index_type='source_node_id', **indexer_args ) with h5py.File(tmp_edges_h5.name, 'r') as h5: _check_index( h5=h5, index_col='/edges/a_to_b/source_node_id', id_to_range_col='/edges/a_to_b/indices/source_to_target/node_id_to_range', range_to_edge_col='/edges/a_to_b/indices/source_to_target/range_to_edge_id', n_edges=n_edges ) # Add edge_type_id index indexer_func( edges_file=tmp_edges_h5.name, edges_population='/edges/a_to_b', index_type='edge_type_id', **indexer_args ) with h5py.File(tmp_edges_h5.name, 'r') as h5: _check_index( h5=h5, index_col='/edges/a_to_b/edge_type_id', id_to_range_col='/edges/a_to_b/indices/edge_type_to_index/node_id_to_range', range_to_edge_col='/edges/a_to_b/indices/edge_type_to_index/range_to_edge_id', n_edges=n_edges )
if __name__ == '__main__': import logging logging.basicConfig(level=logging.DEBUG) test_create_index(create_index_in_memory, {}) test_create_index(create_index_on_disk, {'max_edge_reads': 10})