Source code for FIAT.restricted

# Copyright (C) 2015-2016 Jan Blechta, Andrew T T McRae, and others
#
# This file is part of FIAT (https://www.fenicsproject.org)
#
# SPDX-License-Identifier:    LGPL-3.0-or-later

from FIAT.dual_set import DualSet
from FIAT.finite_element import CiarletElement


[docs] class RestrictedElement(CiarletElement): """Restrict given element to specified list of dofs.""" def __init__(self, element, indices=None, restriction_domain=None): '''For sake of argument, indices overrides restriction_domain''' if not (indices or restriction_domain): raise RuntimeError("Either indices or restriction_domain must be passed in") if not indices: indices = _get_indices(element, restriction_domain) if isinstance(indices, str): raise RuntimeError("variable 'indices' was a string; did you forget to use a keyword?") if len(indices) == 0: raise ValueError("No point in creating empty RestrictedElement.") self._element = element self._indices = indices # Fetch reference element ref_el = element.get_reference_element() # Restrict primal set poly_set = element.get_nodal_basis().take(indices) # Restrict dual set dof_counter = 0 entity_ids = {} nodes = [] nodes_old = element.dual_basis() for d, entities in element.entity_dofs().items(): entity_ids[d] = {} for entity, dofs in entities.items(): entity_ids[d][entity] = [] for dof in dofs: if dof not in indices: continue entity_ids[d][entity].append(dof_counter) dof_counter += 1 nodes.append(nodes_old[dof]) assert dof_counter == len(indices) dual = DualSet(nodes, ref_el, entity_ids) # Restrict mapping mapping_old = element.mapping() mapping_new = [mapping_old[dof] for dof in indices] assert all(e_mapping == mapping_new[0] for e_mapping in mapping_new) # Call constructor of CiarletElement super(RestrictedElement, self).__init__(poly_set, dual, 0, element.get_formdegree(), mapping_new[0])
[docs] def sorted_by_key(mapping): "Sort dict items by key, allowing different key types." # Python3 doesn't allow comparing builtins of different type, therefore the typename trick here def _key(x): return (type(x[0]).__name__, x[0]) return sorted(mapping.items(), key=_key)
def _get_indices(element, restriction_domain): "Restriction domain can be 'interior', 'vertex', 'edge', 'face' or 'facet'" if restriction_domain == "interior": # Return dofs from interior return element.entity_dofs()[max(element.entity_dofs().keys())][0] # otherwise return dofs with d <= dim if restriction_domain == "vertex": dim = 0 elif restriction_domain == "edge": dim = 1 elif restriction_domain == "face": dim = 2 elif restriction_domain == "facet": dim = element.get_reference_element().get_spatial_dimension() - 1 else: raise RuntimeError("Invalid restriction domain") is_prodcell = isinstance(max(element.entity_dofs().keys()), tuple) entity_dofs = element.entity_dofs() indices = [] for d in range(dim + 1): if is_prodcell: for a in range(d + 1): b = d - a try: entities = entity_dofs[(a, b)] for (entity, index) in sorted_by_key(entities): indices += index except KeyError: pass else: entities = entity_dofs[d] for (entity, index) in sorted_by_key(entities): indices += index return indices