xref: /llvm-project/mlir/lib/Target/Cpp/TranslateToCpp.cpp (revision 3ef90f843fee74ff811ef88246734475f50e2073)
1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10 #include "mlir/Dialect/EmitC/IR/EmitC.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/SymbolTable.h"
17 #include "mlir/Support/IndentedOstream.h"
18 #include "mlir/Support/LLVM.h"
19 #include "mlir/Target/Cpp/CppEmitter.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/ScopedHashTable.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringMap.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include <stack>
28 #include <utility>
29 
30 #define DEBUG_TYPE "translate-to-cpp"
31 
32 using namespace mlir;
33 using namespace mlir::emitc;
34 using llvm::formatv;
35 
36 /// Convenience functions to produce interleaved output with functions returning
37 /// a LogicalResult. This is different than those in STLExtras as functions used
38 /// on each element doesn't return a string.
39 template <typename ForwardIterator, typename UnaryFunctor,
40           typename NullaryFunctor>
41 inline LogicalResult
42 interleaveWithError(ForwardIterator begin, ForwardIterator end,
43                     UnaryFunctor eachFn, NullaryFunctor betweenFn) {
44   if (begin == end)
45     return success();
46   if (failed(eachFn(*begin)))
47     return failure();
48   ++begin;
49   for (; begin != end; ++begin) {
50     betweenFn();
51     if (failed(eachFn(*begin)))
52       return failure();
53   }
54   return success();
55 }
56 
57 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
58 inline LogicalResult interleaveWithError(const Container &c,
59                                          UnaryFunctor eachFn,
60                                          NullaryFunctor betweenFn) {
61   return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
62 }
63 
64 template <typename Container, typename UnaryFunctor>
65 inline LogicalResult interleaveCommaWithError(const Container &c,
66                                               raw_ostream &os,
67                                               UnaryFunctor eachFn) {
68   return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
69 }
70 
71 /// Return the precedence of a operator as an integer, higher values
72 /// imply higher precedence.
73 static FailureOr<int> getOperatorPrecedence(Operation *operation) {
74   return llvm::TypeSwitch<Operation *, FailureOr<int>>(operation)
75       .Case<emitc::AddOp>([&](auto op) { return 12; })
76       .Case<emitc::ApplyOp>([&](auto op) { return 15; })
77       .Case<emitc::BitwiseAndOp>([&](auto op) { return 7; })
78       .Case<emitc::BitwiseLeftShiftOp>([&](auto op) { return 11; })
79       .Case<emitc::BitwiseNotOp>([&](auto op) { return 15; })
80       .Case<emitc::BitwiseOrOp>([&](auto op) { return 5; })
81       .Case<emitc::BitwiseRightShiftOp>([&](auto op) { return 11; })
82       .Case<emitc::BitwiseXorOp>([&](auto op) { return 6; })
83       .Case<emitc::CallOp>([&](auto op) { return 16; })
84       .Case<emitc::CallOpaqueOp>([&](auto op) { return 16; })
85       .Case<emitc::CastOp>([&](auto op) { return 15; })
86       .Case<emitc::CmpOp>([&](auto op) -> FailureOr<int> {
87         switch (op.getPredicate()) {
88         case emitc::CmpPredicate::eq:
89         case emitc::CmpPredicate::ne:
90           return 8;
91         case emitc::CmpPredicate::lt:
92         case emitc::CmpPredicate::le:
93         case emitc::CmpPredicate::gt:
94         case emitc::CmpPredicate::ge:
95           return 9;
96         case emitc::CmpPredicate::three_way:
97           return 10;
98         }
99         return op->emitError("unsupported cmp predicate");
100       })
101       .Case<emitc::ConditionalOp>([&](auto op) { return 2; })
102       .Case<emitc::DivOp>([&](auto op) { return 13; })
103       .Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
104       .Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
105       .Case<emitc::LogicalOrOp>([&](auto op) { return 3; })
106       .Case<emitc::MulOp>([&](auto op) { return 13; })
107       .Case<emitc::RemOp>([&](auto op) { return 13; })
108       .Case<emitc::SubOp>([&](auto op) { return 12; })
109       .Case<emitc::UnaryMinusOp>([&](auto op) { return 15; })
110       .Case<emitc::UnaryPlusOp>([&](auto op) { return 15; })
111       .Default([](auto op) { return op->emitError("unsupported operation"); });
112 }
113 
114 namespace {
115 /// Emitter that uses dialect specific emitters to emit C++ code.
116 struct CppEmitter {
117   explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
118 
119   /// Emits attribute or returns failure.
120   LogicalResult emitAttribute(Location loc, Attribute attr);
121 
122   /// Emits operation 'op' with/without training semicolon or returns failure.
123   ///
124   /// For operations that should never be followed by a semicolon, like ForOp,
125   /// the `trailingSemicolon` argument is ignored and a semicolon is not
126   /// emitted.
127   LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
128 
129   /// Emits type 'type' or returns failure.
130   LogicalResult emitType(Location loc, Type type);
131 
132   /// Emits array of types as a std::tuple of the emitted types.
133   /// - emits void for an empty array;
134   /// - emits the type of the only element for arrays of size one;
135   /// - emits a std::tuple otherwise;
136   LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
137 
138   /// Emits array of types as a std::tuple of the emitted types independently of
139   /// the array size.
140   LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
141 
142   /// Emits an assignment for a variable which has been declared previously.
143   LogicalResult emitVariableAssignment(OpResult result);
144 
145   /// Emits a variable declaration for a result of an operation.
146   LogicalResult emitVariableDeclaration(OpResult result,
147                                         bool trailingSemicolon);
148 
149   /// Emits a declaration of a variable with the given type and name.
150   LogicalResult emitVariableDeclaration(Location loc, Type type,
151                                         StringRef name);
152 
153   /// Emits the variable declaration and assignment prefix for 'op'.
154   /// - emits separate variable followed by std::tie for multi-valued operation;
155   /// - emits single type followed by variable for single result;
156   /// - emits nothing if no value produced by op;
157   /// Emits final '=' operator where a type is produced. Returns failure if
158   /// any result type could not be converted.
159   LogicalResult emitAssignPrefix(Operation &op);
160 
161   /// Emits a global variable declaration or definition.
162   LogicalResult emitGlobalVariable(GlobalOp op);
163 
164   /// Emits a label for the block.
165   LogicalResult emitLabel(Block &block);
166 
167   /// Emits the operands and atttributes of the operation. All operands are
168   /// emitted first and then all attributes in alphabetical order.
169   LogicalResult emitOperandsAndAttributes(Operation &op,
170                                           ArrayRef<StringRef> exclude = {});
171 
172   /// Emits the operands of the operation. All operands are emitted in order.
173   LogicalResult emitOperands(Operation &op);
174 
175   /// Emits value as an operands of an operation
176   LogicalResult emitOperand(Value value);
177 
178   /// Emit an expression as a C expression.
179   LogicalResult emitExpression(ExpressionOp expressionOp);
180 
181   /// Insert the expression representing the operation into the value cache.
182   void cacheDeferredOpResult(Value value, StringRef str);
183 
184   /// Return the existing or a new name for a Value.
185   StringRef getOrCreateName(Value val);
186 
187   // Returns the textual representation of a subscript operation.
188   std::string getSubscriptName(emitc::SubscriptOp op);
189 
190   // Returns the textual representation of a member (of object) operation.
191   std::string createMemberAccess(emitc::MemberOp op);
192 
193   // Returns the textual representation of a member of pointer operation.
194   std::string createMemberAccess(emitc::MemberOfPtrOp op);
195 
196   /// Return the existing or a new label of a Block.
197   StringRef getOrCreateName(Block &block);
198 
199   /// Whether to map an mlir integer to a unsigned integer in C++.
200   bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
201 
202   /// RAII helper function to manage entering/exiting C++ scopes.
203   struct Scope {
204     Scope(CppEmitter &emitter)
205         : valueMapperScope(emitter.valueMapper),
206           blockMapperScope(emitter.blockMapper), emitter(emitter) {
207       emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
208       emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
209     }
210     ~Scope() {
211       emitter.valueInScopeCount.pop();
212       emitter.labelInScopeCount.pop();
213     }
214 
215   private:
216     llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
217     llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
218     CppEmitter &emitter;
219   };
220 
221   /// Returns wether the Value is assigned to a C++ variable in the scope.
222   bool hasValueInScope(Value val);
223 
224   // Returns whether a label is assigned to the block.
225   bool hasBlockLabel(Block &block);
226 
227   /// Returns the output stream.
228   raw_indented_ostream &ostream() { return os; };
229 
230   /// Returns if all variables for op results and basic block arguments need to
231   /// be declared at the beginning of a function.
232   bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
233 
234   /// Get expression currently being emitted.
235   ExpressionOp getEmittedExpression() { return emittedExpression; }
236 
237   /// Determine whether given value is part of the expression potentially being
238   /// emitted.
239   bool isPartOfCurrentExpression(Value value) {
240     if (!emittedExpression)
241       return false;
242     Operation *def = value.getDefiningOp();
243     if (!def)
244       return false;
245     auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp());
246     return operandExpression == emittedExpression;
247   };
248 
249 private:
250   using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
251   using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
252 
253   /// Output stream to emit to.
254   raw_indented_ostream os;
255 
256   /// Boolean to enforce that all variables for op results and block
257   /// arguments are declared at the beginning of the function. This also
258   /// includes results from ops located in nested regions.
259   bool declareVariablesAtTop;
260 
261   /// Map from value to name of C++ variable that contain the name.
262   ValueMapper valueMapper;
263 
264   /// Map from block to name of C++ label.
265   BlockMapper blockMapper;
266 
267   /// The number of values in the current scope. This is used to declare the
268   /// names of values in a scope.
269   std::stack<int64_t> valueInScopeCount;
270   std::stack<int64_t> labelInScopeCount;
271 
272   /// State of the current expression being emitted.
273   ExpressionOp emittedExpression;
274   SmallVector<int> emittedExpressionPrecedence;
275 
276   void pushExpressionPrecedence(int precedence) {
277     emittedExpressionPrecedence.push_back(precedence);
278   }
279   void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); }
280   static int lowestPrecedence() { return 0; }
281   int getExpressionPrecedence() {
282     if (emittedExpressionPrecedence.empty())
283       return lowestPrecedence();
284     return emittedExpressionPrecedence.back();
285   }
286 };
287 } // namespace
288 
289 /// Determine whether expression \p op should be emitted in a deferred way.
290 static bool hasDeferredEmission(Operation *op) {
291   return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
292                          emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
293 }
294 
295 /// Determine whether expression \p expressionOp should be emitted inline, i.e.
296 /// as part of its user. This function recommends inlining of any expressions
297 /// that can be inlined unless it is used by another expression, under the
298 /// assumption that  any expression fusion/re-materialization was taken care of
299 /// by transformations run by the backend.
300 static bool shouldBeInlined(ExpressionOp expressionOp) {
301   // Do not inline if expression is marked as such.
302   if (expressionOp.getDoNotInline())
303     return false;
304 
305   // Do not inline expressions with side effects to prevent side-effect
306   // reordering.
307   if (expressionOp.hasSideEffects())
308     return false;
309 
310   // Do not inline expressions with multiple uses.
311   Value result = expressionOp.getResult();
312   if (!result.hasOneUse())
313     return false;
314 
315   Operation *user = *result.getUsers().begin();
316 
317   // Do not inline expressions used by operations with deferred emission, since
318   // their translation requires the materialization of variables.
319   if (hasDeferredEmission(user))
320     return false;
321 
322   // Do not inline expressions used by ops with the CExpression trait. If this
323   // was intended, the user could have been merged into the expression op.
324   return !user->hasTrait<OpTrait::emitc::CExpression>();
325 }
326 
327 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
328                                      Attribute value) {
329   OpResult result = operation->getResult(0);
330 
331   // Only emit an assignment as the variable was already declared when printing
332   // the FuncOp.
333   if (emitter.shouldDeclareVariablesAtTop()) {
334     // Skip the assignment if the emitc.constant has no value.
335     if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
336       if (oAttr.getValue().empty())
337         return success();
338     }
339 
340     if (failed(emitter.emitVariableAssignment(result)))
341       return failure();
342     return emitter.emitAttribute(operation->getLoc(), value);
343   }
344 
345   // Emit a variable declaration for an emitc.constant op without value.
346   if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
347     if (oAttr.getValue().empty())
348       // The semicolon gets printed by the emitOperation function.
349       return emitter.emitVariableDeclaration(result,
350                                              /*trailingSemicolon=*/false);
351   }
352 
353   // Emit a variable declaration.
354   if (failed(emitter.emitAssignPrefix(*operation)))
355     return failure();
356   return emitter.emitAttribute(operation->getLoc(), value);
357 }
358 
359 static LogicalResult printOperation(CppEmitter &emitter,
360                                     emitc::ConstantOp constantOp) {
361   Operation *operation = constantOp.getOperation();
362   Attribute value = constantOp.getValue();
363 
364   return printConstantOp(emitter, operation, value);
365 }
366 
367 static LogicalResult printOperation(CppEmitter &emitter,
368                                     emitc::VariableOp variableOp) {
369   Operation *operation = variableOp.getOperation();
370   Attribute value = variableOp.getValue();
371 
372   return printConstantOp(emitter, operation, value);
373 }
374 
375 static LogicalResult printOperation(CppEmitter &emitter,
376                                     emitc::GlobalOp globalOp) {
377 
378   return emitter.emitGlobalVariable(globalOp);
379 }
380 
381 static LogicalResult printOperation(CppEmitter &emitter,
382                                     emitc::AssignOp assignOp) {
383   OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
384 
385   if (failed(emitter.emitVariableAssignment(result)))
386     return failure();
387 
388   return emitter.emitOperand(assignOp.getValue());
389 }
390 
391 static LogicalResult printOperation(CppEmitter &emitter, emitc::LoadOp loadOp) {
392   if (failed(emitter.emitAssignPrefix(*loadOp)))
393     return failure();
394 
395   return emitter.emitOperand(loadOp.getOperand());
396 }
397 
398 static LogicalResult printBinaryOperation(CppEmitter &emitter,
399                                           Operation *operation,
400                                           StringRef binaryOperator) {
401   raw_ostream &os = emitter.ostream();
402 
403   if (failed(emitter.emitAssignPrefix(*operation)))
404     return failure();
405 
406   if (failed(emitter.emitOperand(operation->getOperand(0))))
407     return failure();
408 
409   os << " " << binaryOperator << " ";
410 
411   if (failed(emitter.emitOperand(operation->getOperand(1))))
412     return failure();
413 
414   return success();
415 }
416 
417 static LogicalResult printUnaryOperation(CppEmitter &emitter,
418                                          Operation *operation,
419                                          StringRef unaryOperator) {
420   raw_ostream &os = emitter.ostream();
421 
422   if (failed(emitter.emitAssignPrefix(*operation)))
423     return failure();
424 
425   os << unaryOperator;
426 
427   if (failed(emitter.emitOperand(operation->getOperand(0))))
428     return failure();
429 
430   return success();
431 }
432 
433 static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
434   Operation *operation = addOp.getOperation();
435 
436   return printBinaryOperation(emitter, operation, "+");
437 }
438 
439 static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
440   Operation *operation = divOp.getOperation();
441 
442   return printBinaryOperation(emitter, operation, "/");
443 }
444 
445 static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
446   Operation *operation = mulOp.getOperation();
447 
448   return printBinaryOperation(emitter, operation, "*");
449 }
450 
451 static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
452   Operation *operation = remOp.getOperation();
453 
454   return printBinaryOperation(emitter, operation, "%");
455 }
456 
457 static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
458   Operation *operation = subOp.getOperation();
459 
460   return printBinaryOperation(emitter, operation, "-");
461 }
462 
463 static LogicalResult emitSwitchCase(CppEmitter &emitter,
464                                     raw_indented_ostream &os, Region &region) {
465   for (Region::OpIterator iteratorOp = region.op_begin(), end = region.op_end();
466        std::next(iteratorOp) != end; ++iteratorOp) {
467     if (failed(emitter.emitOperation(*iteratorOp, /*trailingSemicolon=*/true)))
468       return failure();
469   }
470   os << "break;\n";
471   return success();
472 }
473 
474 static LogicalResult printOperation(CppEmitter &emitter,
475                                     emitc::SwitchOp switchOp) {
476   raw_indented_ostream &os = emitter.ostream();
477 
478   os << "\nswitch (";
479   if (failed(emitter.emitOperand(switchOp.getArg())))
480     return failure();
481   os << ") {";
482 
483   for (auto pair : llvm::zip(switchOp.getCases(), switchOp.getCaseRegions())) {
484     os << "\ncase " << std::get<0>(pair) << ": {\n";
485     os.indent();
486 
487     if (failed(emitSwitchCase(emitter, os, std::get<1>(pair))))
488       return failure();
489 
490     os.unindent() << "}";
491   }
492 
493   os << "\ndefault: {\n";
494   os.indent();
495 
496   if (failed(emitSwitchCase(emitter, os, switchOp.getDefaultRegion())))
497     return failure();
498 
499   os.unindent() << "}\n}";
500   return success();
501 }
502 
503 static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
504   Operation *operation = cmpOp.getOperation();
505 
506   StringRef binaryOperator;
507 
508   switch (cmpOp.getPredicate()) {
509   case emitc::CmpPredicate::eq:
510     binaryOperator = "==";
511     break;
512   case emitc::CmpPredicate::ne:
513     binaryOperator = "!=";
514     break;
515   case emitc::CmpPredicate::lt:
516     binaryOperator = "<";
517     break;
518   case emitc::CmpPredicate::le:
519     binaryOperator = "<=";
520     break;
521   case emitc::CmpPredicate::gt:
522     binaryOperator = ">";
523     break;
524   case emitc::CmpPredicate::ge:
525     binaryOperator = ">=";
526     break;
527   case emitc::CmpPredicate::three_way:
528     binaryOperator = "<=>";
529     break;
530   }
531 
532   return printBinaryOperation(emitter, operation, binaryOperator);
533 }
534 
535 static LogicalResult printOperation(CppEmitter &emitter,
536                                     emitc::ConditionalOp conditionalOp) {
537   raw_ostream &os = emitter.ostream();
538 
539   if (failed(emitter.emitAssignPrefix(*conditionalOp)))
540     return failure();
541 
542   if (failed(emitter.emitOperand(conditionalOp.getCondition())))
543     return failure();
544 
545   os << " ? ";
546 
547   if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
548     return failure();
549 
550   os << " : ";
551 
552   if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
553     return failure();
554 
555   return success();
556 }
557 
558 static LogicalResult printOperation(CppEmitter &emitter,
559                                     emitc::VerbatimOp verbatimOp) {
560   raw_ostream &os = emitter.ostream();
561 
562   os << verbatimOp.getValue();
563 
564   return success();
565 }
566 
567 static LogicalResult printOperation(CppEmitter &emitter,
568                                     cf::BranchOp branchOp) {
569   raw_ostream &os = emitter.ostream();
570   Block &successor = *branchOp.getSuccessor();
571 
572   for (auto pair :
573        llvm::zip(branchOp.getOperands(), successor.getArguments())) {
574     Value &operand = std::get<0>(pair);
575     BlockArgument &argument = std::get<1>(pair);
576     os << emitter.getOrCreateName(argument) << " = "
577        << emitter.getOrCreateName(operand) << ";\n";
578   }
579 
580   os << "goto ";
581   if (!(emitter.hasBlockLabel(successor)))
582     return branchOp.emitOpError("unable to find label for successor block");
583   os << emitter.getOrCreateName(successor);
584   return success();
585 }
586 
587 static LogicalResult printOperation(CppEmitter &emitter,
588                                     cf::CondBranchOp condBranchOp) {
589   raw_indented_ostream &os = emitter.ostream();
590   Block &trueSuccessor = *condBranchOp.getTrueDest();
591   Block &falseSuccessor = *condBranchOp.getFalseDest();
592 
593   os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
594      << ") {\n";
595 
596   os.indent();
597 
598   // If condition is true.
599   for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
600                              trueSuccessor.getArguments())) {
601     Value &operand = std::get<0>(pair);
602     BlockArgument &argument = std::get<1>(pair);
603     os << emitter.getOrCreateName(argument) << " = "
604        << emitter.getOrCreateName(operand) << ";\n";
605   }
606 
607   os << "goto ";
608   if (!(emitter.hasBlockLabel(trueSuccessor))) {
609     return condBranchOp.emitOpError("unable to find label for successor block");
610   }
611   os << emitter.getOrCreateName(trueSuccessor) << ";\n";
612   os.unindent() << "} else {\n";
613   os.indent();
614   // If condition is false.
615   for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
616                              falseSuccessor.getArguments())) {
617     Value &operand = std::get<0>(pair);
618     BlockArgument &argument = std::get<1>(pair);
619     os << emitter.getOrCreateName(argument) << " = "
620        << emitter.getOrCreateName(operand) << ";\n";
621   }
622 
623   os << "goto ";
624   if (!(emitter.hasBlockLabel(falseSuccessor))) {
625     return condBranchOp.emitOpError()
626            << "unable to find label for successor block";
627   }
628   os << emitter.getOrCreateName(falseSuccessor) << ";\n";
629   os.unindent() << "}";
630   return success();
631 }
632 
633 static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp,
634                                         StringRef callee) {
635   if (failed(emitter.emitAssignPrefix(*callOp)))
636     return failure();
637 
638   raw_ostream &os = emitter.ostream();
639   os << callee << "(";
640   if (failed(emitter.emitOperands(*callOp)))
641     return failure();
642   os << ")";
643   return success();
644 }
645 
646 static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
647   Operation *operation = callOp.getOperation();
648   StringRef callee = callOp.getCallee();
649 
650   return printCallOperation(emitter, operation, callee);
651 }
652 
653 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
654   Operation *operation = callOp.getOperation();
655   StringRef callee = callOp.getCallee();
656 
657   return printCallOperation(emitter, operation, callee);
658 }
659 
660 static LogicalResult printOperation(CppEmitter &emitter,
661                                     emitc::CallOpaqueOp callOpaqueOp) {
662   raw_ostream &os = emitter.ostream();
663   Operation &op = *callOpaqueOp.getOperation();
664 
665   if (failed(emitter.emitAssignPrefix(op)))
666     return failure();
667   os << callOpaqueOp.getCallee();
668 
669   auto emitArgs = [&](Attribute attr) -> LogicalResult {
670     if (auto t = dyn_cast<IntegerAttr>(attr)) {
671       // Index attributes are treated specially as operand index.
672       if (t.getType().isIndex()) {
673         int64_t idx = t.getInt();
674         Value operand = op.getOperand(idx);
675         if (!emitter.hasValueInScope(operand))
676           return op.emitOpError("operand ")
677                  << idx << "'s value not defined in scope";
678         os << emitter.getOrCreateName(operand);
679         return success();
680       }
681     }
682     if (failed(emitter.emitAttribute(op.getLoc(), attr)))
683       return failure();
684 
685     return success();
686   };
687 
688   if (callOpaqueOp.getTemplateArgs()) {
689     os << "<";
690     if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
691                                         emitArgs)))
692       return failure();
693     os << ">";
694   }
695 
696   os << "(";
697 
698   LogicalResult emittedArgs =
699       callOpaqueOp.getArgs()
700           ? interleaveCommaWithError(*callOpaqueOp.getArgs(), os, emitArgs)
701           : emitter.emitOperands(op);
702   if (failed(emittedArgs))
703     return failure();
704   os << ")";
705   return success();
706 }
707 
708 static LogicalResult printOperation(CppEmitter &emitter,
709                                     emitc::ApplyOp applyOp) {
710   raw_ostream &os = emitter.ostream();
711   Operation &op = *applyOp.getOperation();
712 
713   if (failed(emitter.emitAssignPrefix(op)))
714     return failure();
715   os << applyOp.getApplicableOperator();
716   os << emitter.getOrCreateName(applyOp.getOperand());
717 
718   return success();
719 }
720 
721 static LogicalResult printOperation(CppEmitter &emitter,
722                                     emitc::BitwiseAndOp bitwiseAndOp) {
723   Operation *operation = bitwiseAndOp.getOperation();
724   return printBinaryOperation(emitter, operation, "&");
725 }
726 
727 static LogicalResult
728 printOperation(CppEmitter &emitter,
729                emitc::BitwiseLeftShiftOp bitwiseLeftShiftOp) {
730   Operation *operation = bitwiseLeftShiftOp.getOperation();
731   return printBinaryOperation(emitter, operation, "<<");
732 }
733 
734 static LogicalResult printOperation(CppEmitter &emitter,
735                                     emitc::BitwiseNotOp bitwiseNotOp) {
736   Operation *operation = bitwiseNotOp.getOperation();
737   return printUnaryOperation(emitter, operation, "~");
738 }
739 
740 static LogicalResult printOperation(CppEmitter &emitter,
741                                     emitc::BitwiseOrOp bitwiseOrOp) {
742   Operation *operation = bitwiseOrOp.getOperation();
743   return printBinaryOperation(emitter, operation, "|");
744 }
745 
746 static LogicalResult
747 printOperation(CppEmitter &emitter,
748                emitc::BitwiseRightShiftOp bitwiseRightShiftOp) {
749   Operation *operation = bitwiseRightShiftOp.getOperation();
750   return printBinaryOperation(emitter, operation, ">>");
751 }
752 
753 static LogicalResult printOperation(CppEmitter &emitter,
754                                     emitc::BitwiseXorOp bitwiseXorOp) {
755   Operation *operation = bitwiseXorOp.getOperation();
756   return printBinaryOperation(emitter, operation, "^");
757 }
758 
759 static LogicalResult printOperation(CppEmitter &emitter,
760                                     emitc::UnaryPlusOp unaryPlusOp) {
761   Operation *operation = unaryPlusOp.getOperation();
762   return printUnaryOperation(emitter, operation, "+");
763 }
764 
765 static LogicalResult printOperation(CppEmitter &emitter,
766                                     emitc::UnaryMinusOp unaryMinusOp) {
767   Operation *operation = unaryMinusOp.getOperation();
768   return printUnaryOperation(emitter, operation, "-");
769 }
770 
771 static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
772   raw_ostream &os = emitter.ostream();
773   Operation &op = *castOp.getOperation();
774 
775   if (failed(emitter.emitAssignPrefix(op)))
776     return failure();
777   os << "(";
778   if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
779     return failure();
780   os << ") ";
781   return emitter.emitOperand(castOp.getOperand());
782 }
783 
784 static LogicalResult printOperation(CppEmitter &emitter,
785                                     emitc::ExpressionOp expressionOp) {
786   if (shouldBeInlined(expressionOp))
787     return success();
788 
789   Operation &op = *expressionOp.getOperation();
790 
791   if (failed(emitter.emitAssignPrefix(op)))
792     return failure();
793 
794   return emitter.emitExpression(expressionOp);
795 }
796 
797 static LogicalResult printOperation(CppEmitter &emitter,
798                                     emitc::IncludeOp includeOp) {
799   raw_ostream &os = emitter.ostream();
800 
801   os << "#include ";
802   if (includeOp.getIsStandardInclude())
803     os << "<" << includeOp.getInclude() << ">";
804   else
805     os << "\"" << includeOp.getInclude() << "\"";
806 
807   return success();
808 }
809 
810 static LogicalResult printOperation(CppEmitter &emitter,
811                                     emitc::LogicalAndOp logicalAndOp) {
812   Operation *operation = logicalAndOp.getOperation();
813   return printBinaryOperation(emitter, operation, "&&");
814 }
815 
816 static LogicalResult printOperation(CppEmitter &emitter,
817                                     emitc::LogicalNotOp logicalNotOp) {
818   Operation *operation = logicalNotOp.getOperation();
819   return printUnaryOperation(emitter, operation, "!");
820 }
821 
822 static LogicalResult printOperation(CppEmitter &emitter,
823                                     emitc::LogicalOrOp logicalOrOp) {
824   Operation *operation = logicalOrOp.getOperation();
825   return printBinaryOperation(emitter, operation, "||");
826 }
827 
828 static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
829 
830   raw_indented_ostream &os = emitter.ostream();
831 
832   // Utility function to determine whether a value is an expression that will be
833   // inlined, and as such should be wrapped in parentheses in order to guarantee
834   // its precedence and associativity.
835   auto requiresParentheses = [&](Value value) {
836     auto expressionOp =
837         dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
838     if (!expressionOp)
839       return false;
840     return shouldBeInlined(expressionOp);
841   };
842 
843   os << "for (";
844   if (failed(
845           emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
846     return failure();
847   os << " ";
848   os << emitter.getOrCreateName(forOp.getInductionVar());
849   os << " = ";
850   if (failed(emitter.emitOperand(forOp.getLowerBound())))
851     return failure();
852   os << "; ";
853   os << emitter.getOrCreateName(forOp.getInductionVar());
854   os << " < ";
855   Value upperBound = forOp.getUpperBound();
856   bool upperBoundRequiresParentheses = requiresParentheses(upperBound);
857   if (upperBoundRequiresParentheses)
858     os << "(";
859   if (failed(emitter.emitOperand(upperBound)))
860     return failure();
861   if (upperBoundRequiresParentheses)
862     os << ")";
863   os << "; ";
864   os << emitter.getOrCreateName(forOp.getInductionVar());
865   os << " += ";
866   if (failed(emitter.emitOperand(forOp.getStep())))
867     return failure();
868   os << ") {\n";
869   os.indent();
870 
871   Region &forRegion = forOp.getRegion();
872   auto regionOps = forRegion.getOps();
873 
874   // We skip the trailing yield op.
875   for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
876     if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
877       return failure();
878   }
879 
880   os.unindent() << "}";
881 
882   return success();
883 }
884 
885 static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
886   raw_indented_ostream &os = emitter.ostream();
887 
888   // Helper function to emit all ops except the last one, expected to be
889   // emitc::yield.
890   auto emitAllExceptLast = [&emitter](Region &region) {
891     Region::OpIterator it = region.op_begin(), end = region.op_end();
892     for (; std::next(it) != end; ++it) {
893       if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
894         return failure();
895     }
896     assert(isa<emitc::YieldOp>(*it) &&
897            "Expected last operation in the region to be emitc::yield");
898     return success();
899   };
900 
901   os << "if (";
902   if (failed(emitter.emitOperand(ifOp.getCondition())))
903     return failure();
904   os << ") {\n";
905   os.indent();
906   if (failed(emitAllExceptLast(ifOp.getThenRegion())))
907     return failure();
908   os.unindent() << "}";
909 
910   Region &elseRegion = ifOp.getElseRegion();
911   if (!elseRegion.empty()) {
912     os << " else {\n";
913     os.indent();
914     if (failed(emitAllExceptLast(elseRegion)))
915       return failure();
916     os.unindent() << "}";
917   }
918 
919   return success();
920 }
921 
922 static LogicalResult printOperation(CppEmitter &emitter,
923                                     func::ReturnOp returnOp) {
924   raw_ostream &os = emitter.ostream();
925   os << "return";
926   switch (returnOp.getNumOperands()) {
927   case 0:
928     return success();
929   case 1:
930     os << " ";
931     if (failed(emitter.emitOperand(returnOp.getOperand(0))))
932       return failure();
933     return success();
934   default:
935     os << " std::make_tuple(";
936     if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
937       return failure();
938     os << ")";
939     return success();
940   }
941 }
942 
943 static LogicalResult printOperation(CppEmitter &emitter,
944                                     emitc::ReturnOp returnOp) {
945   raw_ostream &os = emitter.ostream();
946   os << "return";
947   if (returnOp.getNumOperands() == 0)
948     return success();
949 
950   os << " ";
951   if (failed(emitter.emitOperand(returnOp.getOperand())))
952     return failure();
953   return success();
954 }
955 
956 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
957   CppEmitter::Scope scope(emitter);
958 
959   for (Operation &op : moduleOp) {
960     if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
961       return failure();
962   }
963   return success();
964 }
965 
966 static LogicalResult printFunctionArgs(CppEmitter &emitter,
967                                        Operation *functionOp,
968                                        ArrayRef<Type> arguments) {
969   raw_indented_ostream &os = emitter.ostream();
970 
971   return (
972       interleaveCommaWithError(arguments, os, [&](Type arg) -> LogicalResult {
973         return emitter.emitType(functionOp->getLoc(), arg);
974       }));
975 }
976 
977 static LogicalResult printFunctionArgs(CppEmitter &emitter,
978                                        Operation *functionOp,
979                                        Region::BlockArgListType arguments) {
980   raw_indented_ostream &os = emitter.ostream();
981 
982   return (interleaveCommaWithError(
983       arguments, os, [&](BlockArgument arg) -> LogicalResult {
984         return emitter.emitVariableDeclaration(
985             functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
986       }));
987 }
988 
989 static LogicalResult printFunctionBody(CppEmitter &emitter,
990                                        Operation *functionOp,
991                                        Region::BlockListType &blocks) {
992   raw_indented_ostream &os = emitter.ostream();
993   os.indent();
994 
995   if (emitter.shouldDeclareVariablesAtTop()) {
996     // Declare all variables that hold op results including those from nested
997     // regions.
998     WalkResult result =
999         functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
1000           if (isa<emitc::ExpressionOp>(op->getParentOp()) ||
1001               (isa<emitc::ExpressionOp>(op) &&
1002                shouldBeInlined(cast<emitc::ExpressionOp>(op))))
1003             return WalkResult::skip();
1004           for (OpResult result : op->getResults()) {
1005             if (failed(emitter.emitVariableDeclaration(
1006                     result, /*trailingSemicolon=*/true))) {
1007               return WalkResult(
1008                   op->emitError("unable to declare result variable for op"));
1009             }
1010           }
1011           return WalkResult::advance();
1012         });
1013     if (result.wasInterrupted())
1014       return failure();
1015   }
1016 
1017   // Create label names for basic blocks.
1018   for (Block &block : blocks) {
1019     emitter.getOrCreateName(block);
1020   }
1021 
1022   // Declare variables for basic block arguments.
1023   for (Block &block : llvm::drop_begin(blocks)) {
1024     for (BlockArgument &arg : block.getArguments()) {
1025       if (emitter.hasValueInScope(arg))
1026         return functionOp->emitOpError(" block argument #")
1027                << arg.getArgNumber() << " is out of scope";
1028       if (isa<ArrayType, LValueType>(arg.getType()))
1029         return functionOp->emitOpError("cannot emit block argument #")
1030                << arg.getArgNumber() << " with type " << arg.getType();
1031       if (failed(
1032               emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
1033         return failure();
1034       }
1035       os << " " << emitter.getOrCreateName(arg) << ";\n";
1036     }
1037   }
1038 
1039   for (Block &block : blocks) {
1040     // Only print a label if the block has predecessors.
1041     if (!block.hasNoPredecessors()) {
1042       if (failed(emitter.emitLabel(block)))
1043         return failure();
1044     }
1045     for (Operation &op : block.getOperations()) {
1046       if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
1047         return failure();
1048     }
1049   }
1050 
1051   os.unindent();
1052 
1053   return success();
1054 }
1055 
1056 static LogicalResult printOperation(CppEmitter &emitter,
1057                                     func::FuncOp functionOp) {
1058   // We need to declare variables at top if the function has multiple blocks.
1059   if (!emitter.shouldDeclareVariablesAtTop() &&
1060       functionOp.getBlocks().size() > 1) {
1061     return functionOp.emitOpError(
1062         "with multiple blocks needs variables declared at top");
1063   }
1064 
1065   if (llvm::any_of(functionOp.getArgumentTypes(), llvm::IsaPred<LValueType>)) {
1066     return functionOp.emitOpError()
1067            << "cannot emit lvalue type as argument type";
1068   }
1069 
1070   if (llvm::any_of(functionOp.getResultTypes(), llvm::IsaPred<ArrayType>)) {
1071     return functionOp.emitOpError() << "cannot emit array type as result type";
1072   }
1073 
1074   CppEmitter::Scope scope(emitter);
1075   raw_indented_ostream &os = emitter.ostream();
1076   if (failed(emitter.emitTypes(functionOp.getLoc(),
1077                                functionOp.getFunctionType().getResults())))
1078     return failure();
1079   os << " " << functionOp.getName();
1080 
1081   os << "(";
1082   Operation *operation = functionOp.getOperation();
1083   if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1084     return failure();
1085   os << ") {\n";
1086   if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
1087     return failure();
1088   os << "}\n";
1089 
1090   return success();
1091 }
1092 
1093 static LogicalResult printOperation(CppEmitter &emitter,
1094                                     emitc::FuncOp functionOp) {
1095   // We need to declare variables at top if the function has multiple blocks.
1096   if (!emitter.shouldDeclareVariablesAtTop() &&
1097       functionOp.getBlocks().size() > 1) {
1098     return functionOp.emitOpError(
1099         "with multiple blocks needs variables declared at top");
1100   }
1101 
1102   CppEmitter::Scope scope(emitter);
1103   raw_indented_ostream &os = emitter.ostream();
1104   if (functionOp.getSpecifiers()) {
1105     for (Attribute specifier : functionOp.getSpecifiersAttr()) {
1106       os << cast<StringAttr>(specifier).str() << " ";
1107     }
1108   }
1109 
1110   if (failed(emitter.emitTypes(functionOp.getLoc(),
1111                                functionOp.getFunctionType().getResults())))
1112     return failure();
1113   os << " " << functionOp.getName();
1114 
1115   os << "(";
1116   Operation *operation = functionOp.getOperation();
1117   if (functionOp.isExternal()) {
1118     if (failed(printFunctionArgs(emitter, operation,
1119                                  functionOp.getArgumentTypes())))
1120       return failure();
1121     os << ");";
1122     return success();
1123   }
1124   if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1125     return failure();
1126   os << ") {\n";
1127   if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
1128     return failure();
1129   os << "}\n";
1130 
1131   return success();
1132 }
1133 
1134 static LogicalResult printOperation(CppEmitter &emitter,
1135                                     DeclareFuncOp declareFuncOp) {
1136   CppEmitter::Scope scope(emitter);
1137   raw_indented_ostream &os = emitter.ostream();
1138 
1139   auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
1140       declareFuncOp, declareFuncOp.getSymNameAttr());
1141 
1142   if (!functionOp)
1143     return failure();
1144 
1145   if (functionOp.getSpecifiers()) {
1146     for (Attribute specifier : functionOp.getSpecifiersAttr()) {
1147       os << cast<StringAttr>(specifier).str() << " ";
1148     }
1149   }
1150 
1151   if (failed(emitter.emitTypes(functionOp.getLoc(),
1152                                functionOp.getFunctionType().getResults())))
1153     return failure();
1154   os << " " << functionOp.getName();
1155 
1156   os << "(";
1157   Operation *operation = functionOp.getOperation();
1158   if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1159     return failure();
1160   os << ");";
1161 
1162   return success();
1163 }
1164 
1165 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
1166     : os(os), declareVariablesAtTop(declareVariablesAtTop) {
1167   valueInScopeCount.push(0);
1168   labelInScopeCount.push(0);
1169 }
1170 
1171 std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
1172   std::string out;
1173   llvm::raw_string_ostream ss(out);
1174   ss << getOrCreateName(op.getValue());
1175   for (auto index : op.getIndices()) {
1176     ss << "[" << getOrCreateName(index) << "]";
1177   }
1178   return out;
1179 }
1180 
1181 std::string CppEmitter::createMemberAccess(emitc::MemberOp op) {
1182   std::string out;
1183   llvm::raw_string_ostream ss(out);
1184   ss << getOrCreateName(op.getOperand());
1185   ss << "." << op.getMember();
1186   return out;
1187 }
1188 
1189 std::string CppEmitter::createMemberAccess(emitc::MemberOfPtrOp op) {
1190   std::string out;
1191   llvm::raw_string_ostream ss(out);
1192   ss << getOrCreateName(op.getOperand());
1193   ss << "->" << op.getMember();
1194   return out;
1195 }
1196 
1197 void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) {
1198   if (!valueMapper.count(value))
1199     valueMapper.insert(value, str.str());
1200 }
1201 
1202 /// Return the existing or a new name for a Value.
1203 StringRef CppEmitter::getOrCreateName(Value val) {
1204   if (!valueMapper.count(val)) {
1205     assert(!hasDeferredEmission(val.getDefiningOp()) &&
1206            "cacheDeferredOpResult should have been called on this value, "
1207            "update the emitOperation function.");
1208     valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
1209   }
1210   return *valueMapper.begin(val);
1211 }
1212 
1213 /// Return the existing or a new label for a Block.
1214 StringRef CppEmitter::getOrCreateName(Block &block) {
1215   if (!blockMapper.count(&block))
1216     blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
1217   return *blockMapper.begin(&block);
1218 }
1219 
1220 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
1221   switch (val) {
1222   case IntegerType::Signless:
1223     return false;
1224   case IntegerType::Signed:
1225     return false;
1226   case IntegerType::Unsigned:
1227     return true;
1228   }
1229   llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
1230 }
1231 
1232 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
1233 
1234 bool CppEmitter::hasBlockLabel(Block &block) {
1235   return blockMapper.count(&block);
1236 }
1237 
1238 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
1239   auto printInt = [&](const APInt &val, bool isUnsigned) {
1240     if (val.getBitWidth() == 1) {
1241       if (val.getBoolValue())
1242         os << "true";
1243       else
1244         os << "false";
1245     } else {
1246       SmallString<128> strValue;
1247       val.toString(strValue, 10, !isUnsigned, false);
1248       os << strValue;
1249     }
1250   };
1251 
1252   auto printFloat = [&](const APFloat &val) {
1253     if (val.isFinite()) {
1254       SmallString<128> strValue;
1255       // Use default values of toString except don't truncate zeros.
1256       val.toString(strValue, 0, 0, false);
1257       os << strValue;
1258       switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
1259       case llvm::APFloatBase::S_IEEEhalf:
1260         os << "f16";
1261         break;
1262       case llvm::APFloatBase::S_BFloat:
1263         os << "bf16";
1264         break;
1265       case llvm::APFloatBase::S_IEEEsingle:
1266         os << "f";
1267         break;
1268       case llvm::APFloatBase::S_IEEEdouble:
1269         break;
1270       default:
1271         llvm_unreachable("unsupported floating point type");
1272       };
1273     } else if (val.isNaN()) {
1274       os << "NAN";
1275     } else if (val.isInfinity()) {
1276       if (val.isNegative())
1277         os << "-";
1278       os << "INFINITY";
1279     }
1280   };
1281 
1282   // Print floating point attributes.
1283   if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
1284     if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
1285             fAttr.getType())) {
1286       return emitError(
1287           loc, "expected floating point attribute to be f16, bf16, f32 or f64");
1288     }
1289     printFloat(fAttr.getValue());
1290     return success();
1291   }
1292   if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
1293     if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
1294             dense.getElementType())) {
1295       return emitError(
1296           loc, "expected floating point attribute to be f16, bf16, f32 or f64");
1297     }
1298     os << '{';
1299     interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
1300     os << '}';
1301     return success();
1302   }
1303 
1304   // Print integer attributes.
1305   if (auto iAttr = dyn_cast<IntegerAttr>(attr)) {
1306     if (auto iType = dyn_cast<IntegerType>(iAttr.getType())) {
1307       printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
1308       return success();
1309     }
1310     if (auto iType = dyn_cast<IndexType>(iAttr.getType())) {
1311       printInt(iAttr.getValue(), false);
1312       return success();
1313     }
1314   }
1315   if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
1316     if (auto iType = dyn_cast<IntegerType>(
1317             cast<TensorType>(dense.getType()).getElementType())) {
1318       os << '{';
1319       interleaveComma(dense, os, [&](const APInt &val) {
1320         printInt(val, shouldMapToUnsigned(iType.getSignedness()));
1321       });
1322       os << '}';
1323       return success();
1324     }
1325     if (auto iType = dyn_cast<IndexType>(
1326             cast<TensorType>(dense.getType()).getElementType())) {
1327       os << '{';
1328       interleaveComma(dense, os,
1329                       [&](const APInt &val) { printInt(val, false); });
1330       os << '}';
1331       return success();
1332     }
1333   }
1334 
1335   // Print opaque attributes.
1336   if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
1337     os << oAttr.getValue();
1338     return success();
1339   }
1340 
1341   // Print symbolic reference attributes.
1342   if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
1343     if (sAttr.getNestedReferences().size() > 1)
1344       return emitError(loc, "attribute has more than 1 nested reference");
1345     os << sAttr.getRootReference().getValue();
1346     return success();
1347   }
1348 
1349   // Print type attributes.
1350   if (auto type = dyn_cast<TypeAttr>(attr))
1351     return emitType(loc, type.getValue());
1352 
1353   return emitError(loc, "cannot emit attribute: ") << attr;
1354 }
1355 
1356 LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
1357   assert(emittedExpressionPrecedence.empty() &&
1358          "Expected precedence stack to be empty");
1359   Operation *rootOp = expressionOp.getRootOp();
1360 
1361   emittedExpression = expressionOp;
1362   FailureOr<int> precedence = getOperatorPrecedence(rootOp);
1363   if (failed(precedence))
1364     return failure();
1365   pushExpressionPrecedence(precedence.value());
1366 
1367   if (failed(emitOperation(*rootOp, /*trailingSemicolon=*/false)))
1368     return failure();
1369 
1370   popExpressionPrecedence();
1371   assert(emittedExpressionPrecedence.empty() &&
1372          "Expected precedence stack to be empty");
1373   emittedExpression = nullptr;
1374 
1375   return success();
1376 }
1377 
1378 LogicalResult CppEmitter::emitOperand(Value value) {
1379   if (isPartOfCurrentExpression(value)) {
1380     Operation *def = value.getDefiningOp();
1381     assert(def && "Expected operand to be defined by an operation");
1382     FailureOr<int> precedence = getOperatorPrecedence(def);
1383     if (failed(precedence))
1384       return failure();
1385 
1386     // Sub-expressions with equal or lower precedence need to be parenthesized,
1387     // as they might be evaluated in the wrong order depending on the shape of
1388     // the expression tree.
1389     bool encloseInParenthesis = precedence.value() <= getExpressionPrecedence();
1390     if (encloseInParenthesis) {
1391       os << "(";
1392       pushExpressionPrecedence(lowestPrecedence());
1393     } else
1394       pushExpressionPrecedence(precedence.value());
1395 
1396     if (failed(emitOperation(*def, /*trailingSemicolon=*/false)))
1397       return failure();
1398 
1399     if (encloseInParenthesis)
1400       os << ")";
1401 
1402     popExpressionPrecedence();
1403     return success();
1404   }
1405 
1406   auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
1407   if (expressionOp && shouldBeInlined(expressionOp))
1408     return emitExpression(expressionOp);
1409 
1410   os << getOrCreateName(value);
1411   return success();
1412 }
1413 
1414 LogicalResult CppEmitter::emitOperands(Operation &op) {
1415   return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) {
1416     // If an expression is being emitted, push lowest precedence as these
1417     // operands are either wrapped by parenthesis.
1418     if (getEmittedExpression())
1419       pushExpressionPrecedence(lowestPrecedence());
1420     if (failed(emitOperand(operand)))
1421       return failure();
1422     if (getEmittedExpression())
1423       popExpressionPrecedence();
1424     return success();
1425   });
1426 }
1427 
1428 LogicalResult
1429 CppEmitter::emitOperandsAndAttributes(Operation &op,
1430                                       ArrayRef<StringRef> exclude) {
1431   if (failed(emitOperands(op)))
1432     return failure();
1433   // Insert comma in between operands and non-filtered attributes if needed.
1434   if (op.getNumOperands() > 0) {
1435     for (NamedAttribute attr : op.getAttrs()) {
1436       if (!llvm::is_contained(exclude, attr.getName().strref())) {
1437         os << ", ";
1438         break;
1439       }
1440     }
1441   }
1442   // Emit attributes.
1443   auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
1444     if (llvm::is_contained(exclude, attr.getName().strref()))
1445       return success();
1446     os << "/* " << attr.getName().getValue() << " */";
1447     if (failed(emitAttribute(op.getLoc(), attr.getValue())))
1448       return failure();
1449     return success();
1450   };
1451   return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
1452 }
1453 
1454 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
1455   if (!hasValueInScope(result)) {
1456     return result.getDefiningOp()->emitOpError(
1457         "result variable for the operation has not been declared");
1458   }
1459   os << getOrCreateName(result) << " = ";
1460   return success();
1461 }
1462 
1463 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
1464                                                   bool trailingSemicolon) {
1465   if (hasDeferredEmission(result.getDefiningOp()))
1466     return success();
1467   if (hasValueInScope(result)) {
1468     return result.getDefiningOp()->emitError(
1469         "result variable for the operation already declared");
1470   }
1471   if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
1472                                      result.getType(),
1473                                      getOrCreateName(result))))
1474     return failure();
1475   if (trailingSemicolon)
1476     os << ";\n";
1477   return success();
1478 }
1479 
1480 LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
1481   if (op.getExternSpecifier())
1482     os << "extern ";
1483   else if (op.getStaticSpecifier())
1484     os << "static ";
1485   if (op.getConstSpecifier())
1486     os << "const ";
1487 
1488   if (failed(emitVariableDeclaration(op->getLoc(), op.getType(),
1489                                      op.getSymName()))) {
1490     return failure();
1491   }
1492 
1493   std::optional<Attribute> initialValue = op.getInitialValue();
1494   if (initialValue) {
1495     os << " = ";
1496     if (failed(emitAttribute(op->getLoc(), *initialValue)))
1497       return failure();
1498   }
1499 
1500   os << ";";
1501   return success();
1502 }
1503 
1504 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
1505   // If op is being emitted as part of an expression, bail out.
1506   if (getEmittedExpression())
1507     return success();
1508 
1509   switch (op.getNumResults()) {
1510   case 0:
1511     break;
1512   case 1: {
1513     OpResult result = op.getResult(0);
1514     if (shouldDeclareVariablesAtTop()) {
1515       if (failed(emitVariableAssignment(result)))
1516         return failure();
1517     } else {
1518       if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
1519         return failure();
1520       os << " = ";
1521     }
1522     break;
1523   }
1524   default:
1525     if (!shouldDeclareVariablesAtTop()) {
1526       for (OpResult result : op.getResults()) {
1527         if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
1528           return failure();
1529       }
1530     }
1531     os << "std::tie(";
1532     interleaveComma(op.getResults(), os,
1533                     [&](Value result) { os << getOrCreateName(result); });
1534     os << ") = ";
1535   }
1536   return success();
1537 }
1538 
1539 LogicalResult CppEmitter::emitLabel(Block &block) {
1540   if (!hasBlockLabel(block))
1541     return block.getParentOp()->emitError("label for block not found");
1542   // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
1543   // label instead of using `getOStream`.
1544   os.getOStream() << getOrCreateName(block) << ":\n";
1545   return success();
1546 }
1547 
1548 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
1549   LogicalResult status =
1550       llvm::TypeSwitch<Operation *, LogicalResult>(&op)
1551           // Builtin ops.
1552           .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
1553           // CF ops.
1554           .Case<cf::BranchOp, cf::CondBranchOp>(
1555               [&](auto op) { return printOperation(*this, op); })
1556           // EmitC ops.
1557           .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,
1558                 emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp,
1559                 emitc::BitwiseNotOp, emitc::BitwiseOrOp,
1560                 emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
1561                 emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
1562                 emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
1563                 emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
1564                 emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LoadOp,
1565                 emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp,
1566                 emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
1567                 emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
1568                 emitc::VariableOp, emitc::VerbatimOp>(
1569               [&](auto op) { return printOperation(*this, op); })
1570           // Func ops.
1571           .Case<func::CallOp, func::FuncOp, func::ReturnOp>(
1572               [&](auto op) { return printOperation(*this, op); })
1573           .Case<emitc::GetGlobalOp>([&](auto op) {
1574             cacheDeferredOpResult(op.getResult(), op.getName());
1575             return success();
1576           })
1577           .Case<emitc::LiteralOp>([&](auto op) {
1578             cacheDeferredOpResult(op.getResult(), op.getValue());
1579             return success();
1580           })
1581           .Case<emitc::MemberOp>([&](auto op) {
1582             cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
1583             return success();
1584           })
1585           .Case<emitc::MemberOfPtrOp>([&](auto op) {
1586             cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
1587             return success();
1588           })
1589           .Case<emitc::SubscriptOp>([&](auto op) {
1590             cacheDeferredOpResult(op.getResult(), getSubscriptName(op));
1591             return success();
1592           })
1593           .Default([&](Operation *) {
1594             return op.emitOpError("unable to find printer for op");
1595           });
1596 
1597   if (failed(status))
1598     return failure();
1599 
1600   if (hasDeferredEmission(&op))
1601     return success();
1602 
1603   if (getEmittedExpression() ||
1604       (isa<emitc::ExpressionOp>(op) &&
1605        shouldBeInlined(cast<emitc::ExpressionOp>(op))))
1606     return success();
1607 
1608   // Never emit a semicolon for some operations, especially if endening with
1609   // `}`.
1610   trailingSemicolon &=
1611       !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp, emitc::IfOp,
1612            emitc::IncludeOp, emitc::SwitchOp, emitc::VerbatimOp>(op);
1613 
1614   os << (trailingSemicolon ? ";\n" : "\n");
1615 
1616   return success();
1617 }
1618 
1619 LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
1620                                                   StringRef name) {
1621   if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
1622     if (failed(emitType(loc, arrType.getElementType())))
1623       return failure();
1624     os << " " << name;
1625     for (auto dim : arrType.getShape()) {
1626       os << "[" << dim << "]";
1627     }
1628     return success();
1629   }
1630   if (failed(emitType(loc, type)))
1631     return failure();
1632   os << " " << name;
1633   return success();
1634 }
1635 
1636 LogicalResult CppEmitter::emitType(Location loc, Type type) {
1637   if (auto iType = dyn_cast<IntegerType>(type)) {
1638     switch (iType.getWidth()) {
1639     case 1:
1640       return (os << "bool"), success();
1641     case 8:
1642     case 16:
1643     case 32:
1644     case 64:
1645       if (shouldMapToUnsigned(iType.getSignedness()))
1646         return (os << "uint" << iType.getWidth() << "_t"), success();
1647       else
1648         return (os << "int" << iType.getWidth() << "_t"), success();
1649     default:
1650       return emitError(loc, "cannot emit integer type ") << type;
1651     }
1652   }
1653   if (auto fType = dyn_cast<FloatType>(type)) {
1654     switch (fType.getWidth()) {
1655     case 16: {
1656       if (llvm::isa<Float16Type>(type))
1657         return (os << "_Float16"), success();
1658       else if (llvm::isa<BFloat16Type>(type))
1659         return (os << "__bf16"), success();
1660       else
1661         return emitError(loc, "cannot emit float type ") << type;
1662     }
1663     case 32:
1664       return (os << "float"), success();
1665     case 64:
1666       return (os << "double"), success();
1667     default:
1668       return emitError(loc, "cannot emit float type ") << type;
1669     }
1670   }
1671   if (auto iType = dyn_cast<IndexType>(type))
1672     return (os << "size_t"), success();
1673   if (auto sType = dyn_cast<emitc::SizeTType>(type))
1674     return (os << "size_t"), success();
1675   if (auto sType = dyn_cast<emitc::SignedSizeTType>(type))
1676     return (os << "ssize_t"), success();
1677   if (auto pType = dyn_cast<emitc::PtrDiffTType>(type))
1678     return (os << "ptrdiff_t"), success();
1679   if (auto tType = dyn_cast<TensorType>(type)) {
1680     if (!tType.hasRank())
1681       return emitError(loc, "cannot emit unranked tensor type");
1682     if (!tType.hasStaticShape())
1683       return emitError(loc, "cannot emit tensor type with non static shape");
1684     os << "Tensor<";
1685     if (isa<ArrayType>(tType.getElementType()))
1686       return emitError(loc, "cannot emit tensor of array type ") << type;
1687     if (failed(emitType(loc, tType.getElementType())))
1688       return failure();
1689     auto shape = tType.getShape();
1690     for (auto dimSize : shape) {
1691       os << ", ";
1692       os << dimSize;
1693     }
1694     os << ">";
1695     return success();
1696   }
1697   if (auto tType = dyn_cast<TupleType>(type))
1698     return emitTupleType(loc, tType.getTypes());
1699   if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
1700     os << oType.getValue();
1701     return success();
1702   }
1703   if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
1704     if (failed(emitType(loc, aType.getElementType())))
1705       return failure();
1706     for (auto dim : aType.getShape())
1707       os << "[" << dim << "]";
1708     return success();
1709   }
1710   if (auto lType = dyn_cast<emitc::LValueType>(type))
1711     return emitType(loc, lType.getValueType());
1712   if (auto pType = dyn_cast<emitc::PointerType>(type)) {
1713     if (isa<ArrayType>(pType.getPointee()))
1714       return emitError(loc, "cannot emit pointer to array type ") << type;
1715     if (failed(emitType(loc, pType.getPointee())))
1716       return failure();
1717     os << "*";
1718     return success();
1719   }
1720   return emitError(loc, "cannot emit type ") << type;
1721 }
1722 
1723 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
1724   switch (types.size()) {
1725   case 0:
1726     os << "void";
1727     return success();
1728   case 1:
1729     return emitType(loc, types.front());
1730   default:
1731     return emitTupleType(loc, types);
1732   }
1733 }
1734 
1735 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
1736   if (llvm::any_of(types, llvm::IsaPred<ArrayType>)) {
1737     return emitError(loc, "cannot emit tuple of array type");
1738   }
1739   os << "std::tuple<";
1740   if (failed(interleaveCommaWithError(
1741           types, os, [&](Type type) { return emitType(loc, type); })))
1742     return failure();
1743   os << ">";
1744   return success();
1745 }
1746 
1747 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
1748                                     bool declareVariablesAtTop) {
1749   CppEmitter emitter(os, declareVariablesAtTop);
1750   return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
1751 }
1752