Source code for FIAT.restricted
# Copyright (C) 2015-2016 Jan Blechta, Andrew T T McRae, and others
#
# This file is part of FIAT (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later
from FIAT.dual_set import DualSet
from FIAT.finite_element import CiarletElement
[docs]
class RestrictedElement(CiarletElement):
"""Restrict given element to specified list of dofs."""
def __init__(self, element, indices=None, restriction_domain=None):
'''For sake of argument, indices overrides restriction_domain'''
if not (indices or restriction_domain):
raise RuntimeError("Either indices or restriction_domain must be passed in")
if not indices:
indices = _get_indices(element, restriction_domain)
if isinstance(indices, str):
raise RuntimeError("variable 'indices' was a string; did you forget to use a keyword?")
if len(indices) == 0:
raise ValueError("No point in creating empty RestrictedElement.")
self._element = element
self._indices = indices
# Fetch reference element
ref_el = element.get_reference_element()
# Restrict primal set
poly_set = element.get_nodal_basis().take(indices)
# Restrict dual set
dof_counter = 0
entity_ids = {}
nodes = []
nodes_old = element.dual_basis()
for d, entities in element.entity_dofs().items():
entity_ids[d] = {}
for entity, dofs in entities.items():
entity_ids[d][entity] = []
for dof in dofs:
if dof not in indices:
continue
entity_ids[d][entity].append(dof_counter)
dof_counter += 1
nodes.append(nodes_old[dof])
assert dof_counter == len(indices)
dual = DualSet(nodes, ref_el, entity_ids)
# Restrict mapping
mapping_old = element.mapping()
mapping_new = [mapping_old[dof] for dof in indices]
assert all(e_mapping == mapping_new[0] for e_mapping in mapping_new)
# Call constructor of CiarletElement
super(RestrictedElement, self).__init__(poly_set, dual, 0, element.get_formdegree(), mapping_new[0])
[docs]
def sorted_by_key(mapping):
"Sort dict items by key, allowing different key types."
# Python3 doesn't allow comparing builtins of different type, therefore the typename trick here
def _key(x):
return (type(x[0]).__name__, x[0])
return sorted(mapping.items(), key=_key)
def _get_indices(element, restriction_domain):
"Restriction domain can be 'interior', 'vertex', 'edge', 'face' or 'facet'"
if restriction_domain == "interior":
# Return dofs from interior
return element.entity_dofs()[max(element.entity_dofs().keys())][0]
# otherwise return dofs with d <= dim
if restriction_domain == "vertex":
dim = 0
elif restriction_domain == "edge":
dim = 1
elif restriction_domain == "face":
dim = 2
elif restriction_domain == "facet":
dim = element.get_reference_element().get_spatial_dimension() - 1
else:
raise RuntimeError("Invalid restriction domain")
is_prodcell = isinstance(max(element.entity_dofs().keys()), tuple)
entity_dofs = element.entity_dofs()
indices = []
for d in range(dim + 1):
if is_prodcell:
for a in range(d + 1):
b = d - a
try:
entities = entity_dofs[(a, b)]
for (entity, index) in sorted_by_key(entities):
indices += index
except KeyError:
pass
else:
entities = entity_dofs[d]
for (entity, index) in sorted_by_key(entities):
indices += index
return indices