5 #ifndef TIRAMISU_CUDA_AST_H 6 #define TIRAMISU_CUDA_AST_H 9 #define NVCC_PATH "nvcc" 12 #define UNARY(op, x) {op, op_data_t{true, 1, (x)}} 13 #define UNARY_TYPED(op, x, T) {op, op_data_t{true, 2, (x), (T)}} 14 #define BINARY(op, x) {op, op_data_t{true, 2, (x)}} 15 #define BINARY_TYPED(op, x, T) {op, op_data_t{true, 2, (x), (T)}} 16 #define TERNARY(op, x, y) {op, op_data_t{true, 3, (x), (y)}} 17 #define TERNARY_TYPED(op, x, y, T) {op, op_data_t{true, 3, (x), (y), (T)}} 18 #define FN_CALL(op, x, n) {op, op_data_t{false, (n), (x)}} 19 #define FN_CALL_TYPED(op, x, n, T) {op, op_data_t{false, (n), (x), (T)}} 21 #define UNARY2(op, x) case op: return tiramisu::cuda_ast::op_data_t{true, 1, (x)}; 22 #define UNARY_TYPED2(op, x, T) case op: return tiramisu::cuda_ast::op_data_t{true, 2, (x), (T)}; 23 #define BINARY2(op, x) case op: return tiramisu::cuda_ast::op_data_t{true, 2, (x)}; 24 #define BINARY_TYPED2(op, x, T) case op: return tiramisu::cuda_ast::op_data_t{true, 2, (x), (T)}; 25 #define TERNARY2(op, x, y) case op: return tiramisu::cuda_ast::op_data_t{true, 3, (x), (y)}; 26 #define TERNARY_TYPED2(op, x, y, T) case op: return tiramisu::cuda_ast::op_data_t{true, 3, (x), (y), (T)}; 27 #define FN_CALL2(op, x, n) case op: return tiramisu::cuda_ast::op_data_t{false, (n), (x)}; 28 #define FN_CALL_TYPED2(op, x, n, T) case op: return tiramisu::cuda_ast::op_data_t{false, (n), (x), (T)}; 40 void operator()(isl_ast_expr * p)
const {isl_ast_expr_free(p);}
48 typedef std::unique_ptr<isl_id, isl_id_deleter>
isl_id_ptr;
52 void operator()(isl_ast_node_list * p)
const {isl_ast_node_list_free(p);}
60 typedef std::unique_ptr<isl_val, isl_val_deleter>
isl_val_ptr;
68 op_data_t(
bool infix,
int arity, std::string && symbol) : infix(infix), arity(arity), symbol(symbol) {}
69 op_data_t(
bool infix,
int arity, std::string && symbol, std::string && next_symbol) : infix(infix), arity(arity), symbol(symbol), next_symbol(next_symbol) {}
70 op_data_t(
bool infix,
int arity, std::string && symbol,
primitive_t type) : infix(infix), arity(arity), symbol(symbol), type_preserving(
72 op_data_t(
bool infix,
int arity, std::string && symbol, std::string && next_symbol,
primitive_t type) : infix(infix), arity(arity), symbol(symbol), next_symbol(next_symbol), type_preserving(
75 bool operator==(
const op_data_t &rhs)
const;
77 bool operator!=(
const op_data_t &rhs)
const;
82 std::string next_symbol =
"";
83 bool type_preserving =
true;
201 virtual void print_body(std::stringstream &ss,
const std::string &base);
202 virtual void print(std::stringstream &ss,
const std::string &base) = 0;
203 virtual std::pair<statement_ptr, statement_ptr> extract_min_cap();
205 virtual statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators);
206 virtual std::unordered_set<std::string> extract_scalars();
220 statement_ptr to_be_cast;
223 void print(std::stringstream &ss,
const std::string &base)
override ;
224 statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators)
override;
225 std::unordered_set<std::string> extract_scalars()
override;
231 std::vector<statement_ptr> elements;
234 void print(std::stringstream &ss,
const std::string &base)
override;
235 void print_body(std::stringstream &ss,
const std::string &base)
override ;
242 void add_statement(statement_ptr stmt);
252 const std::string &get_name()
const;
254 virtual void print_declaration(std::stringstream &ss,
const std::string &base) = 0;
255 virtual bool is_buffer()
const;
271 void print(std::stringstream &ss,
const std::string &base)
override;
272 void print_declaration(std::stringstream &ss,
const std::string &base)
override;
273 void print_size(std::stringstream &ss,
const std::string &base,
const std::string &seperator);
274 bool is_buffer()
const override;
278 std::vector<statement_ptr> size;
290 void print(std::stringstream &ss,
const std::string &base)
override;
291 void print_declaration(std::stringstream &ss,
const std::string &base)
override;
292 statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators)
override;
293 std::unordered_set<std::string> extract_scalars()
override;
298 typedef std::shared_ptr<value>
value_ptr;
305 explicit value(uint8_t val);
306 explicit value(int8_t val);
307 explicit value(uint16_t val);
308 explicit value(int16_t val);
309 explicit value(uint32_t val);
310 explicit value(int32_t val);
311 explicit value(uint64_t val);
312 explicit value(int64_t val);
313 explicit value(
float val);
314 explicit value(
double val);
319 void print(std::stringstream &ss,
const std::string &base)
override;
320 statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators)
override;
348 virtual void print_declaration(std::stringstream &ss,
const std::string &base);
358 void print(std::stringstream &ss,
const std::string &base)
override;
359 void print_declaration(std::stringstream &ss,
const std::string &base)
override;
369 void print(std::stringstream &ss,
const std::string &base)
override;
379 void print(std::stringstream &ss,
const std::string &base)
override;
380 std::pair<statement_ptr, statement_ptr> extract_min_cap()
override;
381 statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators)
override;
382 std::unordered_set<std::string> extract_scalars()
override;
386 std::vector<statement_ptr> arguments;
392 for_loop(statement_ptr initialization, statement_ptr condition, statement_ptr incrementer, statement_ptr body);
395 void print(std::stringstream &ss,
const std::string &base)
override;
398 statement_ptr initial_value;
399 statement_ptr condition;
400 statement_ptr incrementer;
407 if_condition(statement_ptr condition, statement_ptr then_body, statement_ptr else_body);
408 if_condition(statement_ptr condition, statement_ptr then_body);
411 void print(std::stringstream &ss,
const std::string &base)
override;
414 statement_ptr condition;
415 statement_ptr then_body;
417 statement_ptr else_body;
423 buffer_access(buffer_ptr accessed,
const std::vector<statement_ptr> &access);
426 void print(std::stringstream &ss,
const std::string &base)
override;
427 statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators)
override;
428 std::unordered_set<std::string> extract_scalars()
override;
432 std::vector<cuda_ast::statement_ptr> access;
439 op(
primitive_t type,
const std::vector<statement_ptr> & operands);
441 std::unordered_set<std::string> extract_scalars()
override;
450 void print(std::stringstream &ss,
const std::string &base)
override;
451 statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators)
override;
454 std::string m_op_symbol;
460 binary(
primitive_t type, statement_ptr operand_1, statement_ptr operand_2, std::string &&op_symbol);
463 void print(std::stringstream &ss,
const std::string &base)
override;
464 statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators)
override;
466 std::string m_op_symbol;
472 ternary(
primitive_t type, statement_ptr operand_1, statement_ptr operand_2, statement_ptr operand_3, std::string &&op_symbol_1, std::string &&op_symbol_2);
475 void print(std::stringstream &ss,
const std::string &base)
override;
476 statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators)
override;
478 std::string m_op_symbol_1;
479 std::string m_op_symbol_2;
498 void print(std::stringstream &ss,
const std::string &base)
override;
503 abstract_identifier_ptr id;
504 assignment_ptr asgmnt;
511 void print(std::stringstream &ss,
const std::string &base)
override;
514 typedef std::unordered_map<std::string, std::pair<tiramisu::primitive_t, cuda_ast::memory_location> >
scalar_data_t;
529 std::string simplified_name();
540 void print(std::stringstream &ss,
const std::string &base)
override;
550 statement_ptr return_value;
553 void print(std::stringstream &ss,
const std::string &base)
override;
558 host_function(
primitive_t type, std::string name,
const std::vector<abstract_identifier_ptr> &arguments, statement_ptr body);
559 void print(std::stringstream &ss,
const std::string &base)
override;
560 void set_pointer_return(
bool val =
true);
566 std::vector<abstract_identifier_ptr> arguments;
576 statement_ptr x, y, z;
582 dim3d_t block_dimensions;
583 dim3d_t thread_dimensions;
584 std::map<std::string, scalar_ptr> used_constants;
585 std::map<std::string, buffer_ptr> used_buffers;
587 static int kernel_count;
592 void set_body(statement_ptr body);
593 std::string get_name()
const;
594 std::string get_wrapper_name()
const;
595 static constexpr
auto wrapper_return_type =
p_int32;
596 void add_used_scalar(scalar_ptr
scalar);
597 void add_used_buffer(buffer_ptr
buffer);
598 std::vector<abstract_identifier_ptr> get_arguments();
607 void print(std::stringstream &ss,
const std::string &base)
override;
619 void print(std::stringstream &ss,
const std::string &base)
override;
628 memcpy(buffer_ptr from, buffer_ptr to);
629 void print(std::stringstream &ss,
const std::string &base)
override;
639 void print(std::stringstream &ss,
const std::string &base)
override;
649 void print(std::stringstream &ss,
const std::string &base)
override;
661 scalar_data_t m_scalar_data;
662 std::unordered_map<std::string, cuda_ast::buffer_ptr> m_buffers;
666 kernel_ptr current_kernel;
667 std::unordered_map<isl_ast_node*, kernel_ptr> iterator_to_kernel_map;
668 std::vector<kernel_ptr> kernels;
670 bool in_kernel =
false;
671 std::vector<std::string> iterator_stack;
672 std::vector<cuda_ast::statement_ptr> iterator_upper_bound;
673 std::vector<cuda_ast::statement_ptr> iterator_lower_bound;
674 std::vector<cuda_ast::statement_ptr> kernel_simplified_vars;
676 std::unordered_map<std::string, cuda_ast::gpu_iterator> gpu_iterators;
677 std::vector<cuda_ast::statement_ptr> gpu_conditions;
678 std::unordered_set<std::string> gpu_local;
682 statement_ptr get_scalar_from_name(std::string name);
683 std::unordered_map<computation *, std::vector<isl_ast_expr*>> index_exprs;
687 statement_ptr cuda_stmt_from_isl_node(isl_ast_node *node);
688 statement_ptr cuda_stmt_handle_isl_for(isl_ast_node *node);
689 statement_ptr cuda_stmt_val_from_for_condition(isl_ast_expr_ptr &
expr, isl_ast_node *node);
690 statement_ptr cuda_stmt_handle_isl_block(isl_ast_node *node);
691 statement_ptr cuda_stmt_handle_isl_if(isl_ast_node *node);
692 statement_ptr cuda_stmt_handle_isl_user(isl_ast_node *node);
694 statement_ptr cuda_stmt_handle_isl_op_expr(isl_ast_expr_ptr &
expr, isl_ast_node *node);
695 void cuda_stmt_foreach_isl_expr_list(isl_ast_expr *node,
const std::function<
void(
int, isl_ast_expr *)> &fn,
int start = 0);
720 bool compile_cpu_obj(
const std::string &filename,
const std::string &obj_name)
const;
721 bool compile_gpu_obj(
const std::string &obj_name)
const;
722 static exec_result exec(
const std::string &cmd);
725 std::string get_cpu_obj(
const std::string &obj_name)
const;
726 std::string get_gpu_obj(
const std::string &obj_name)
const;
727 explicit compiler(
const std::string &code);
728 bool compile(
const std::string &obj_name)
const;
735 #endif //TIRAMISU_CUDA_AST_H std::unique_ptr< isl_val, isl_val_deleter > isl_val_ptr
std::shared_ptr< kernel > kernel_ptr
std::unique_ptr< isl_id, isl_id_deleter > isl_id_ptr
expr cast(primitive_t tT, const expr &e)
Returns an expression that casts e to tT.
op_data_t(bool infix, int arity, std::string &&symbol)
op_data_t(bool infix, int arity, std::string &&symbol, primitive_t type)
primitive_t
tiramisu data types.
void operator()(isl_ast_expr *p) const
expr memcpy(const buffer &from, const buffer &to)
std::unique_ptr< isl_ast_expr, isl_ast_expr_deleter > isl_ast_expr_ptr
std::shared_ptr< scalar > scalar_ptr
op_data_t(bool infix, int arity, std::string &&symbol, std::string &&next_symbol)
std::vector< statement_ptr > m_operands
const op_data_t isl_operation_description(isl_ast_op_type op)
void operator()(isl_val *p) const
std::shared_ptr< buffer > buffer_ptr
op_data_t(bool infix, int arity, std::string &&symbol, std::string &&next_symbol, primitive_t type)
const std::string tiramisu_type_to_cuda_type(tiramisu::primitive_t t)
std::shared_ptr< assignment > assignment_ptr
void operator()(isl_id *p) const
A class to represent tiramisu expressions.
A class to represent functions in Tiramisu.
A class that holds all the global variables necessary for Tiramisu.
std::shared_ptr< abstract_identifier > abstract_identifier_ptr
std::unordered_map< std::string, std::pair< tiramisu::primitive_t, cuda_ast::memory_location > > scalar_data_t
A class that represents loop invariants.
op_t
Types of tiramisu operators.
const op_data_t tiramisu_operation_description(tiramisu::op_t op)
expr allocate(const buffer &b)
void operator()(isl_ast_node_list *p) const
std::shared_ptr< statement > statement_ptr
std::unique_ptr< isl_ast_node_list, isl_ast_node_list_deleter > isl_ast_node_list_ptr
std::shared_ptr< value > value_ptr