Last active
April 11, 2019 01:34
-
-
Save yzhliu/0d2cdfc9fa92127b81a1298d5bec55a0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h | |
index fbf8fe7e..1c397f40 100644 | |
--- a/src/arithmetic/const_fold.h | |
+++ b/src/arithmetic/const_fold.h | |
@@ -101,33 +101,28 @@ inline bool IsIndexType(const Type& type) { | |
// specialization of constant folders. | |
template<> | |
inline Expr TryConstFold<ir::Add>(Expr a, Expr b) { | |
- TVM_ARITH_CONST_PROPAGATION({ | |
+ TVM_INDEX_CONST_PROPAGATION({ | |
const Type& rtype = a.type(); | |
if (pa && pb) return IntImm::make(rtype, pa->value + pb->value); | |
if (pa && pa->value == 0) return b; | |
if (pb && pb->value == 0) return a; | |
- if (fa && fb) return FloatImm::make(rtype, fa->value + fb->value); | |
- if (fa && fa->value == 0) return b; | |
- if (fb && fb->value == 0) return a; | |
}); | |
return Expr(); | |
} | |
template<> | |
inline Expr TryConstFold<ir::Sub>(Expr a, Expr b) { | |
- TVM_ARITH_CONST_PROPAGATION({ | |
+ TVM_INDEX_CONST_PROPAGATION({ | |
const Type& rtype = a.type(); | |
if (pa && pb) return IntImm::make(rtype, pa->value - pb->value); | |
if (pb && pb->value == 0) return a; | |
- if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value); | |
- if (fb && fb->value == 0) return a; | |
}); | |
return Expr(); | |
} | |
template<> | |
inline Expr TryConstFold<ir::Mul>(Expr a, Expr b) { | |
- TVM_ARITH_CONST_PROPAGATION({ | |
+ TVM_INDEX_CONST_PROPAGATION({ | |
const Type& rtype = a.type(); | |
if (pa && pb) return IntImm::make(rtype, pa->value * pb->value); | |
if (pa) { | |
@@ -138,22 +133,13 @@ inline Expr TryConstFold<ir::Mul>(Expr a, Expr b) { | |
if (pb->value == 1) return a; | |
if (pb->value == 0) return b; | |
} | |
- if (fa && fb) return FloatImm::make(rtype, fa->value * fb->value); | |
- if (fa) { | |
- if (fa->value == 1) return b; | |
- if (fa->value == 0) return a; | |
- } | |
- if (fb) { | |
- if (fb->value == 1) return a; | |
- if (fb->value == 0) return b; | |
- } | |
}); | |
return Expr(); | |
} | |
template<> | |
inline Expr TryConstFold<ir::Div>(Expr a, Expr b) { | |
- TVM_ARITH_CONST_PROPAGATION({ | |
+ TVM_INDEX_CONST_PROPAGATION({ | |
const Type& rtype = a.type(); | |
// due to division and mod can have different modes | |
// only constant fold positive number where rule is fixed. | |
@@ -167,14 +153,6 @@ inline Expr TryConstFold<ir::Div>(Expr a, Expr b) { | |
if (pb->value == 1) return a; | |
CHECK_NE(pb->value, 0) << "Divide by zero"; | |
} | |
- if (fa && fb && fb->value != 0) { | |
- return FloatImm::make(rtype, fa->value / fb->value); | |
- } | |
- if (fa && fa->value == 0) return a; | |
- if (fb) { | |
- if (fb->value == 1) return a; | |
- CHECK_NE(fb->value, 0) << "Divide by zero"; | |
- } | |
}); | |
return Expr(); | |
} | |
@@ -201,20 +179,18 @@ inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) { | |
template<> | |
inline Expr TryConstFold<ir::Min>(Expr a, Expr b) { | |
- TVM_ARITH_CONST_PROPAGATION({ | |
+ TVM_INDEX_CONST_PROPAGATION({ | |
const Type& rtype = a.type(); | |
if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value)); | |
- if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value)); | |
}); | |
return Expr(); | |
} | |
template<> | |
inline Expr TryConstFold<ir::Max>(Expr a, Expr b) { | |
- TVM_ARITH_CONST_PROPAGATION({ | |
+ TVM_INDEX_CONST_PROPAGATION({ | |
const Type& rtype = a.type(); | |
if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value)); | |
- if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value)); | |
}); | |
return Expr(); | |
} | |
diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc | |
index 4504ee23..753ad6a8 100644 | |
--- a/src/lang/expr_operator.cc | |
+++ b/src/lang/expr_operator.cc | |
@@ -106,14 +106,11 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) { | |
Expr cast(const Type& t, Expr value) { | |
using ir::IntImm; | |
- using ir::FloatImm; | |
if (value.type() == t) return value; | |
// const fold IntImm as they are used in index computations | |
if (t.lanes() == 1) { | |
if (const IntImm* op = value.as<IntImm>()) { | |
return make_const(t, op->value); | |
- } else if (const FloatImm* op = value.as<FloatImm>()) { | |
- return make_const(t, op->value); | |
} | |
return ir::Cast::make(t, value); | |
} else { | |
@@ -123,8 +120,6 @@ Expr cast(const Type& t, Expr value) { | |
if (value.type() != vtype) { | |
if (const IntImm* op = value.as<IntImm>()) { | |
value = make_const(vtype, op->value); | |
- } else if (const FloatImm* op = value.as<FloatImm>()) { | |
- value = make_const(vtype, op->value); | |
} else { | |
value = ir::Cast::make(vtype, value); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment