Source code for ufl_legacy.exproperators

# -*- coding: utf-8 -*-
"""This module attaches special functions to Expr.
This way we avoid circular dependencies between e.g.
Sum and its superclass Expr."""

# Copyright (C) 2008-2016 Martin Sandve Alnæs
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier:    LGPL-3.0-or-later
#
# Modified by Massimiliano Leoni, 2016.

from itertools import chain
import numbers

from ufl_legacy.log import error
from ufl_legacy.utils.stacks import StackDict
from ufl_legacy.core.expr import Expr
from ufl_legacy.constantvalue import Zero, as_ufl
from ufl_legacy.algebra import Sum, Product, Division, Power, Abs
from ufl_legacy.tensoralgebra import Transposed, Inner
from ufl_legacy.core.multiindex import MultiIndex, Index, FixedIndex, IndexBase, indices
from ufl_legacy.indexed import Indexed
from ufl_legacy.indexsum import IndexSum
from ufl_legacy.tensors import as_tensor, ComponentTensor
from ufl_legacy.restriction import PositiveRestricted, NegativeRestricted
from ufl_legacy.differentiation import Grad
from ufl_legacy.index_combination_utils import create_slice_indices, merge_overlapping_indices

from ufl_legacy.exprequals import expr_equals

# --- Boolean operators ---

from ufl_legacy.conditional import LE, GE, LT, GT


def _le(left, right):
    "UFL operator: A boolean expresion (left <= right) for use with conditional."
    return LE(left, right)


def _ge(left, right):
    "UFL operator: A boolean expresion (left >= right) for use with conditional."
    return GE(left, right)


def _lt(left, right):
    "UFL operator: A boolean expresion (left < right) for use with conditional."
    return LT(left, right)


def _gt(left, right):
    "UFL operator: A boolean expresion (left > right) for use with conditional."
    return GT(left, right)


# '==' needs to implement comparison of expression representations for
# use in hashmaps (dict and set), but the others can be overloaded in
# the language.  It is possible that we can overload eq as well, but
# we'll need to fix some issues first and also check for a possible
# significant performance hit with compilation of complex
# forms. Replacing a==b with equiv(a,b) all over the code could be one
# way to reduce such a performance hit, but we cannot do anything
# about dict and set calling __eq__...
Expr.__eq__ = expr_equals


# != is used at least by tests, possibly in code as well, and must
# mean the opposite of ==, i.e. when evaluated as bool it must mean
# 'not equal representation'.
def _ne(self, other):
    return not self.__eq__(other)


Expr.__ne__ = _ne
Expr.__lt__ = _lt
Expr.__gt__ = _gt
Expr.__le__ = _le
Expr.__ge__ = _ge

# Python operators 'and'/'or' cannot be overloaded, and bitwise
# operators &/| don't have the right precedence levels
# Expr.__and__ = _and
# Expr.__or__ = _or


def _as_tensor(self, indices):
    "UFL operator: A^indices := as_tensor(A, indices)."
    if not isinstance(indices, tuple):
        error("Expecting a tuple of Index objects to A^indices := as_tensor(A, indices).")
    if not all(isinstance(i, Index) for i in indices):
        error("Expecting a tuple of Index objects to A^indices := as_tensor(A, indices).")
    return as_tensor(self, indices)


Expr.__xor__ = _as_tensor


# --- Helper functions for product handling ---

def _mult(a, b):
    # Discover repeated indices, which results in index sums
    afi = a.ufl_free_indices
    bfi = b.ufl_free_indices
    afid = a.ufl_index_dimensions
    bfid = b.ufl_index_dimensions
    fi, fid, ri, rid = merge_overlapping_indices(afi, afid, bfi, bfid)

    # Pick out valid non-scalar products here (dot products):
    # - matrix-matrix (A*B, M*grad(u)) => A . B
    # - matrix-vector (A*v) => A . v
    s1, s2 = a.ufl_shape, b.ufl_shape
    r1, r2 = len(s1), len(s2)

    if r1 == 0 and r2 == 0:
        # Create scalar product
        p = Product(a, b)
        ti = ()

    elif r1 == 0 or r2 == 0:
        # Scalar - tensor product
        if r2 == 0:
            a, b = b, a

        # Check for zero, simplifying early if possible
        if isinstance(a, Zero) or isinstance(b, Zero):
            shape = s1 or s2
            return Zero(shape, fi, fid)

        # Repeated indices are allowed, like in:
        # v[i]*M[i,:]

        # Apply product to scalar components
        ti = indices(len(b.ufl_shape))
        p = Product(a, b[ti])

    elif r1 == 2 and r2 in (1, 2):  # Matrix-matrix or matrix-vector
        if ri:
            error("Not expecting repeated indices in non-scalar product.")

        # Check for zero, simplifying early if possible
        if isinstance(a, Zero) or isinstance(b, Zero):
            shape = s1[:-1] + s2[1:]
            return Zero(shape, fi, fid)

        # Return dot product in index notation
        ai = indices(len(a.ufl_shape) - 1)
        bi = indices(len(b.ufl_shape) - 1)
        k = indices(1)

        p = a[ai + k] * b[k + bi]
        ti = ai + bi

    else:
        error("Invalid ranks {0} and {1} in product.".format(r1, r2))

    # TODO: I think applying as_tensor after index sums results in
    # cleaner expression graphs.
    # Wrap as tensor again
    if ti:
        p = as_tensor(p, ti)

    # If any repeated indices were found, apply implicit summation
    # over those
    for i in ri:
        mi = MultiIndex((Index(count=i),))
        p = IndexSum(p, mi)

    return p


# --- Extend Expr with algebraic operators ---

_valid_types = (Expr, numbers.Real, numbers.Integral, numbers.Complex)


def _mul(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    o = as_ufl(o)
    return _mult(self, o)


Expr.__mul__ = _mul


def _rmul(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    o = as_ufl(o)
    return _mult(o, self)


Expr.__rmul__ = _rmul


def _add(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    return Sum(self, o)


Expr.__add__ = _add


def _radd(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    if isinstance(o, numbers.Number) and o == 0:
        # Allow adding scalar int 0 as a no-op, even for shaped self,
        # needed for sum([a,b])
        return self
    return Sum(o, self)


Expr.__radd__ = _radd


def _sub(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    return Sum(self, -o)


Expr.__sub__ = _sub


def _rsub(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    return Sum(o, -self)


Expr.__rsub__ = _rsub


def _div(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    sh = self.ufl_shape
    if sh:
        ii = indices(len(sh))
        d = Division(self[ii], o)
        return as_tensor(d, ii)
    return Division(self, o)


Expr.__div__ = _div
Expr.__truediv__ = _div


def _rdiv(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    return Division(o, self)


Expr.__rdiv__ = _rdiv
Expr.__rtruediv__ = _rdiv


def _pow(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    if o == 2 and self.ufl_shape:
        return Inner(self, self)
    return Power(self, o)


Expr.__pow__ = _pow


def _rpow(self, o):
    if not isinstance(o, _valid_types):
        return NotImplemented
    return Power(o, self)


Expr.__rpow__ = _rpow


# TODO: Add Negated class for this? Might simplify reductions in Add.
def _neg(self):
    return -1 * self


Expr.__neg__ = _neg


def _abs(self):
    return Abs(self)


Expr.__abs__ = _abs


# --- Extend Expr with restiction operators a("+"), a("-") ---

def _restrict(self, side):
    if side == "+":
        return PositiveRestricted(self)
    if side == "-":
        return NegativeRestricted(self)
    error("Invalid side '%s' in restriction operator." % (side,))


def _eval(self, coord, mapping=None, component=()):
    # Evaluate expression at this particular coordinate, with provided
    # values for other terminals in mapping

    # Evaluate derivatives first
    from ufl_legacy.algorithms import expand_derivatives
    f = expand_derivatives(self)

    # Evaluate recursively
    if mapping is None:
        mapping = {}
    index_values = StackDict()
    return f.evaluate(coord, mapping, component, index_values)


def _call(self, arg, mapping=None, component=()):
    # Taking the restriction or evaluating depending on argument
    if arg in ("+", "-"):
        if mapping is not None:
            error("Not expecting a mapping when taking restriction.")
        return _restrict(self, arg)
    else:
        return _eval(self, arg, mapping, component)


Expr.__call__ = _call


# --- Extend Expr with the transpose operation A.T ---

def _transpose(self):
    """Transpose a rank-2 tensor expression. For more general transpose
    operations of higher order tensor expressions, use indexing and Tensor."""
    return Transposed(self)


Expr.T = property(_transpose)


# --- Extend Expr with indexing operator a[i] ---

[docs] def analyse_key(ii, rank): """Takes something the user might input as an index tuple inside [], which could include complete slices (:) and ellipsis (...), and returns tuples of actual UFL index objects. The return value is a tuple (indices, axis_indices), each being a tuple of IndexBase instances. The return value 'indices' corresponds to all input objects of these types: - Index - FixedIndex - int => Wrapped in FixedIndex The return value 'axis_indices' corresponds to all input objects of these types: - Complete slice (:) => Replaced by a single new index - Ellipsis (...) => Replaced by multiple new indices """ # Wrap in tuple if not isinstance(ii, (tuple, MultiIndex)): ii = (ii,) else: # Flatten nested tuples, happens with f[...,ii] where ii is a # tuple of indices jj = [] for j in ii: if isinstance(j, (tuple, MultiIndex)): jj.extend(j) else: jj.append(j) ii = tuple(jj) # Convert all indices to Index or FixedIndex objects. If there is # an ellipsis, split the indices into before and after. axis_indices = set() pre = [] post = [] indexlist = pre for i in ii: if i == Ellipsis: # Switch from pre to post list when an ellipsis is # encountered if indexlist is not pre: error("Found duplicate ellipsis.") indexlist = post else: # Convert index to a proper type if isinstance(i, numbers.Integral): idx = FixedIndex(i) elif isinstance(i, IndexBase): idx = i elif isinstance(i, slice): if i == slice(None): idx = Index() axis_indices.add(idx) else: # TODO: Use ListTensor to support partial slices? error("Partial slices not implemented, only complete slices like [:]") else: error("Can't convert this object to index: %s" % (i,)) # Store index in pre or post list indexlist.append(idx) # Handle ellipsis as a number of complete slices, that is create a # number of new axis indices num_axis = rank - len(pre) - len(post) if indexlist is post: ellipsis_indices = indices(num_axis) axis_indices.update(ellipsis_indices) else: ellipsis_indices = () # Construct final tuples to return all_indices = tuple(chain(pre, ellipsis_indices, post)) axis_indices = tuple(i for i in all_indices if i in axis_indices) return all_indices, axis_indices
def _getitem(self, component): # Treat component consistently as tuple below if not isinstance(component, tuple): component = (component,) shape = self.ufl_shape # Analyse slices (:) and Ellipsis (...) all_indices, slice_indices, repeated_indices = create_slice_indices(component, shape, self.ufl_free_indices) # Check that we have the right number of indices for a tensor with # this shape if len(shape) != len(all_indices): error("Invalid number of indices {0} for expression of rank {1}.".format(len(all_indices), len(shape))) # Special case for simplifying foo[...] => foo, foo[:] => foo or # similar if len(slice_indices) == len(all_indices): return self # Special case for simplifying as_tensor(ai,(i,))[i] => ai if isinstance(self, ComponentTensor): if all_indices == self.indices().indices(): return self.ufl_operands[0] # Apply all indices to index self, yielding a scalar valued # expression mi = MultiIndex(all_indices) a = Indexed(self, mi) # TODO: I think applying as_tensor after index sums results in # cleaner expression graphs. # If the Ellipsis or any slices were found, wrap as tensor valued # with the slice indices created at the top here if slice_indices: a = as_tensor(a, slice_indices) # If any repeated indices were found, apply implicit summation # over those for i in repeated_indices: mi = MultiIndex((i,)) a = IndexSum(a, mi) # Check for zero (last so we can get indices etc from a, could # possibly be done faster by checking early instead) if isinstance(self, Zero): shape = a.ufl_shape fi = a.ufl_free_indices fid = a.ufl_index_dimensions a = Zero(shape, fi, fid) return a Expr.__getitem__ = _getitem # --- Extend Expr with spatial differentiation operator a.dx(i) --- def _dx(self, *ii): "Return the partial derivative with respect to spatial variable number *ii*." d = self # Unwrap ii to allow .dx(i,j) and .dx((i,j)) if len(ii) == 1 and isinstance(ii[0], tuple): ii = ii[0] # Apply all derivatives for i in ii: d = Grad(d) # Take all components, applying repeated index sums in the [] operation return d.__getitem__((Ellipsis,) + ii) Expr.dx = _dx