1 #ifndef _H_TIRAMISU_EXPR_ 2 #define _H_TIRAMISU_EXPR_ 6 #include <isl/union_map.h> 7 #include <isl/union_set.h> 8 #include <isl/ast_build.h> 9 #include <isl/schedule.h> 10 #include <isl/schedule_node.h> 11 #include <isl/space.h> 14 #include <unordered_map> 18 #include <type_traits> 21 #include <tiramisu/debug.h> 53 static bool auto_data_mapping;
67 static function *implicit_fct;
82 return global::implicit_fct;
96 global::implicit_fct = fct;
108 global::auto_data_mapping = v;
121 return global::auto_data_mapping;
126 global::loop_iterator_type =
p_int32;
131 global::loop_iterator_type = t;
136 return global::loop_iterator_type;
167 std::vector<tiramisu::expr> op;
192 std::vector<tiramisu::expr> access_vector;
203 std::vector<tiramisu::expr> argument_vector;
233 this->defined =
false;
250 this->defined =
true;
252 this->op.push_back(expr0);
268 this->defined =
true;
270 this->op.push_back(expr0);
282 this->defined =
true;
294 tiramisu::str_dump(
"Binary operation between two expressions of different types:\n");
296 tiramisu::str_dump(
" and ");
298 tiramisu::str_dump(
"\n");
299 ERROR(
"\nThe two expressions should be of the same type. Use casting to elevate the type of one expression to the other.\n",
true);
305 this->defined =
true;
307 this->op.push_back(expr0);
308 this->op.push_back(expr1);
317 "expr1 and expr2 should be of the same type.");
322 this->defined =
true;
324 this->op.push_back(expr0);
325 this->op.push_back(expr1);
326 this->op.push_back(expr2);
333 std::vector<tiramisu::expr> vec,
338 "The operator is not an access or a call operator.");
340 assert(vec.size() > 0);
341 assert(name.size() > 0);
346 this->defined =
true;
351 this->set_access(vec);
355 this->set_arguments(vec);
359 ERROR(
"Type of operator is not o_access, o_call, o_address_of, o_buffer, or o_lin_index.",
true);
372 this->defined =
true;
375 this->uint8_value = val;
385 this->defined =
true;
388 this->int8_value = val;
396 this->defined =
true;
401 this->uint16_value = val;
409 this->defined =
true;
414 this->int16_value = val;
424 this->defined =
true;
427 this->uint32_value = val;
437 this->defined =
true;
440 this->int32_value = val;
450 this->defined =
true;
453 this->uint64_value = val;
463 this->defined =
true;
466 this->int64_value = val;
476 this->defined =
true;
479 this->float32_value = val;
494 this->defined =
true;
497 this->float64_value = val;
573 return float32_value;
581 return float64_value;
593 result = this->get_uint8_value();
597 result = this->get_int8_value();
601 result = this->get_uint16_value();
605 result = this->get_int16_value();
609 result = this->get_uint32_value();
613 result = this->get_int32_value();
617 result = this->get_uint64_value();
621 result = this->get_int64_value();
625 result = this->get_float32_value();
629 result = this->get_float64_value();
633 ERROR(
"Calling get_int_val() on a non integer expression.",
true);
647 result = this->get_float32_value();
651 result = this->get_float64_value();
655 ERROR(
"Calling get_double_val() on a non double expression.",
true);
668 assert((i < (
int)this->op.size()) &&
"Operand index is out of bounds.");
680 return this->op.size();
733 const std::string &replace_with)
735 if (this->name == to_replace) {
736 this->name = replace_with;
739 for (
int i = 0; i < this->op.size(); i++) {
769 return access_vector;
780 return argument_vector;
791 return access_vector.size();
816 uint16_t uint16_value;
818 uint32_t uint32_value;
820 uint64_t uint64_value;
823 double float64_value;
827 std::vector<tiramisu::expr> access_vector;
829 std::vector<tiramisu::expr> argument_vector;
831 if ((this->_operator != e._operator) ||
832 (this->op.size() != e.op.size()) ||
833 (this->access_vector.size() != e.access_vector.size()) ||
834 (this->argument_vector.size() != e.argument_vector.size()) ||
835 (this->defined != e.defined) ||
836 (this->name != e.
name) ||
837 (this->dtype != e.
dtype) ||
838 (this->etype != e.
etype))
844 for (
int i = 0; i < this->access_vector.size(); i++)
845 equal = equal && this->access_vector[i].is_equal(e.access_vector[i]);
847 for (
int i = 0; i < this->op.size(); i++)
848 equal = equal && this->op[i].is_equal(e.op[i]);
850 for (
int i = 0; i < this->argument_vector.size(); i++)
851 equal = equal && this->argument_vector[i].is_equal(e.argument_vector[i]);
992 access_vector = vector;
1001 assert((i < (
int)this->access_vector.size()) &&
"index is out of bounds.");
1002 access_vector[i] = acc;
1012 argument_vector = vector;
1023 if (this->get_expr_type() !=
e_none)
1025 if (exhaustive ==
true)
1027 if (ENABLE_DEBUG && (this->is_defined()))
1029 std::cout <<
"Expression:" << std::endl;
1031 switch (this->etype)
1035 std::cout <<
"Expression operator type:" <<
str_tiramisu_type_op(this->_operator) << std::endl;
1036 if (this->get_n_arg() > 0)
1038 std::cout <<
"Number of operands:" << this->get_n_arg() << std::endl;
1039 std::cout <<
"Dumping the operands:" << std::endl;
1040 for (
int i = 0; i < this->get_n_arg(); i++)
1042 std::cout <<
"Operand " << std::to_string(i) <<
"." << std::endl;
1043 this->op[i].dump(exhaustive);
1048 std::cout <<
"Access to " + this->get_name() +
". Access expressions:" << std::endl;
1049 for (
const auto &e : this->get_access())
1055 std::cout <<
"Address to " + this->get_name() +
". Access expressions:" << std::endl;
1056 for (
const auto &e : this->get_access()) {
1061 std::cout <<
"Linear address to " + this->get_name() +
". Access expressions:" 1063 for (
const auto &e : this->get_access()) {
1069 std::cout <<
"call to " + this->get_name() +
". Argument expressions:" << std::endl;
1070 for (
const auto &e : this->get_arguments())
1077 std::cout <<
"Address of the following access : " << std::endl;
1078 this->get_operand(0).dump(
true);
1082 std::cout <<
"allocate(" << this->get_name() <<
")" << std::endl;
1086 std::cout <<
"free(" << this->get_name() <<
")" << std::endl;
1096 std::cout <<
"Value:" << this->get_uint8_value() << std::endl;
1100 std::cout <<
"Value:" << this->get_int8_value() << std::endl;
1104 std::cout <<
"Value:" << this->get_uint16_value() << std::endl;
1108 std::cout <<
"Value:" << this->get_int16_value() << std::endl;
1112 std::cout <<
"Value:" << this->get_uint32_value() << std::endl;
1116 std::cout <<
"Value:" << this->get_int32_value() << std::endl;
1120 std::cout <<
"Value:" << this->get_uint64_value() << std::endl;
1124 std::cout <<
"Value:" << this->get_int64_value() << std::endl;
1128 std::cout <<
"Value:" << this->get_float32_value() << std::endl;
1132 std::cout <<
"Value:" << this->get_float64_value() << std::endl;
1138 std::cout <<
"Var name:" << this->get_name() << std::endl;
1143 std::cout <<
"Sync object" << std::endl;
1146 ERROR(
"Expression type not supported.",
true);
1152 std::cout << this->to_str();
1170 if (this->get_name() ==
"_unbounded")
1181 if (this->get_expr_type() !=
e_none)
1183 switch (this->etype)
1187 switch (this->get_op_type())
1200 this->get_operand(0).simplify();
1201 this->get_operand(1).simplify();
1204 return expr(this->get_operand(0).get_int_val() + this->get_operand(1).get_int_val());
1206 this->get_operand(0).simplify();
1207 this->get_operand(1).simplify();
1210 return expr(this->get_operand(0).get_int_val() - this->get_operand(1).get_int_val());
1212 this->get_operand(0).simplify();
1213 this->get_operand(1).simplify();
1216 return expr(this->get_operand(0).get_int_val() * this->get_operand(1).get_int_val());
1298 ERROR(
"Simplifying an unsupported tiramisu expression.", 1);
1311 ERROR(
"Expression type not supported.",
true);
1320 std::string str = std::string(
"");
1322 if (this->get_expr_type() !=
e_none)
1324 switch (this->etype)
1328 switch (this->get_op_type())
1332 this->get_operand(0).dump(
false);
1334 str += this->get_operand(1).to_str();
1338 str +=
"(" + this->get_operand(0).to_str();
1339 str +=
" || " + this->get_operand(1).to_str();
1343 str +=
"max(" + this->get_operand(0).to_str();
1344 str +=
", " + this->get_operand(1).to_str();
1348 str +=
"min(" + this->get_operand(0).to_str();
1349 str +=
", " + this->get_operand(1).to_str();
1353 str +=
"(-" + this->get_operand(0).to_str();
1357 str +=
"(" + this->get_operand(0).to_str();
1358 str +=
" + " + this->get_operand(1).to_str();
1362 str +=
"(" + this->get_operand(0).to_str();
1363 str +=
" - " + this->get_operand(1).to_str();
1367 str +=
"(" + this->get_operand(0).to_str();
1368 str +=
" * " + this->get_operand(1).to_str();
1372 str +=
"(" + this->get_operand(0).to_str();
1373 str +=
" / " + this->get_operand(1).to_str();
1377 str +=
"(" + this->get_operand(0).to_str();
1378 str +=
" % " + this->get_operand(1).to_str();
1382 str +=
"memcpy(" + this->get_operand(0).to_str();
1383 str +=
", " + this->get_operand(1).to_str();
1387 str +=
"select(" + this->get_operand(0).to_str();
1388 str +=
", " + this->get_operand(1).to_str();
1389 str +=
", " + this->get_operand(2).to_str();
1393 str +=
"if(" + this->get_operand(0).to_str();
1394 str +=
"):(" + this->get_operand(1).to_str();
1398 str +=
"lerp(" + this->get_operand(0).to_str();
1399 str +=
", " + this->get_operand(1).to_str();
1400 str +=
", " + this->get_operand(2).to_str();
1404 str +=
"(" + this->get_operand(0).to_str();
1405 str +=
" <= " + this->get_operand(1).to_str();
1409 str +=
"(" + this->get_operand(0).to_str();
1410 str +=
" < " + this->get_operand(1).to_str();
1414 str +=
"(" + this->get_operand(0).to_str();
1415 str +=
" >= " + this->get_operand(1).to_str();
1419 str +=
"(" + this->get_operand(0).to_str();
1420 str +=
" > " + this->get_operand(1).to_str();
1424 str +=
"(!" + this->get_operand(0).to_str();
1428 str +=
"(" + this->get_operand(0).to_str();
1429 str +=
" == " + this->get_operand(1).to_str();
1433 str +=
"(" + this->get_operand(0).to_str();
1434 str +=
" != " + this->get_operand(1).to_str();
1438 str +=
"(" + this->get_operand(0).to_str();
1439 str +=
" >> " + this->get_operand(1).to_str();
1443 str +=
"(" + this->get_operand(0).to_str();
1444 str +=
" << " + this->get_operand(1).to_str();
1448 str +=
"floor(" + this->get_operand(0).to_str();
1452 str +=
"sin(" + this->get_operand(0).to_str();
1456 str +=
"cos(" + this->get_operand(0).to_str();
1460 str +=
"tan(" + this->get_operand(0).to_str();
1464 str +=
"atan(" + this->get_operand(0).to_str();
1468 str +=
"acos(" + this->get_operand(0).to_str();
1472 str +=
"asin(" + this->get_operand(0).to_str();
1476 str +=
"sinh(" + this->get_operand(0).to_str();
1480 str +=
"cosh(" + this->get_operand(0).to_str();
1484 str +=
"tanh(" + this->get_operand(0).to_str();
1488 str +=
"asinh(" + this->get_operand(0).to_str();
1492 str +=
"acosh(" + this->get_operand(0).to_str();
1496 str +=
"atanh(" + this->get_operand(0).to_str();
1500 str +=
"abs(" + this->get_operand(0).to_str();
1504 str +=
"sqrt(" + this->get_operand(0).to_str();
1508 str +=
"exp(" + this->get_operand(0).to_str();
1512 str +=
"log(" + this->get_operand(0).to_str();
1516 str +=
"ceil(" + this->get_operand(0).to_str();
1520 str +=
"round(" + this->get_operand(0).to_str();
1524 str +=
"trunc(" + this->get_operand(0).to_str();
1528 str +=
"cast(" + this->get_operand(0).to_str();
1535 str += this->get_name() +
"(";
1536 for (
int k = 0; k < this->get_access().size(); k++)
1542 str += this->get_access()[k].to_str();
1547 str += this->get_name() +
"(";
1548 for (
int k = 0; k < this->get_arguments().size(); k++)
1554 str += this->get_arguments()[k].to_str();
1559 str +=
"&" + this->get_operand(0).get_name();
1562 str +=
"allocate(" + this->get_name() +
")";
1565 str +=
"free(" + this->get_name() +
")";
1568 ERROR(
"Dumping an unsupported tiramisu expression.", 1);
1576 str += std::to_string((
int)this->get_uint8_value());
1580 str += std::to_string((
int)this->get_int8_value());
1584 str += std::to_string(this->get_uint16_value());
1588 str += std::to_string(this->get_int16_value());
1592 str += std::to_string(this->get_uint32_value());
1596 str += std::to_string(this->get_int32_value());
1600 str += std::to_string(this->get_uint64_value());
1604 str += std::to_string(this->get_int64_value());
1608 str += std::to_string(this->get_float32_value());
1612 str += std::to_string(this->get_float64_value());
1618 str += this->get_name();
1623 str +=
"sync object";
1627 ERROR(
"Expression type not supported.",
true);
1640 expr substitute(std::vector<std::pair<var, expr>> substitutions)
const;
1650 expr substitute_access(std::string original, std::string substitute)
const;
1655 for (
int i = 0; i < access_vector.size(); i++)
1656 e.access_vector[i] = f(e.access_vector[i]);
1657 for (
int i = 0; i < op.size(); i++)
1658 e.op[i] = f(e.op[i]);
1659 for (
int i = 0; i < argument_vector.size(); i++)
1660 e.argument_vector[i] = f(e.argument_vector[i]);
1674 e.
name =
"_unbounded";
1716 static std::unordered_map<std::string, var> declared_vars;
1723 var(std::string name,
bool save);
1783 lower = lower_bound;
1784 upper = upper_bound;
1800 std::vector<isl_ast_expr *> &index_expr,
1808 template <
typename cT>
1816 return expr{
static_cast<int8_t
>(val)};
1818 return expr{
static_cast<uint8_t
>(val)};
1820 return expr{
static_cast<int16_t
>(val)};
1822 return expr{
static_cast<uint16_t
>(val)};
1824 return expr{
static_cast<int32_t
>(val)};
1826 return expr{
static_cast<uint32_t
>(val)};
1828 return expr{
static_cast<int64_t
>(val)};
1830 return expr{
static_cast<uint64_t
>(val)};
1832 return expr{
static_cast<float>(val)};
1834 return expr{
static_cast<double>(val)};
1836 throw std::invalid_argument{
"Type not supported"};
1846 template <
typename T>
1852 template <
typename T>
1858 template <
typename T>
1864 template <
typename T>
1870 template <
typename T>
1873 return e /
expr{val};
1876 template <
typename T>
1879 return expr{val} / e;
1882 template <
typename T>
1888 template <
typename T>
1894 template <
typename T>
1897 return e %
expr{val};
1900 template <
typename T>
1903 return expr{val} % e;
1906 template <
typename T>
1909 return e >>
expr{val};
1912 template <
typename T>
1915 return expr{val} >> e;
1918 template <
typename T>
1921 return e <<
expr{val};
1924 template <
typename T>
1927 return expr{val} << e;
uint8_t get_uint8_value() const
Return the actual value of the expression.
int get_n_arg() const
Return the number of arguments of the operator.
bool is_equal(tiramisu::expr e) const
Return true if e is identical to this expression.
bool is_unbounded() const
static function * get_implicit_function()
Return the implicit function created during Tiramisu initialization.
double get_float64_value() const
Return the actual value of the expression.
A class that represents computations.
only_integral< T > operator/(const tiramisu::expr &e, T val)
expr_t
The possible types of an expression.
expr cast(primitive_t tT, const expr &e)
Returns an expression that casts e to tT.
expr value_cast(primitive_t tT, cT val)
Takes in a primitive value val, and returns an expression of tiramisu type tT that represents val...
expr(tiramisu::op_t o, tiramisu::primitive_t dtype, tiramisu::expr expr0)
Create a cast expression to type t (a unary operator).
tiramisu::expr operator!=(tiramisu::expr e1) const
Comparison operator.
std::string str_tiramisu_type_op(tiramisu::op_t type)
tiramisu::primitive_t get_data_type() const
Get the data type of the expression.
int8_t get_int8_value() const
Return the actual value of the expression.
primitive_t
tiramisu data types.
float get_float32_value() const
Return the actual value of the expression.
static expr unbounded()
Create a variable that can be used that a dimension is unbounded.
tiramisu::primitive_t dtype
Data type.
static primitive_t get_loop_iterator_data_type()
void set_arguments(std::vector< tiramisu::expr > vector)
Set the arguments of an external function call.
Halide::Expr halide_expr_from_tiramisu_expr(const tiramisu::computation *comp, std::vector< isl_ast_expr * > &index_expr, const tiramisu::expr &tiramisu_expr)
Convert a Tiramisu expression into a Halide expression.
const std::vector< tiramisu::expr > & get_arguments() const
Return the arguments of an external function call.
expr memcpy(const buffer &from, const buffer &to)
var(tiramisu::primitive_t type, std::string name)
Construct an expression that represents a variable.
expr()
Create an undefined expression.
std::string generate_new_variable_name()
tiramisu::expr operator>(tiramisu::expr e1) const
Greater than operator.
void dump(bool exhaustive) const
Dump the object on standard output (dump most of the fields of the expression class).
void set_access(std::vector< tiramisu::expr > vector)
Set the access of a computation or an array.
only_integral< T > operator%(const tiramisu::expr &e, T val)
static void set_default_tiramisu_options()
A class that represents buffers.
bool is_defined() const
Return true if the expression is defined.
expr(double val)
Construct a 64-bit float expression.
expr(int8_t val)
Construct a signed 8-bit integer expression.
A class that represents a synchronization object.
expr(uint8_t val)
Construct an unsigned 8-bit integer expression.
only_integral< T > operator*(const tiramisu::expr &e, T val)
only_integral< T > operator>>(const tiramisu::expr &e, T val)
expr(tiramisu::op_t o, std::string name, std::vector< tiramisu::expr > vec, tiramisu::primitive_t type)
Construct an access or a call.
const std::vector< tiramisu::expr > & get_access() const
Return a vector of the access of the computation or array.
int64_t get_int_val() const
tiramisu::expr simplify() const
Simplify the expression.
expr(tiramisu::op_t o, tiramisu::expr expr0, tiramisu::expr expr1)
Construct an expression for a binary operator.
tiramisu::expr operator>=(tiramisu::expr e1) const
Greater than or equal operator.
A class to represent tiramisu expressions.
tiramisu::expr_t get_expr_type() const
Return the type of the expression (tiramisu::expr_type).
tiramisu::expr_t etype
The type of the expression.
expr(float val)
Construct a 32-bit float expression.
static void set_auto_data_mapping(bool v)
If this option is set to true, Tiramisu automatically modifies the computation data mapping whenever ...
only_integral< T > operator-(const tiramisu::expr &e, T val)
static bool is_auto_data_mapping_set()
Return whether auto data mapping is set.
expr(uint16_t val)
Construct an unsigned 16-bit integer expression.
int get_n_dim_access() const
Get the number of dimensions in the access vector.
uint64_t get_uint64_value() const
Return the actual value of the expression.
double get_double_val() const
tiramisu::expr operator<(tiramisu::expr e1) const
Less than operator.
static void set_implicit_function(function *fct)
Return the implicit function created during Tiramisu initialization.
expr apply_to_operands(std::function< expr(const expr &)> f) const
expr(tiramisu::op_t o, tiramisu::expr expr0)
Create an expression for a unary operator.
A class that holds all the global variables necessary for Tiramisu.
int16_t get_int16_value() const
Return the actual value of the expression.
expr(uint64_t val)
Construct an unsigned 64-bit integer expression.
std::string str_from_tiramisu_type_expr(tiramisu::expr_t type)
tiramisu::expr operator||(tiramisu::expr e1) const
Logical and of two expressions.
const tiramisu::expr & get_operand(int i) const
Return the value of the i 'th operand of the expression.
uint32_t get_uint32_value() const
Return the actual value of the expression.
int32_t get_int32_value() const
Return the actual value of the expression.
uint16_t get_uint16_value() const
Return the actual value of the expression.
expr(tiramisu::op_t o, std::string name)
Create an expression for a unary operator that applies on a variable.
op_t
Types of tiramisu operators.
std::string str_from_tiramisu_type_primitive(tiramisu::primitive_t type)
tiramisu::expr operator&&(tiramisu::expr e1) const
Logical and of two expressions.
tiramisu::op_t get_op_type() const
Get the type of the operator (tiramisu::op_t).
expr allocate(const buffer &b)
tiramisu::expr operator<=(tiramisu::expr e1) const
Less than or equal operator.
var(std::string name)
Construct an expression that represents an untyped variable.
static void set_loop_iterator_type(primitive_t t)
expr(int64_t val)
Construct a signed 64-bit integer expression.
tiramisu::expr replace_op_in_expr(const std::string &to_replace, const std::string &replace_with)
expr(tiramisu::op_t o, tiramisu::expr expr0, tiramisu::expr expr1, tiramisu::expr expr2)
Construct an expression for a ternary operator.
std::string name
Identifier name.
const std::string & get_name() const
Get the name of the ID or the variable represented by this expressions.
only_integral< T > operator+(const tiramisu::expr &e, T val)
int64_t get_int64_value() const
Return the actual value of the expression.
only_integral< T > operator<<(const tiramisu::expr &e, T val)
void set_name(std::string &name)
tiramisu::expr operator==(tiramisu::expr e1) const
Comparison operator.
var(std::string name, expr lower_bound, expr upper_bound)
Construct a loop iterator that has name as a name.
typename std::enable_if< std::is_integral< T >::value, expr >::type only_integral
tiramisu::expr operator-() const
Expression multiplied by (-1).
tiramisu::expr operator!() const
Logical NOT of an expression.
A class that represents constant variable references.
expr(int16_t val)
Construct a signed 16-bit integer expression.
expr(uint32_t val)
Construct an unsigned 32-bit integer expression.
std::string to_str() const
bool is_constant() const
Return true if this expression is a literal constant (i.e., 0, 1, 2, ...).
expr(int32_t val)
Construct a signed 32-bit integer expression.
void set_access_dimension(int i, tiramisu::expr acc)
Set an element of the vector of accesses of a computation.
A class for code generation.