1#ifndef HALIDE_IR_MATCH_H
2#define HALIDE_IR_MATCH_H
139 typename =
typename std::remove_reference<T>::type::pattern_tag>
146 constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
166 const int lanes = scalar_type.
lanes;
167 scalar_type.
lanes = 1;
170 switch (scalar_type.
code) {
198 ((a.type == b.type) &&
199 (a.node_type == b.node_type) &&
216 template<u
int32_t bound>
244 template<u
int32_t bound>
246 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
249 op = ((
const Broadcast *)op)->value.get();
258 state.get_bound_const(i, val, type);
261 state.set_bound_const(i, value, e.type);
265 template<u
int32_t bound>
267 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
271 state.get_bound_const(i, val, type);
272 return type == i64_type && value == val.
u.
i64;
274 state.set_bound_const(i, value, i64_type);
310 template<u
int32_t bound>
312 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
315 op = ((
const Broadcast *)op)->value.get();
324 state.get_bound_const(i, val, type);
327 state.set_bound_const(i, value, e.type);
343 state.get_bound_const(i, val, ty);
363 template<u
int32_t bound>
365 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
368 op = ((
const Broadcast *)op)->value.get();
373 double value = ((
const FloatImm *)op)->value;
377 state.get_bound_const(i, val, type);
380 state.set_bound_const(i, value, e.type);
396 state.get_bound_const(i, val, ty);
417 template<u
int32_t bound>
419 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
422 op = ((
const Broadcast *)op)->value.get();
436 template<u
int32_t bound>
438 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
454 state.get_bound_const(i, val, ty);
475 template<u
int32_t bound>
478 return equal(*state.get_binding(i), e);
480 state.set_binding(i, e);
515 template<u
int32_t bound>
519 op = ((
const Broadcast *)op)->value.get();
527 return ((
const FloatImm *)op)->value == (
double)
v;
533 template<u
int32_t bound>
538 template<u
int32_t bound>
562 val.u.f64 = (double)
v;
578 typename =
typename std::decay<T>::type::pattern_tag>
589 static_assert(!std::is_same<typename std::decay<T>::type,
Expr>::value || std::is_lvalue_reference<T>::value,
590 "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
601 typename =
typename std::decay<T>::type::pattern_tag,
603 typename =
typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
639template<
typename Op,
typename A,
typename B>
654 A::canonical && B::canonical && (!
commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
656 template<u
int32_t bound>
658 if (e.node_type != Op::_node_type) {
661 const Op &op = (
const Op &)e;
662 return (
a.template match<bound>(*op.a.get(), state) &&
663 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
666 template<u
int32_t bound,
typename Op2,
typename A2,
typename B2>
668 return (std::is_same<Op, Op2>::value &&
669 a.template match<bound>(
unwrap(op.a), state) &&
673 constexpr static bool foldable = A::foldable && B::foldable;
678 if (std::is_same<A, IntLiteral>::value) {
679 b.make_folded_const(val_b, ty, state);
680 if ((std::is_same<Op, And>::value && val_b.
u.
u64 == 0) ||
681 (std::is_same<Op, Or>::value && val_b.
u.
u64 == 1)) {
687 a.make_folded_const(val_a, ty, state);
690 a.make_folded_const(val_a, ty, state);
691 if ((std::is_same<Op, And>::value && val_a.
u.
u64 == 0) ||
692 (std::is_same<Op, Or>::value && val_a.
u.
u64 == 1)) {
698 b.make_folded_const(val_b, ty, state);
703 val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.
u.
i64, val_b.
u.
i64);
706 val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.
u.
u64, val_b.
u.
u64);
710 val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.
u.
f64, val_b.
u.
f64);
721 if (std::is_same<A, IntLiteral>::value) {
722 eb =
b.make(state, type_hint);
723 ea =
a.make(state, eb.
type());
725 ea =
a.make(state, type_hint);
726 eb =
b.make(state, ea.
type());
736 return Op::make(std::move(ea), std::move(eb));
750template<
typename Op,
typename A,
typename B>
762 (!
commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
766 template<u
int32_t bound>
768 if (e.node_type != Op::_node_type) {
771 const Op &op = (
const Op &)e;
772 return (
a.template match<bound>(*op.a.get(), state) &&
773 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
776 template<u
int32_t bound,
typename Op2,
typename A2,
typename B2>
778 return (std::is_same<Op, Op2>::value &&
779 a.template match<bound>(
unwrap(op.a), state) &&
783 constexpr static bool foldable = A::foldable && B::foldable;
789 if (std::is_same<A, IntLiteral>::value) {
790 b.make_folded_const(val_b, ty, state);
792 a.make_folded_const(val_a, ty, state);
795 a.make_folded_const(val_a, ty, state);
797 b.make_folded_const(val_b, ty, state);
802 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
i64, val_b.
u.
i64);
805 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
u64, val_b.
u.
u64);
809 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
f64, val_b.
u.
f64);
823 if (std::is_same<A, IntLiteral>::value) {
824 eb =
b.make(state, {});
825 ea =
a.make(state, eb.
type());
827 ea =
a.make(state, {});
828 eb =
b.make(state, ea.
type());
838 return Op::make(std::move(ea), std::move(eb));
842template<
typename A,
typename B>
844 s <<
"(" << op.
a <<
" + " << op.
b <<
")";
848template<
typename A,
typename B>
850 s <<
"(" << op.
a <<
" - " << op.
b <<
")";
854template<
typename A,
typename B>
856 s <<
"(" << op.
a <<
" * " << op.
b <<
")";
860template<
typename A,
typename B>
862 s <<
"(" << op.
a <<
" / " << op.
b <<
")";
866template<
typename A,
typename B>
868 s <<
"(" << op.
a <<
" && " << op.
b <<
")";
872template<
typename A,
typename B>
874 s <<
"(" << op.
a <<
" || " << op.
b <<
")";
878template<
typename A,
typename B>
880 s <<
"min(" << op.
a <<
", " << op.
b <<
")";
884template<
typename A,
typename B>
886 s <<
"max(" << op.
a <<
", " << op.
b <<
")";
890template<
typename A,
typename B>
892 s <<
"(" << op.
a <<
" <= " << op.
b <<
")";
896template<
typename A,
typename B>
898 s <<
"(" << op.
a <<
" < " << op.
b <<
")";
902template<
typename A,
typename B>
904 s <<
"(" << op.
a <<
" >= " << op.
b <<
")";
908template<
typename A,
typename B>
910 s <<
"(" << op.
a <<
" > " << op.
b <<
")";
914template<
typename A,
typename B>
916 s <<
"(" << op.
a <<
" == " << op.
b <<
")";
920template<
typename A,
typename B>
922 s <<
"(" << op.
a <<
" != " << op.
b <<
")";
926template<
typename A,
typename B>
928 s <<
"(" << op.
a <<
" % " << op.
b <<
")";
932template<
typename A,
typename B>
934 assert_is_lvalue_if_expr<A>();
935 assert_is_lvalue_if_expr<B>();
939template<
typename A,
typename B>
941 assert_is_lvalue_if_expr<A>();
942 assert_is_lvalue_if_expr<B>();
949 int dead_bits = 64 - t.bits;
957 return (a + b) & (ones >> (64 - t.bits));
965template<
typename A,
typename B>
967 assert_is_lvalue_if_expr<A>();
968 assert_is_lvalue_if_expr<B>();
972template<
typename A,
typename B>
974 assert_is_lvalue_if_expr<A>();
975 assert_is_lvalue_if_expr<B>();
983 int dead_bits = 64 - t.bits;
990 return (a - b) & (ones >> (64 - t.bits));
998template<
typename A,
typename B>
1000 assert_is_lvalue_if_expr<A>();
1001 assert_is_lvalue_if_expr<B>();
1005template<
typename A,
typename B>
1007 assert_is_lvalue_if_expr<A>();
1008 assert_is_lvalue_if_expr<B>();
1015 int dead_bits = 64 - t.bits;
1023 return (a * b) & (ones >> (64 - t.bits));
1031template<
typename A,
typename B>
1033 assert_is_lvalue_if_expr<A>();
1034 assert_is_lvalue_if_expr<B>();
1038template<
typename A,
typename B>
1058template<
typename A,
typename B>
1060 assert_is_lvalue_if_expr<A>();
1061 assert_is_lvalue_if_expr<B>();
1065template<
typename A,
typename B>
1067 assert_is_lvalue_if_expr<A>();
1068 assert_is_lvalue_if_expr<B>();
1087template<
typename A,
typename B>
1089 assert_is_lvalue_if_expr<A>();
1090 assert_is_lvalue_if_expr<B>();
1109template<
typename A,
typename B>
1111 assert_is_lvalue_if_expr<A>();
1112 assert_is_lvalue_if_expr<B>();
1131template<
typename A,
typename B>
1136template<
typename A,
typename B>
1156template<
typename A,
typename B>
1161template<
typename A,
typename B>
1181template<
typename A,
typename B>
1186template<
typename A,
typename B>
1206template<
typename A,
typename B>
1211template<
typename A,
typename B>
1231template<
typename A,
typename B>
1236template<
typename A,
typename B>
1256template<
typename A,
typename B>
1261template<
typename A,
typename B>
1281template<
typename A,
typename B>
1286template<
typename A,
typename B>
1307template<
typename A,
typename B>
1312template<
typename A,
typename B>
1337template<
typename... Args>
1346template<
typename... Args>
1353 return a < b ? a : b;
1356template<
typename... Args>
1370 typename =
typename std::enable_if<(i <
sizeof...(Args))>::type>
1372 using T =
decltype(std::get<i>(
args));
1373 return (std::get<i>(
args).template match<bound>(*c.args[i].get(), state) &&
1374 match_args<i + 1, bound | bindings<T>::mask>(0, c, state));
1377 template<
int i, u
int32_t binds>
1382 template<u
int32_t bound>
1392 typename =
typename std::enable_if<(i <
sizeof...(Args))>::type>
1394 s << std::get<i>(
args);
1395 if (i + 1 <
sizeof...(Args)) {
1398 print_args<i + 1>(0, s);
1407 print_args<0>(0, s);
1412 Expr arg0 = std::get<0>(
args).make(state, type_hint);
1423 return absd(arg0, arg1);
1443 return arg0 << arg1;
1445 return arg0 >> arg1;
1471 std::get<0>(
args).make_folded_const(val, ty, state);
1476 std::get<1>(
args).make_folded_const(arg1, signed_ty, state);
1479 if (arg1.
u.
i64 < 0) {
1482 val.u.i64 >>= -arg1.
u.
i64;
1485 val.u.u64 >>= -arg1.
u.
i64;
1488 val.u.u64 <<= arg1.
u.
i64;
1491 if (arg1.
u.
i64 > 0) {
1494 val.u.i64 >>= arg1.
u.
i64;
1497 val.u.u64 >>= arg1.
u.
i64;
1500 val.u.u64 <<= -arg1.
u.
i64;
1513template<
typename... Args>
1521template<
typename... Args>
1526template<
typename A,
typename B>
1530template<
typename A,
typename B>
1534template<
typename A,
typename B>
1538template<
typename A,
typename B>
1542template<
typename A,
typename B>
1546template<
typename A,
typename B>
1550template<
typename A,
typename B>
1554template<
typename A,
typename B>
1558template<
typename A,
typename B>
1562template<
typename A,
typename B>
1566template<
typename A,
typename B>
1570template<
typename A,
typename B>
1574template<
typename A,
typename B>
1578template<
typename A,
typename B,
typename C>
1582template<
typename A,
typename B,
typename C>
1598 template<u
int32_t bound>
1603 const Not &op = (
const Not &)e;
1604 return (
a.template match<bound>(*op.
a.
get(), state));
1607 template<u
int32_t bound,
typename A2>
1609 return a.template match<bound>(
unwrap(op.a), state);
1619 template<
typename A1 = A>
1621 a.make_folded_const(val, ty, state);
1622 val.u.u64 = ~val.u.u64;
1629 assert_is_lvalue_if_expr<A>();
1635 assert_is_lvalue_if_expr<A>();
1641 s <<
"!(" << op.
a <<
")";
1645template<
typename C,
typename T,
typename F>
1657 constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
1659 template<u
int32_t bound>
1665 return (
c.template match<bound>(*op.
condition.
get(), state) &&
1666 t.template match<bound | bindings<C>::mask>(*op.
true_value.
get(), state) &&
1669 template<u
int32_t bound,
typename C2,
typename T2,
typename F2>
1671 return (
c.template match<bound>(
unwrap(instance.c), state) &&
1678 return Select::make(
c.make(state, {}),
t.make(state, type_hint),
f.make(state, type_hint));
1681 constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
1683 template<
typename C1 = C>
1687 c.make_folded_const(c_val, c_ty, state);
1688 if ((c_val.
u.
u64 & 1) == 1) {
1689 t.make_folded_const(val, ty, state);
1691 f.make_folded_const(val, ty, state);
1697template<
typename C,
typename T,
typename F>
1699 s <<
"select(" << op.
c <<
", " << op.
t <<
", " << op.
f <<
")";
1703template<
typename C,
typename T,
typename F>
1705 assert_is_lvalue_if_expr<C>();
1706 assert_is_lvalue_if_expr<T>();
1707 assert_is_lvalue_if_expr<F>();
1711template<
typename A,
typename B>
1722 constexpr static bool canonical = A::canonical && B::canonical;
1724 template<u
int32_t bound>
1728 if (
a.template match<bound>(*op.
value.
get(), state) &&
1729 lanes.template match<bound>(op.
lanes, state)) {
1736 template<u
int32_t bound,
typename A2,
typename B2>
1738 return (
a.template match<bound>(
unwrap(op.a), state) &&
1746 lanes.make_folded_const(lanes_val, ty, state);
1748 type_hint.
lanes /= l;
1749 Expr val =
a.make(state, type_hint);
1759 template<
typename A1 = A>
1763 lanes.make_folded_const(lanes_val, lanes_ty, state);
1765 a.make_folded_const(val, ty, state);
1770template<
typename A,
typename B>
1772 s <<
"broadcast(" << op.
a <<
", " << op.
lanes <<
")";
1776template<
typename A,
typename B>
1778 assert_is_lvalue_if_expr<A>();
1782template<
typename A,
typename B,
typename C>
1794 constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
1796 template<u
int32_t bound>
1802 if (
a.template match<bound>(*op.
base.
get(), state) &&
1803 b.template match<bound | bindings<A>::mask>(*op.
stride.
get(), state) &&
1811 template<u
int32_t bound,
typename A2,
typename B2,
typename C2>
1813 return (
a.template match<bound>(
unwrap(op.a), state) &&
1822 lanes.make_folded_const(lanes_val, ty, state);
1824 type_hint.
lanes /= l;
1826 eb =
b.make(state, type_hint);
1827 ea =
a.make(state, eb.type());
1834template<
typename A,
typename B,
typename C>
1836 s <<
"ramp(" << op.
a <<
", " << op.
b <<
", " << op.
lanes <<
")";
1840template<
typename A,
typename B,
typename C>
1842 assert_is_lvalue_if_expr<A>();
1843 assert_is_lvalue_if_expr<B>();
1844 assert_is_lvalue_if_expr<C>();
1848template<
typename A,
typename B, VectorReduce::Operator reduce_op>
1860 template<u
int32_t bound>
1864 if (op.
op == reduce_op &&
1865 a.template match<bound>(*op.
value.
get(), state) &&
1866 lanes.template match<bound | bindings<A>::mask>(op.
type.
lanes(), state)) {
1873 template<u
int32_t bound,
typename A2,
typename B2, VectorReduce::Operator reduce_op_2>
1875 return (reduce_op == reduce_op_2 &&
1876 a.template match<bound>(
unwrap(op.a), state) &&
1884 lanes.make_folded_const(lanes_val, ty, state);
1885 int l = (int)lanes_val.
u.
i64;
1892template<
typename A,
typename B, VectorReduce::Operator reduce_op>
1894 s <<
"vector_reduce(" << reduce_op <<
", " << op.
a <<
", " << op.
lanes <<
")";
1898template<
typename A,
typename B>
1900 assert_is_lvalue_if_expr<A>();
1904template<
typename A,
typename B>
1906 assert_is_lvalue_if_expr<A>();
1910template<
typename A,
typename B>
1912 assert_is_lvalue_if_expr<A>();
1916template<
typename A,
typename B>
1918 assert_is_lvalue_if_expr<A>();
1922template<
typename A,
typename B>
1924 assert_is_lvalue_if_expr<A>();
1940 template<u
int32_t bound>
1945 const Sub &op = (
const Sub &)e;
1946 return (
a.template match<bound>(*op.
b.
get(), state) &&
1950 template<u
int32_t bound,
typename A2>
1952 return a.template match<bound>(
unwrap(p.a), state);
1957 Expr ea =
a.make(state, type_hint);
1959 return Sub::make(std::move(z), std::move(ea));
1964 template<
typename A1 = A>
1966 a.make_folded_const(val, ty, state);
1967 int dead_bits = 64 - ty.bits;
1970 if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
1979 val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
1983 val.u.f64 = -val.u.f64;
2000 assert_is_lvalue_if_expr<A>();
2006 assert_is_lvalue_if_expr<A>();
2022 template<u
int32_t bound>
2028 return (e.type ==
t &&
2029 a.template match<bound>(*op.
value.
get(), state));
2031 template<u
int32_t bound,
typename A2>
2033 return t == op.t &&
a.template match<bound>(
unwrap(op.a), state);
2038 return cast(
t,
a.make(state, {}));
2046 s <<
"cast(" << op.
t <<
", " << op.
a <<
")";
2052 assert_is_lvalue_if_expr<A>();
2071 a.make_folded_const(c, ty, state);
2077 if (type_hint.bits) {
2081 c.
u.
f64 = (double)x;
2083 ty.
code = type_hint.code;
2084 ty.
bits = type_hint.bits;
2093 template<
typename A1 = A>
2095 a.make_folded_const(val, ty, state);
2101 assert_is_lvalue_if_expr<A>();
2107 s <<
"fold(" << op.
a <<
")";
2126 template<
typename A1 = A>
2128 a.make_folded_const(val, ty, state);
2138 assert_is_lvalue_if_expr<A>();
2144 s <<
"overflows(" << op.
a <<
")";
2158 template<u
int32_t bound>
2204 template<
typename A1 = A>
2206 Expr e =
a.make(state, {});
2220 assert_is_lvalue_if_expr<A>();
2226 assert_is_lvalue_if_expr<A>();
2233 s <<
"is_const(" << op.
a <<
")";
2235 s <<
"is_const(" << op.
a <<
", " << op.
v <<
")";
2240template<
typename A,
typename Prover>
2257 Expr condition =
a.make(state, {});
2258 condition =
prover->mutate(condition,
nullptr);
2266template<
typename A,
typename Prover>
2268 assert_is_lvalue_if_expr<A>();
2272template<
typename A,
typename Prover>
2274 s <<
"can_prove(" << op.
a <<
")";
2295 Type t =
a.make(state, {}).type();
2305 assert_is_lvalue_if_expr<A>();
2311 s <<
"is_float(" << op.
a <<
")";
2333 Type t =
a.make(state, {}).type();
2343 assert_is_lvalue_if_expr<A>();
2349 s <<
"is_int(" << op.
a;
2351 s <<
", " << op.
bits;
2375 Type t =
a.make(state, {}).type();
2385 assert_is_lvalue_if_expr<A>();
2391 s <<
"is_uint(" << op.
a;
2393 s <<
", " << op.
bits;
2416 Type t =
a.make(state, {}).type();
2426 assert_is_lvalue_if_expr<A>();
2432 s <<
"is_scalar(" << op.
a <<
")";
2453 a.make_folded_const(val, ty, state);
2456 val.
u.
u64 = (val.
u.
u64 == max_bits);
2467 assert_is_lvalue_if_expr<A>();
2473 s <<
"is_max_value(" << op.
a <<
")";
2494 a.make_folded_const(val, ty, state);
2497 val.
u.
u64 = (val.
u.
u64 == min_bits);
2510 assert_is_lvalue_if_expr<A>();
2516 s <<
"is_min_value(" << op.
a <<
")";
2521template<
typename Before,
2524 typename =
typename std::enable_if<std::decay<Before>::type::foldable &&
2525 std::decay<After>::type::foldable>::type>
2530 wildcard_type.lanes = output_type.lanes = 1;
2533 static std::set<uint32_t> tested;
2535 if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2540 debug(0) <<
"validate('" << before <<
"', '" << after <<
"', '" << pred <<
"', " <<
Type(wildcard_type) <<
", " <<
Type(output_type) <<
")\n";
2545 static std::mt19937_64 rng(0);
2550 for (
int trials = 0; trials < 100; trials++) {
2554 int shift = (int)(rng() & (wildcard_type.bits - 1));
2556 for (
int i = 0; i <
max_wild; i++) {
2558 switch (wildcard_type.code) {
2578 double val = ((
int64_t)(rng() & 15) - 8) / 2.0;
2580 val = ((
int64_t)(rng() & 15) - 8) / 2.0;
2594 before.make_folded_const(val_before, type, state);
2596 after.make_folded_const(val_after, type, state);
2597 lanes |= type.
lanes;
2604 switch (output_type.code) {
2619 ok &= (error < 0.01 ||
2620 val_before.
u.
u64 == val_after.
u.
u64 ||
2621 std::isnan(val_before.
u.
f64));
2629 debug(0) <<
"Fails with values:\n";
2630 for (
int i = 0; i <
max_wild; i++) {
2635 for (
int i = 0; i <
max_wild; i++) {
2640 debug(0) << val_before.
u.
u64 <<
" " << val_after.
u.
u64 <<
"\n";
2646template<
typename Before,
2649 typename =
typename std::enable_if<!(std::decay<Before>::type::foldable &&
2650 std::decay<After>::type::foldable)>::type>
2661template<
typename Pattern,
2662 typename =
typename enable_if_pattern<Pattern>::type>
2666 p.make_folded_const(c, ty, state);
2674#define HALIDE_DEBUG_MATCHED_RULES 0
2675#define HALIDE_DEBUG_UNMATCHED_RULES 0
2681#define HALIDE_FUZZ_TEST_RULES 0
2683template<
typename Instance>
2696 template<
typename After>
2701 template<
typename Before,
2706 static_assert((Before::binds & After::binds) == After::binds,
"Rule result uses unbound values");
2707 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2708 static_assert(After::canonical,
"RHS of rewrite rule should be in canonical form");
2709#if HALIDE_FUZZ_TEST_RULES
2714#if HALIDE_DEBUG_MATCHED_RULES
2719#if HALIDE_DEBUG_UNMATCHED_RULES
2720 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2726 template<
typename Before,
2729 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2732#if HALIDE_DEBUG_MATCHED_RULES
2737#if HALIDE_DEBUG_UNMATCHED_RULES
2738 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2744 template<
typename Before,
2747 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2748#if HALIDE_FUZZ_TEST_RULES
2753#if HALIDE_DEBUG_MATCHED_RULES
2758#if HALIDE_DEBUG_UNMATCHED_RULES
2759 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2765 template<
typename Before,
2772 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
2773 static_assert((Before::binds & After::binds) == After::binds,
"Rule result uses unbound values");
2774 static_assert((Before::binds & Predicate::binds) == Predicate::binds,
"Rule predicate uses unbound values");
2775 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2776 static_assert(After::canonical,
"RHS of rewrite rule should be in canonical form");
2778#if HALIDE_FUZZ_TEST_RULES
2784#if HALIDE_DEBUG_MATCHED_RULES
2785 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
2789#if HALIDE_DEBUG_UNMATCHED_RULES
2790 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2796 template<
typename Before,
2801 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
2802 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2807#if HALIDE_DEBUG_MATCHED_RULES
2808 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
2812#if HALIDE_DEBUG_UNMATCHED_RULES
2813 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2819 template<
typename Before,
2824 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
2825 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2826#if HALIDE_FUZZ_TEST_RULES
2832#if HALIDE_DEBUG_MATCHED_RULES
2833 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
2837#if HALIDE_DEBUG_UNMATCHED_RULES
2838 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2862template<
typename Instance,
2863 typename =
typename enable_if_pattern<Instance>::type>
2865 return {
pattern_arg(instance), output_type, wildcard_type};
2868template<
typename Instance,
2869 typename =
typename enable_if_pattern<Instance>::type>
2871 return {
pattern_arg(instance), output_type, output_type};
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
#define HALIDE_ALWAYS_INLINE
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
For optional debugging during codegen, use the debug class as follows:
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp< Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits=0) noexcept -> IsInt< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin< decltype(pattern_arg(args))... >
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp< Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
auto widening_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
auto halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
constexpr bool and_reduce()
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
constexpr bool commutative(IRNodeType t)
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
auto shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
constexpr uint32_t bitwise_or_reduce()
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
constexpr int const_min(int a, int b)
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
auto rounding_halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
constexpr IRNodeType StrongestExprNodeType
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
@ Predicate
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
Expr likely_if_innermost(Expr e)
Equivalent to likely, but only triggers a loop partitioning if found in an innermost loop.
Expr abs(Expr a)
Returns the absolute value of a signed integer or floating-point expression.
Expr likely(Expr e)
Expressions tagged with this intrinsic are considered to be part of the steady state of some loop wit...
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
The sum of two expressions.
Logical and - are both expressions true.
A base class for expression nodes.
A vector with 'lanes' elements, in which every element is 'value'.
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
@ signed_integer_overflow
@ rounding_mul_shift_right
bool is_intrinsic() const
static const IRNodeType _node_type
The actual IR nodes begin here.
static const IRNodeType _node_type
The ratio of two expressions.
Is the first expression equal to the second.
Floating point constants.
static const FloatImm * make(Type t, double value)
Is the first expression greater than or equal to the second.
Is the first expression greater than the second.
static constexpr bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
static constexpr IRNodeType max_node_type
static constexpr IRNodeType min_node_type
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
static constexpr bool foldable
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr bool canonical
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr uint32_t binds
static constexpr IRNodeType min_node_type
static constexpr IRNodeType max_node_type
static constexpr bool foldable
static constexpr bool canonical
static constexpr IRNodeType max_node_type
static constexpr bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
static constexpr uint32_t binds
static constexpr bool canonical
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr IRNodeType max_node_type
static constexpr uint32_t binds
static constexpr bool canonical
static constexpr bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
static constexpr IRNodeType max_node_type
static constexpr uint32_t binds
static constexpr IRNodeType min_node_type
static constexpr bool canonical
static constexpr bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
static constexpr bool canonical
static constexpr bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
static constexpr IRNodeType max_node_type
static constexpr bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr uint32_t binds
static constexpr bool foldable
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
std::tuple< Args... > args
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
HALIDE_ALWAYS_INLINE Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
static constexpr IRNodeType min_node_type
static constexpr IRNodeType min_node_type
static constexpr bool canonical
static constexpr bool foldable
static constexpr IRNodeType max_node_type
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr bool foldable
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr bool canonical
static constexpr IRNodeType min_node_type
static constexpr IRNodeType max_node_type
static constexpr uint32_t binds
static constexpr IRNodeType max_node_type
static constexpr bool foldable
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr IRNodeType min_node_type
static constexpr bool canonical
static constexpr bool foldable
static constexpr IRNodeType min_node_type
static constexpr bool canonical
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr uint32_t binds
static constexpr IRNodeType min_node_type
static constexpr bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr bool foldable
static constexpr IRNodeType max_node_type
static constexpr uint32_t binds
static constexpr IRNodeType max_node_type
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr IRNodeType min_node_type
static constexpr bool foldable
static constexpr bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
static constexpr bool foldable
static constexpr IRNodeType min_node_type
static constexpr bool canonical
static constexpr uint32_t binds
static constexpr IRNodeType max_node_type
To save stack space, the matcher objects are largely stateless and immutable.
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
static constexpr uint16_t special_values_mask
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
halide_type_t bound_const_type[max_wild]
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
HALIDE_ALWAYS_INLINE MatcherState() noexcept
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
halide_scalar_value_t bound_const[max_wild]
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
static constexpr uint16_t signed_integer_overflow
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr uint32_t binds
static constexpr bool canonical
static constexpr bool foldable
static constexpr IRNodeType max_node_type
static constexpr IRNodeType min_node_type
static constexpr uint32_t binds
static constexpr bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr IRNodeType max_node_type
static constexpr bool canonical
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
static constexpr uint32_t binds
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr bool foldable
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
static constexpr uint32_t binds
static constexpr bool canonical
static constexpr IRNodeType max_node_type
static constexpr bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr bool canonical
static constexpr IRNodeType max_node_type
static constexpr IRNodeType min_node_type
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
static constexpr bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_NEVER_INLINE void build_replacement(After after)
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
halide_type_t wildcard_type
halide_type_t output_type
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr bool foldable
static constexpr bool canonical
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr IRNodeType max_node_type
static constexpr IRNodeType min_node_type
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr bool canonical
static constexpr IRNodeType max_node_type
const BaseExprNode & expr
static constexpr uint32_t binds
static constexpr bool foldable
static constexpr bool canonical
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
static constexpr uint32_t binds
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr bool foldable
static constexpr bool canonical
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr IRNodeType max_node_type
static constexpr IRNodeType min_node_type
static constexpr uint32_t binds
static constexpr bool foldable
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr bool canonical
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
static constexpr bool foldable
static constexpr bool canonical
static constexpr bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
static constexpr IRNodeType max_node_type
static constexpr uint32_t binds
static constexpr bool foldable
static constexpr IRNodeType max_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
static constexpr IRNodeType min_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr bool canonical
static constexpr IRNodeType min_node_type
static constexpr uint32_t binds
static constexpr IRNodeType max_node_type
static constexpr bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
static constexpr bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
static constexpr uint32_t mask
IRNodeType node_type
Each IR node subclass has a unique identifier.
static const IntImm * make(Type t, int64_t value)
Is the first expression less than or equal to the second.
Is the first expression less than the second.
The greater of two values.
The lesser of two values.
The product of two expressions.
Is the first expression not equal to the second.
Logical not - true if the expression false.
Logical or - is at least one of the expression true.
A linear ramp vector node.
static const IRNodeType _node_type
static Expr make(Expr base, Expr stride, int lanes)
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
The difference of two expressions.
static const IRNodeType _node_type
static Expr make(Expr a, Expr b)
Unsigned integer constants.
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
static const IRNodeType _node_type
static Expr make(Operator op, Expr vec, int lanes)
Types in the halide type system.
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
HALIDE_ALWAYS_INLINE bool is_vector() const
Is this type a vector type? (lanes() != 1).
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@3 u
A runtime tag for a type in the halide type system.
uint8_t bits
The number of bits of precision of a single scalar value of this type.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.