Tiramisu Compiler
cuda_ast.h
Go to the documentation of this file.
1 //
2 // Created by malek on 12/18/17.
3 //
4 
5 #ifndef TIRAMISU_CUDA_AST_H
6 #define TIRAMISU_CUDA_AST_H
7 
8 #ifndef NVCC_PATH
9 #define NVCC_PATH "nvcc"
10 #endif
11 
12 #define UNARY(op, x) {op, op_data_t{true, 1, (x)}}
13 #define UNARY_TYPED(op, x, T) {op, op_data_t{true, 2, (x), (T)}}
14 #define BINARY(op, x) {op, op_data_t{true, 2, (x)}}
15 #define BINARY_TYPED(op, x, T) {op, op_data_t{true, 2, (x), (T)}}
16 #define TERNARY(op, x, y) {op, op_data_t{true, 3, (x), (y)}}
17 #define TERNARY_TYPED(op, x, y, T) {op, op_data_t{true, 3, (x), (y), (T)}}
18 #define FN_CALL(op, x, n) {op, op_data_t{false, (n), (x)}}
19 #define FN_CALL_TYPED(op, x, n, T) {op, op_data_t{false, (n), (x), (T)}}
20 
21 #define UNARY2(op, x) case op: return tiramisu::cuda_ast::op_data_t{true, 1, (x)};
22 #define UNARY_TYPED2(op, x, T) case op: return tiramisu::cuda_ast::op_data_t{true, 2, (x), (T)};
23 #define BINARY2(op, x) case op: return tiramisu::cuda_ast::op_data_t{true, 2, (x)};
24 #define BINARY_TYPED2(op, x, T) case op: return tiramisu::cuda_ast::op_data_t{true, 2, (x), (T)};
25 #define TERNARY2(op, x, y) case op: return tiramisu::cuda_ast::op_data_t{true, 3, (x), (y)};
26 #define TERNARY_TYPED2(op, x, y, T) case op: return tiramisu::cuda_ast::op_data_t{true, 3, (x), (y), (T)};
27 #define FN_CALL2(op, x, n) case op: return tiramisu::cuda_ast::op_data_t{false, (n), (x)};
28 #define FN_CALL_TYPED2(op, x, n, T) case op: return tiramisu::cuda_ast::op_data_t{false, (n), (x), (T)};
29 
30 #include <isl/id.h>
31 #include <tiramisu/type.h>
32 #include <string>
33 #include <vector>
34 #include "utils.h"
35 
36 namespace tiramisu
37 {
39  {
40  void operator()(isl_ast_expr * p) const {isl_ast_expr_free(p);}
41  };
42  typedef std::unique_ptr<isl_ast_expr, isl_ast_expr_deleter> isl_ast_expr_ptr;
43 
45  {
46  void operator()(isl_id * p) const {isl_id_free(p);}
47  };
48  typedef std::unique_ptr<isl_id, isl_id_deleter> isl_id_ptr;
49 
51  {
52  void operator()(isl_ast_node_list * p) const {isl_ast_node_list_free(p);}
53  };
54  typedef std::unique_ptr<isl_ast_node_list, isl_ast_node_list_deleter> isl_ast_node_list_ptr;
55 
57  {
58  void operator()(isl_val * p) const {isl_val_free(p);}
59  };
60  typedef std::unique_ptr<isl_val, isl_val_deleter> isl_val_ptr;
61 
62  class function;
63 namespace cuda_ast
64 {
65  struct op_data_t
66  {
67  op_data_t() {}
68  op_data_t(bool infix, int arity, std::string && symbol) : infix(infix), arity(arity), symbol(symbol) {}
69  op_data_t(bool infix, int arity, std::string && symbol, std::string && next_symbol) : infix(infix), arity(arity), symbol(symbol), next_symbol(next_symbol) {}
70  op_data_t(bool infix, int arity, std::string && symbol, primitive_t type) : infix(infix), arity(arity), symbol(symbol), type_preserving(
71  false), type(type) {}
72  op_data_t(bool infix, int arity, std::string && symbol, std::string && next_symbol, primitive_t type) : infix(infix), arity(arity), symbol(symbol), next_symbol(next_symbol), type_preserving(
73  false), type(type) {}
74 
75  bool operator==(const op_data_t &rhs) const;
76 
77  bool operator!=(const op_data_t &rhs) const;
78 
79  bool infix;
80  int arity;
81  std::string symbol;
82  std::string next_symbol = "";
83  bool type_preserving = true;
85  };
86 
88 
89 // const std::unordered_map <tiramisu::op_t , op_data_t> tiramisu_operation_description = {
90 // UNARY(o_minus, "-"),
91 // FN_CALL(o_floor, "floor", 1),
92 // FN_CALL(o_sin, "sin", 1),
93 // FN_CALL(o_cos, "cos", 1),
94 // FN_CALL(o_tan, "tan", 1),
95 // FN_CALL(o_asin, "asin", 1),
96 // FN_CALL(o_acos, "acos", 1),
97 // FN_CALL(o_atan, "atan", 1),
98 // FN_CALL(o_sinh, "sinh", 1),
99 // FN_CALL(o_cosh, "cosh", 1),
100 // FN_CALL(o_tanh, "tanh", 1),
101 // FN_CALL(o_asinh, "asinh", 1),
102 // FN_CALL(o_acosh, "acosh", 1),
103 // FN_CALL(o_atanh, "atanh", 1),
104 // FN_CALL(o_abs, "abs", 1),
105 // FN_CALL(o_sqrt, "sqrt", 1),
106 // FN_CALL(o_expo, "exp", 1),
107 // FN_CALL(o_log, "log", 1),
108 // FN_CALL(o_ceil, "ceil", 1),
109 // FN_CALL(o_round, "round", 1),
110 // FN_CALL(o_trunc, "trunc", 1),
111 // BINARY(o_add, "+"),
112 // BINARY(o_sub, "-"),
113 // BINARY(o_mul, "*"),
114 // BINARY(o_div, "/"),
115 // BINARY(o_mod, "%"),
116 // BINARY(o_logical_and, "&&"),
117 // BINARY(o_logical_or, "||"),
118 // UNARY(o_logical_not, "!"),
119 // BINARY(o_eq, "=="),
120 // BINARY(o_ne, "!="),
121 // BINARY(o_le, "<="),
122 // BINARY(o_lt, "<"),
123 // BINARY(o_ge, ">="),
124 // BINARY(o_gt, ">"),
125 // FN_CALL(o_max, "max", 2),
126 // FN_CALL(o_min, "min", 2),
127 // BINARY(o_right_shift, ">>"),
128 // BINARY(o_left_shift, "<<"),
129 // TERNARY(o_select, "?", ":"),
130 // FN_CALL(o_lerp, "lerp", 3),
131 // };
132 
133 const op_data_t isl_operation_description(isl_ast_op_type op);
134 
135 // const std::unordered_map <isl_ast_op_type , op_data_t> isl_operation_description = {
136 // BINARY_TYPED(isl_ast_op_and, "&&", p_boolean),
137 // BINARY_TYPED(isl_ast_op_and_then, "&&", p_boolean),
138 // BINARY_TYPED(isl_ast_op_or, "||", p_boolean),
139 // BINARY_TYPED(isl_ast_op_or_else, "||", p_boolean),
140 // FN_CALL(isl_ast_op_max, "max", 2),
141 // FN_CALL(isl_ast_op_min, "min", 2),
142 // UNARY(isl_ast_op_minus, "-"),
143 // BINARY(isl_ast_op_add, "+"),
144 // BINARY(isl_ast_op_sub, "-"),
145 // BINARY(isl_ast_op_mul, "*"),
146 // BINARY(isl_ast_op_div, "/"),
147 // BINARY(isl_ast_op_fdiv_q, "/"),
148 // BINARY(isl_ast_op_pdiv_q, "/"),
149 // BINARY(isl_ast_op_pdiv_r, "%"),
150 // BINARY(isl_ast_op_zdiv_r, "%"),
151 // TERNARY(isl_ast_op_cond, "?", ":"),
152 // FN_CALL(isl_ast_op_select, "lerp", 3),
153 // BINARY_TYPED(isl_ast_op_eq, "==", p_boolean),
154 // BINARY_TYPED(isl_ast_op_le, "<=", p_boolean),
155 // BINARY_TYPED(isl_ast_op_lt, "<", p_boolean),
156 // BINARY_TYPED(isl_ast_op_ge, ">=", p_boolean),
157 // BINARY_TYPED(isl_ast_op_gt, ">", p_boolean),
158 // };
159 
161 
162 
163 // const std::unordered_map <tiramisu::primitive_t, std::string> tiramisu_type_to_cuda_type = {
164 // {p_none, "void"},
165 // {p_boolean, "bool"},
166 // {p_int8, "int8_t"},
167 // {p_uint8, "uint8_t"},
168 // {p_int16, "int16_t"},
169 // {p_uint16, "uint16_t"},
170 // {p_int32, "int32_t"},
171 // {p_uint32, "uint32_t"},
172 // {p_int64, "int64_t"},
173 // {p_uint64, "uint64_t"},
174 // {p_float32, "float"},
175 // {p_float64, "double"},
176 // };
177 enum class memory_location
178 {
179  host,
180  global,
181  shared,
182  local,
183  constant,
184  reg,
185 };
186 
188 
189 };
190 
191  class statement;
192  typedef std::shared_ptr<statement> statement_ptr;
193 
194 
195 
196  struct gpu_iterator;
197 class statement : public abstract_node {
198 public:
199  primitive_t get_type() const;
200  std::string print();
201  virtual void print_body(std::stringstream &ss, const std::string &base);
202  virtual void print(std::stringstream &ss, const std::string &base) = 0;
203  virtual std::pair<statement_ptr, statement_ptr> extract_min_cap();
204  // TODO implement in more subclasses
205  virtual statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators);
206  virtual std::unordered_set<std::string> extract_scalars();
207 
208 protected:
209 
210  explicit statement(primitive_t type);
211 // template <std::function<statement_ptr(statement_ptr)> F>
212 // virtual statement_ptr apply()
213 // {}
214 
215 private:
217 };
218 
219 class cast : public statement {
220  statement_ptr to_be_cast;
221 public:
222  cast(primitive_t type, statement_ptr stmt);
223  void print(std::stringstream &ss, const std::string &base) override ;
224  statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators) override;
225  std::unordered_set<std::string> extract_scalars() override;
226 
227 };
228 
229 class block : public statement {
230 private:
231  std::vector<statement_ptr> elements;
232 
233 public:
234  void print(std::stringstream &ss, const std::string &base) override;
235  void print_body(std::stringstream &ss, const std::string &base) override ;
236 
237 public:
238  block();
239 
240  virtual ~block();
241 
242  void add_statement(statement_ptr stmt);
243 
244 };
245 
247 {
248 protected:
249  abstract_identifier(primitive_t type, const std::string &name, memory_location location);
250 
251 public:
252  const std::string &get_name() const;
253  memory_location get_location() const;
254  virtual void print_declaration(std::stringstream &ss, const std::string &base) = 0;
255  virtual bool is_buffer() const;
256 
257 
258 private:
259  std::string name;
260  cuda_ast::memory_location location;
261 
262 public:
263 
264 };
265  typedef std::shared_ptr<abstract_identifier> abstract_identifier_ptr;
266 
268 {
269 public:
270  buffer(primitive_t type, const std::string &name, memory_location location, const std::vector<statement_ptr> &size);
271  void print(std::stringstream &ss, const std::string &base) override;
272  void print_declaration(std::stringstream &ss, const std::string &base) override;
273  void print_size(std::stringstream &ss, const std::string &base, const std::string &seperator);
274  bool is_buffer() const override;
275 
276 
277 private:
278  std::vector<statement_ptr> size;
279 };
280  typedef std::shared_ptr<buffer> buffer_ptr;
281 
283 {
284  bool is_const;
285 public:
286  scalar(primitive_t type, const std::string &name, memory_location location);
287  scalar(primitive_t type, const std::string &name, memory_location location, bool is_const);
288 
289 public:
290  void print(std::stringstream &ss, const std::string &base) override;
291  void print_declaration(std::stringstream &ss, const std::string &base) override;
292  statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators) override;
293  std::unordered_set<std::string> extract_scalars() override;
294 };
295  typedef std::shared_ptr<scalar> scalar_ptr;
296 
297 class value;
298 typedef std::shared_ptr<value> value_ptr;
299 
300 class value : public statement
301 {
302 public:
303 
304  explicit value(const tiramisu::expr & expr);
305  explicit value(uint8_t val);
306  explicit value(int8_t val);
307  explicit value(uint16_t val);
308  explicit value(int16_t val);
309  explicit value(uint32_t val);
310  explicit value(int32_t val);
311  explicit value(uint64_t val);
312  explicit value(int64_t val);
313  explicit value(float val);
314  explicit value(double val);
315 
316  value_ptr copy();
317 
318 public:
319  void print(std::stringstream &ss, const std::string &base) override;
320  statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators) override;
321 
322 private:
323 
324  /**
325  * The value.
326  */
327  union
328  {
329  uint8_t u8_val;
330  int8_t i8_val;
331  uint16_t u16_val;
332  int16_t i16_val;
333  uint32_t u32_val;
334  int32_t i32_val;
335  uint64_t u64_val;
336  int64_t i64_val;
337  float f32_val;
338  double f64_val;
339  };
340 };
341 
342 class assignment : public statement
343 {
344 protected:
345  explicit assignment(primitive_t type);
346 
347 public:
348  virtual void print_declaration(std::stringstream &ss, const std::string &base);
349 };
350  typedef std::shared_ptr<assignment> assignment_ptr;
351 
353 {
354  scalar_ptr m_scalar;
355  statement_ptr m_rhs;
356 public:
357  scalar_assignment(scalar_ptr scalar, statement_ptr rhs);
358  void print(std::stringstream &ss, const std::string &base) override;
359  void print_declaration(std::stringstream &ss, const std::string &base) override;
360 
361 };
362 
364 {
365  cuda_ast::buffer_ptr m_buffer;
366  cuda_ast::statement_ptr m_index_access;
368 public:
369  void print(std::stringstream &ss, const std::string &base) override;
370 public:
371  buffer_assignment(cuda_ast::buffer_ptr buffer, statement_ptr index_access, statement_ptr rhs);
372 };
373 
374 class function_call : public statement
375 {
376 public:
377  function_call(primitive_t type, const std::string &name, const std::vector<statement_ptr> &arguments);
378 public:
379  void print(std::stringstream &ss, const std::string &base) override;
380  std::pair<statement_ptr, statement_ptr> extract_min_cap() override;
381  statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators) override;
382  std::unordered_set<std::string> extract_scalars() override;
383 
384 private:
385  std::string name;
386  std::vector<statement_ptr> arguments;
387 };
388 
389 class for_loop : public statement
390 {
391 public:
392  for_loop(statement_ptr initialization, statement_ptr condition, statement_ptr incrementer, statement_ptr body);
393 
394 public:
395  void print(std::stringstream &ss, const std::string &base) override;
396 
397 private:
398  statement_ptr initial_value;
399  statement_ptr condition;
400  statement_ptr incrementer;
401  statement_ptr body;
402 };
403 
404 class if_condition : public statement
405 {
406 public:
407  if_condition(statement_ptr condition, statement_ptr then_body, statement_ptr else_body);
408  if_condition(statement_ptr condition, statement_ptr then_body);
409 
410 public:
411  void print(std::stringstream &ss, const std::string &base) override;
412 
413 private:
414  statement_ptr condition;
415  statement_ptr then_body;
416  bool has_else;
417  statement_ptr else_body;
418 };
419 
420 class buffer_access : public statement
421 {
422 public:
423  buffer_access(buffer_ptr accessed, const std::vector<statement_ptr> &access);
424 
425 public:
426  void print(std::stringstream &ss, const std::string &base) override;
427  statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators) override;
428  std::unordered_set<std::string> extract_scalars() override;
429 
430 private:
431  buffer_ptr accessed;
432  std::vector<cuda_ast::statement_ptr> access;
433 };
434 
435 class op : public statement
436 {
437 
438 protected:
439  op(primitive_t type, const std::vector<statement_ptr> & operands);
440  std::vector<statement_ptr> m_operands;
441  std::unordered_set<std::string> extract_scalars() override;
442 
443 };
444 
445 class unary : public op
446 {
447 public:
448  unary(primitive_t type, statement_ptr operand, std::string &&op_symbol);
449 public:
450  void print(std::stringstream &ss, const std::string &base) override;
451  statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators) override;
452 
453 private:
454  std::string m_op_symbol;
455 };
456 
457 class binary : public op
458 {
459 public:
460  binary(primitive_t type, statement_ptr operand_1, statement_ptr operand_2, std::string &&op_symbol);
461 
462 public:
463  void print(std::stringstream &ss, const std::string &base) override;
464  statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators) override;
465 private:
466  std::string m_op_symbol;
467 };
468 
469 class ternary : public op
470 {
471 public:
472  ternary(primitive_t type, statement_ptr operand_1, statement_ptr operand_2, statement_ptr operand_3, std::string &&op_symbol_1, std::string &&op_symbol_2);
473 
474 public:
475  void print(std::stringstream &ss, const std::string &base) override;
476  statement_ptr replace_iterators(std::unordered_map<std::string, gpu_iterator> & iterators) override;
477 private:
478  std::string m_op_symbol_1;
479  std::string m_op_symbol_2;
480 
481 };
482 
483 //class assignment : public statement
484 //{
485 //public:
486 // assignment(primitive_t type, abstract_identifier *identifier, statement *value);
487 //
488 //private:
489 // abstract_identifier * identifier;
490 // statement * value;
491 //};
492 
493 class declaration : public statement
494 {
495 public:
496  explicit declaration (abstract_identifier_ptr id);
497  explicit declaration (assignment_ptr asgmnt);
498  void print(std::stringstream &ss, const std::string &base) override;
499 
500 
501 private:
502  bool is_initialized;
503  abstract_identifier_ptr id;
504  assignment_ptr asgmnt;
505 };
506 
507 class sync : public statement
508 {
509 public:
510  sync();
511  void print(std::stringstream &ss, const std::string &base) override;
512 };
513 
514 typedef std::unordered_map<std::string, std::pair<tiramisu::primitive_t, cuda_ast::memory_location> > scalar_data_t;
515 
517 {
518  enum class type_t {
519  THREAD,
520  BLOCK
521  } type;
522  enum class dimension_t{
523  x = 0,
524  y,
525  z
526  } dimension;
527  statement_ptr size;
528  // returns a simplified name; __tx__, __ty__, __tz__, __bx__, __by__, __bz__
529  std::string simplified_name();
530 };
531 
533 {
534 private:
535  gpu_iterator it;
536  bool simplified;
537 public:
538  explicit gpu_iterator_read(gpu_iterator it);
539  explicit gpu_iterator_read(gpu_iterator it, bool simplified);
540  void print(std::stringstream &ss, const std::string &base) override;
541 };
542 
543 
544  class kernel_call;
545  class kernel_definition;
546 
548 {
549 private:
550  statement_ptr return_value;
551 public:
552  explicit return_statement(statement_ptr return_value);
553  void print(std::stringstream &ss, const std::string &base) override;
554 };
555 class host_function : public statement
556 {
557 public:
558  host_function(primitive_t type, std::string name, const std::vector<abstract_identifier_ptr> &arguments, statement_ptr body);
559  void print(std::stringstream &ss, const std::string &base) override;
560  void set_pointer_return(bool val = true);
561 
562 private:
563  bool pointer_return;
564  std::string name;
565  statement_ptr body;
566  std::vector<abstract_identifier_ptr> arguments;
567 };
568 
569 class kernel
570 {
571  friend class kernel_call;
572  friend class kernel_definition;
573 private:
574  struct dim3d_t
575  {
576  statement_ptr x, y, z;
577  dim3d_t();
578 
579  void set(gpu_iterator::dimension_t dim, statement_ptr size);
580 
581  };
582  dim3d_t block_dimensions;
583  dim3d_t thread_dimensions;
584  std::map<std::string, scalar_ptr> used_constants;
585  std::map<std::string, buffer_ptr> used_buffers;
586  statement_ptr body;
587  static int kernel_count;
588  int kernel_number;
589 public:
590  kernel();
591  void set_dimension(gpu_iterator dimension);
592  void set_body(statement_ptr body);
593  std::string get_name() const;
594  std::string get_wrapper_name() const;
595  static constexpr auto wrapper_return_type = p_int32;
596  void add_used_scalar(scalar_ptr scalar);
597  void add_used_buffer(buffer_ptr buffer);
598  std::vector<abstract_identifier_ptr> get_arguments();
599 
600 };
601 typedef std::shared_ptr<kernel> kernel_ptr;
602 
603 class kernel_call : public statement
604 {
605 public:
606  explicit kernel_call (kernel_ptr kernel);
607  void print(std::stringstream &ss, const std::string &base) override;
608 
609 
610 private:
611  kernel_ptr kernel;
612 
613 };
614 
616 {
617 public:
618  explicit kernel_definition(kernel_ptr kernel);
619  void print(std::stringstream &ss, const std::string &base) override;
620 
621 private:
622  kernel_ptr kernel;
623 };
624 
625 class memcpy : public statement
626 {
627 public:
628  memcpy(buffer_ptr from, buffer_ptr to);
629  void print(std::stringstream &ss, const std::string &base) override;
630 private:
631  buffer_ptr from, to;
632 };
633 
634 
635 class allocate : public statement
636 {
637 public:
638  allocate(buffer_ptr b);
639  void print(std::stringstream &ss, const std::string &base) override;
640 
641 private:
642  buffer_ptr b;
643 };
644 
645 class free : public statement
646 {
647 public:
648  free(buffer_ptr b);
649  void print(std::stringstream &ss, const std::string &base) override;
650 
651 private:
652  buffer_ptr b;
653 };
654 
655 
657 {
658  friend class tiramisu::function;
659 private:
660  const tiramisu::function &m_fct;
661  scalar_data_t m_scalar_data;
662  std::unordered_map<std::string, cuda_ast::buffer_ptr> m_buffers;
663  cuda_ast::buffer_ptr get_buffer(const std::string & name);
664  cuda_ast::statement_ptr parse_tiramisu(const tiramisu::expr & tiramisu_expr);
665  int loop_level = 0;
666  kernel_ptr current_kernel;
667  std::unordered_map<isl_ast_node*, kernel_ptr> iterator_to_kernel_map;
668  std::vector<kernel_ptr> kernels;
669  // Will be set to true as soon as GPU computation is encountered, and set to false as soon as GPU loop is exited
670  bool in_kernel = false;
671  std::vector<std::string> iterator_stack;
672  std::vector<cuda_ast::statement_ptr> iterator_upper_bound;
673  std::vector<cuda_ast::statement_ptr> iterator_lower_bound;
674  std::vector<cuda_ast::statement_ptr> kernel_simplified_vars;
675  // A mapping from iterator name to GPU info
676  std::unordered_map<std::string, cuda_ast::gpu_iterator> gpu_iterators;
677  std::vector<cuda_ast::statement_ptr> gpu_conditions;
678  std::unordered_set<std::string> gpu_local;
680  cuda_ast::statement_ptr lower_bound,
681  cuda_ast::statement_ptr upper_bound);
682  statement_ptr get_scalar_from_name(std::string name);
683  std::unordered_map<computation *, std::vector<isl_ast_expr*>> index_exprs;
684 public:
685  explicit generator(tiramisu::function &fct);
686 
687  statement_ptr cuda_stmt_from_isl_node(isl_ast_node *node);
688  statement_ptr cuda_stmt_handle_isl_for(isl_ast_node *node);
689  statement_ptr cuda_stmt_val_from_for_condition(isl_ast_expr_ptr &expr, isl_ast_node *node);
690  statement_ptr cuda_stmt_handle_isl_block(isl_ast_node *node);
691  statement_ptr cuda_stmt_handle_isl_if(isl_ast_node *node);
692  statement_ptr cuda_stmt_handle_isl_user(isl_ast_node *node);
693  cuda_ast::statement_ptr cuda_stmt_handle_isl_expr(isl_ast_expr_ptr &expr, isl_ast_node *node);
694  statement_ptr cuda_stmt_handle_isl_op_expr(isl_ast_expr_ptr &expr, isl_ast_node *node);
695  void cuda_stmt_foreach_isl_expr_list(isl_ast_expr *node, const std::function<void(int, isl_ast_expr *)> &fn, int start = 0);
696 
697 
698  static cuda_ast::value_ptr cuda_stmt_handle_isl_val(isl_val_ptr &node);
699 };
700 
701 namespace {
702 
703  struct exec_result {
705  int result;
706  std::string std_out;
707  std::string std_err;
708 
709  bool fail();
710 
711  bool succeed();
712  };
713 
714 }
715 
716  class compiler
717  {
718  std::string code;
719 
720  bool compile_cpu_obj(const std::string &filename, const std::string &obj_name) const;
721  bool compile_gpu_obj(const std::string &obj_name) const;
722  static exec_result exec(const std::string &cmd);
723 
724  public:
725  std::string get_cpu_obj(const std::string &obj_name) const;
726  std::string get_gpu_obj(const std::string &obj_name) const;
727  explicit compiler(const std::string &code);
728  bool compile(const std::string &obj_name) const;
729  };
730 
731 }
732 
733 }
734 
735 #endif //TIRAMISU_CUDA_AST_H
std::unique_ptr< isl_val, isl_val_deleter > isl_val_ptr
Definition: cuda_ast.h:60
std::shared_ptr< kernel > kernel_ptr
Definition: cuda_ast.h:601
std::unique_ptr< isl_id, isl_id_deleter > isl_id_ptr
Definition: cuda_ast.h:48
std::string std_out
Definition: cuda_ast.h:706
expr cast(primitive_t tT, const expr &e)
Returns an expression that casts e to tT.
op_data_t(bool infix, int arity, std::string &&symbol)
Definition: cuda_ast.h:68
op_data_t(bool infix, int arity, std::string &&symbol, primitive_t type)
Definition: cuda_ast.h:70
primitive_t
tiramisu data types.
Definition: type.h:27
void operator()(isl_ast_expr *p) const
Definition: cuda_ast.h:40
expr memcpy(const buffer &from, const buffer &to)
std::unique_ptr< isl_ast_expr, isl_ast_expr_deleter > isl_ast_expr_ptr
Definition: cuda_ast.h:42
std::shared_ptr< scalar > scalar_ptr
Definition: cuda_ast.h:295
op_data_t(bool infix, int arity, std::string &&symbol, std::string &&next_symbol)
Definition: cuda_ast.h:69
int result
Definition: cuda_ast.h:705
std::vector< statement_ptr > m_operands
Definition: cuda_ast.h:440
bool exec_succeeded
Definition: cuda_ast.h:704
const op_data_t isl_operation_description(isl_ast_op_type op)
void operator()(isl_val *p) const
Definition: cuda_ast.h:58
std::shared_ptr< buffer > buffer_ptr
Definition: cuda_ast.h:280
op_data_t(bool infix, int arity, std::string &&symbol, std::string &&next_symbol, primitive_t type)
Definition: cuda_ast.h:72
const std::string tiramisu_type_to_cuda_type(tiramisu::primitive_t t)
std::shared_ptr< assignment > assignment_ptr
Definition: cuda_ast.h:350
void operator()(isl_id *p) const
Definition: cuda_ast.h:46
A class to represent tiramisu expressions.
Definition: expr.h:150
A class to represent functions in Tiramisu.
Definition: core.h:131
A class that holds all the global variables necessary for Tiramisu.
Definition: expr.h:47
std::shared_ptr< abstract_identifier > abstract_identifier_ptr
Definition: cuda_ast.h:265
std::unordered_map< std::string, std::pair< tiramisu::primitive_t, cuda_ast::memory_location > > scalar_data_t
Definition: cuda_ast.h:514
A class that represents loop invariants.
Definition: core.h:4187
op_t
Types of tiramisu operators.
Definition: type.h:53
const op_data_t tiramisu_operation_description(tiramisu::op_t op)
expr allocate(const buffer &b)
void operator()(isl_ast_node_list *p) const
Definition: cuda_ast.h:52
std::string std_err
Definition: cuda_ast.h:707
Definition: core.h:27
std::shared_ptr< statement > statement_ptr
Definition: cuda_ast.h:191
std::unique_ptr< isl_ast_node_list, isl_ast_node_list_deleter > isl_ast_node_list_ptr
Definition: cuda_ast.h:54
std::shared_ptr< value > value_ptr
Definition: cuda_ast.h:297