Basix
Loading...
Searching...
No Matches
maps.h
1// Copyright (c) 2021-2022 Matthew Scroggs and Garth N. Wells
2// FEniCS Project
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include "mdspan.hpp"
8#include <algorithm>
9#include <stdexcept>
10#include <type_traits>
11
13namespace basix::maps
14{
15
16namespace impl
17{
20template <typename T, typename = void>
21struct scalar_value_type
22{
24 typedef T value_type;
25};
27template <typename T>
28struct scalar_value_type<T, std::void_t<typename T::value_type>>
29{
30 typedef typename T::value_type value_type;
31};
33template <typename T>
34using scalar_value_type_t = typename scalar_value_type<T>::value_type;
35} // namespace impl
36
38enum class type
39{
40 identity = 0,
41 L2Piola = 1,
42 covariantPiola = 2,
43 contravariantPiola = 3,
44 doubleCovariantPiola = 4,
45 doubleContravariantPiola = 5,
46};
47
49template <typename O, typename P, typename Q, typename R>
50void l2_piola(O&& r, const P& U, const Q& /*J*/, double detJ, const R& /*K*/)
51{
52 assert(U.extent(0) == r.extent(0));
53 assert(U.extent(1) == r.extent(1));
54 for (std::size_t i = 0; i < U.extent(0); ++i)
55 for (std::size_t j = 0; j < U.extent(1); ++j)
56 r(i, j) = U(i, j) / detJ;
57}
58
60template <typename O, typename P, typename Q, typename R>
61void covariant_piola(O&& r, const P& U, const Q& /*J*/, double /*detJ*/,
62 const R& K)
63{
64 using T = typename std::decay_t<O>::value_type;
65 using Z = typename impl::scalar_value_type_t<T>;
66 for (std::size_t p = 0; p < U.extent(0); ++p)
67 {
68 // r_p = K^T U_p, where p indicates the p-th row
69 for (std::size_t i = 0; i < r.extent(1); ++i)
70 {
71 T acc = 0;
72 for (std::size_t k = 0; k < K.extent(0); ++k)
73 acc += static_cast<Z>(K(k, i)) * U(p, k);
74 r(p, i) = acc;
75 }
76 }
77}
78
80template <typename O, typename P, typename Q, typename R>
81void contravariant_piola(O&& r, const P& U, const Q& J, double detJ,
82 const R& /*K*/)
83{
84 using T = typename std::decay_t<O>::value_type;
85 using Z = typename impl::scalar_value_type_t<T>;
86 for (std::size_t p = 0; p < U.extent(0); ++p)
87 {
88 for (std::size_t i = 0; i < r.extent(1); ++i)
89 {
90 T acc = 0;
91 for (std::size_t k = 0; k < J.extent(1); ++k)
92 acc += static_cast<Z>(J(i, k)) * U(p, k);
93 r(p, i) = acc;
94 }
95 }
96
97 std::transform(r.data_handle(), r.data_handle() + r.size(), r.data_handle(),
98 [detJ](auto ri) { return ri / static_cast<Z>(detJ); });
99}
100
102template <typename O, typename P, typename Q, typename R>
103void double_covariant_piola(O&& r, const P& U, const Q& J, double /*detJ*/,
104 const R& K)
105{
106 namespace stdex
107 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
108 using T = typename std::decay_t<O>::value_type;
109 using Z = typename impl::scalar_value_type_t<T>;
110 for (std::size_t p = 0; p < U.extent(0); ++p)
111 {
112 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
113 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
114 _U(U.data_handle() + p * U.extent(1), J.extent(1), J.extent(1));
115 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
116 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
117 _r(r.data_handle() + p * r.extent(1), K.extent(1), K.extent(1));
118 // _r = K^T _U K
119 for (std::size_t i = 0; i < _r.extent(0); ++i)
120 {
121 for (std::size_t j = 0; j < _r.extent(1); ++j)
122 {
123 T acc = 0;
124 for (std::size_t k = 0; k < K.extent(0); ++k)
125 for (std::size_t l = 0; l < _U.extent(1); ++l)
126 acc += static_cast<Z>(K(k, i)) * _U(k, l) * static_cast<Z>(K(l, j));
127 _r(i, j) = acc;
128 }
129 }
130 }
131}
132
134template <typename O, typename P, typename Q, typename R>
135void double_contravariant_piola(O&& r, const P& U, const Q& J, double detJ,
136 const R& /*K*/)
137{
138 namespace stdex
139 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
140 using T = typename std::decay_t<O>::value_type;
141 using Z = typename impl::scalar_value_type_t<T>;
142 for (std::size_t p = 0; p < U.extent(0); ++p)
143 {
144 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
145 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
146 _U(U.data_handle() + p * U.extent(1), J.extent(1), J.extent(1));
147 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
148 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
149 _r(r.data_handle() + p * r.extent(1), J.extent(0), J.extent(0));
150
151 // _r = J U J^T
152 for (std::size_t i = 0; i < _r.extent(0); ++i)
153 {
154 for (std::size_t j = 0; j < _r.extent(1); ++j)
155 {
156 T acc = 0;
157 for (std::size_t k = 0; k < J.extent(1); ++k)
158 for (std::size_t l = 0; l < _U.extent(1); ++l)
159 acc += static_cast<Z>(J(i, k)) * _U(k, l) * static_cast<Z>(J(j, l));
160 _r(i, j) = acc;
161 }
162 }
163 }
164
165 std::transform(r.data_handle(), r.data_handle() + r.size(), r.data_handle(),
166 [detJ](auto ri) { return ri / static_cast<Z>(detJ * detJ); });
167}
168
169} // namespace basix::maps
A finite element.
Definition finite-element.h:139
Information about finite element maps.
Definition maps.h:14
void l2_piola(O &&r, const P &U, const Q &, double detJ, const R &)
L2 Piola map.
Definition maps.h:50
void covariant_piola(O &&r, const P &U, const Q &, double, const R &K)
Covariant Piola map.
Definition maps.h:61
void contravariant_piola(O &&r, const P &U, const Q &J, double detJ, const R &)
Contravariant Piola map.
Definition maps.h:81
void double_contravariant_piola(O &&r, const P &U, const Q &J, double detJ, const R &)
Double contravariant Piola map.
Definition maps.h:135
void double_covariant_piola(O &&r, const P &U, const Q &J, double, const R &K)
Double covariant Piola map.
Definition maps.h:103
type
Map type.
Definition maps.h:39