1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 10 #include "mlir/Dialect/EmitC/IR/EmitC.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/IR/BuiltinOps.h" 13 #include "mlir/IR/BuiltinTypes.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/SymbolTable.h" 17 #include "mlir/Support/IndentedOstream.h" 18 #include "mlir/Support/LLVM.h" 19 #include "mlir/Target/Cpp/CppEmitter.h" 20 #include "llvm/ADT/DenseMap.h" 21 #include "llvm/ADT/ScopedHashTable.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringMap.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 #include "llvm/Support/Debug.h" 26 #include "llvm/Support/FormatVariadic.h" 27 #include <stack> 28 #include <utility> 29 30 #define DEBUG_TYPE "translate-to-cpp" 31 32 using namespace mlir; 33 using namespace mlir::emitc; 34 using llvm::formatv; 35 36 /// Convenience functions to produce interleaved output with functions returning 37 /// a LogicalResult. This is different than those in STLExtras as functions used 38 /// on each element doesn't return a string. 39 template <typename ForwardIterator, typename UnaryFunctor, 40 typename NullaryFunctor> 41 inline LogicalResult 42 interleaveWithError(ForwardIterator begin, ForwardIterator end, 43 UnaryFunctor eachFn, NullaryFunctor betweenFn) { 44 if (begin == end) 45 return success(); 46 if (failed(eachFn(*begin))) 47 return failure(); 48 ++begin; 49 for (; begin != end; ++begin) { 50 betweenFn(); 51 if (failed(eachFn(*begin))) 52 return failure(); 53 } 54 return success(); 55 } 56 57 template <typename Container, typename UnaryFunctor, typename NullaryFunctor> 58 inline LogicalResult interleaveWithError(const Container &c, 59 UnaryFunctor eachFn, 60 NullaryFunctor betweenFn) { 61 return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn); 62 } 63 64 template <typename Container, typename UnaryFunctor> 65 inline LogicalResult interleaveCommaWithError(const Container &c, 66 raw_ostream &os, 67 UnaryFunctor eachFn) { 68 return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; }); 69 } 70 71 /// Return the precedence of a operator as an integer, higher values 72 /// imply higher precedence. 73 static FailureOr<int> getOperatorPrecedence(Operation *operation) { 74 return llvm::TypeSwitch<Operation *, FailureOr<int>>(operation) 75 .Case<emitc::AddOp>([&](auto op) { return 12; }) 76 .Case<emitc::ApplyOp>([&](auto op) { return 15; }) 77 .Case<emitc::BitwiseAndOp>([&](auto op) { return 7; }) 78 .Case<emitc::BitwiseLeftShiftOp>([&](auto op) { return 11; }) 79 .Case<emitc::BitwiseNotOp>([&](auto op) { return 15; }) 80 .Case<emitc::BitwiseOrOp>([&](auto op) { return 5; }) 81 .Case<emitc::BitwiseRightShiftOp>([&](auto op) { return 11; }) 82 .Case<emitc::BitwiseXorOp>([&](auto op) { return 6; }) 83 .Case<emitc::CallOp>([&](auto op) { return 16; }) 84 .Case<emitc::CallOpaqueOp>([&](auto op) { return 16; }) 85 .Case<emitc::CastOp>([&](auto op) { return 15; }) 86 .Case<emitc::CmpOp>([&](auto op) -> FailureOr<int> { 87 switch (op.getPredicate()) { 88 case emitc::CmpPredicate::eq: 89 case emitc::CmpPredicate::ne: 90 return 8; 91 case emitc::CmpPredicate::lt: 92 case emitc::CmpPredicate::le: 93 case emitc::CmpPredicate::gt: 94 case emitc::CmpPredicate::ge: 95 return 9; 96 case emitc::CmpPredicate::three_way: 97 return 10; 98 } 99 return op->emitError("unsupported cmp predicate"); 100 }) 101 .Case<emitc::ConditionalOp>([&](auto op) { return 2; }) 102 .Case<emitc::DivOp>([&](auto op) { return 13; }) 103 .Case<emitc::LogicalAndOp>([&](auto op) { return 4; }) 104 .Case<emitc::LogicalNotOp>([&](auto op) { return 15; }) 105 .Case<emitc::LogicalOrOp>([&](auto op) { return 3; }) 106 .Case<emitc::MulOp>([&](auto op) { return 13; }) 107 .Case<emitc::RemOp>([&](auto op) { return 13; }) 108 .Case<emitc::SubOp>([&](auto op) { return 12; }) 109 .Case<emitc::UnaryMinusOp>([&](auto op) { return 15; }) 110 .Case<emitc::UnaryPlusOp>([&](auto op) { return 15; }) 111 .Default([](auto op) { return op->emitError("unsupported operation"); }); 112 } 113 114 namespace { 115 /// Emitter that uses dialect specific emitters to emit C++ code. 116 struct CppEmitter { 117 explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop); 118 119 /// Emits attribute or returns failure. 120 LogicalResult emitAttribute(Location loc, Attribute attr); 121 122 /// Emits operation 'op' with/without training semicolon or returns failure. 123 /// 124 /// For operations that should never be followed by a semicolon, like ForOp, 125 /// the `trailingSemicolon` argument is ignored and a semicolon is not 126 /// emitted. 127 LogicalResult emitOperation(Operation &op, bool trailingSemicolon); 128 129 /// Emits type 'type' or returns failure. 130 LogicalResult emitType(Location loc, Type type); 131 132 /// Emits array of types as a std::tuple of the emitted types. 133 /// - emits void for an empty array; 134 /// - emits the type of the only element for arrays of size one; 135 /// - emits a std::tuple otherwise; 136 LogicalResult emitTypes(Location loc, ArrayRef<Type> types); 137 138 /// Emits array of types as a std::tuple of the emitted types independently of 139 /// the array size. 140 LogicalResult emitTupleType(Location loc, ArrayRef<Type> types); 141 142 /// Emits an assignment for a variable which has been declared previously. 143 LogicalResult emitVariableAssignment(OpResult result); 144 145 /// Emits a variable declaration for a result of an operation. 146 LogicalResult emitVariableDeclaration(OpResult result, 147 bool trailingSemicolon); 148 149 /// Emits a declaration of a variable with the given type and name. 150 LogicalResult emitVariableDeclaration(Location loc, Type type, 151 StringRef name); 152 153 /// Emits the variable declaration and assignment prefix for 'op'. 154 /// - emits separate variable followed by std::tie for multi-valued operation; 155 /// - emits single type followed by variable for single result; 156 /// - emits nothing if no value produced by op; 157 /// Emits final '=' operator where a type is produced. Returns failure if 158 /// any result type could not be converted. 159 LogicalResult emitAssignPrefix(Operation &op); 160 161 /// Emits a global variable declaration or definition. 162 LogicalResult emitGlobalVariable(GlobalOp op); 163 164 /// Emits a label for the block. 165 LogicalResult emitLabel(Block &block); 166 167 /// Emits the operands and atttributes of the operation. All operands are 168 /// emitted first and then all attributes in alphabetical order. 169 LogicalResult emitOperandsAndAttributes(Operation &op, 170 ArrayRef<StringRef> exclude = {}); 171 172 /// Emits the operands of the operation. All operands are emitted in order. 173 LogicalResult emitOperands(Operation &op); 174 175 /// Emits value as an operands of an operation 176 LogicalResult emitOperand(Value value); 177 178 /// Emit an expression as a C expression. 179 LogicalResult emitExpression(ExpressionOp expressionOp); 180 181 /// Insert the expression representing the operation into the value cache. 182 void cacheDeferredOpResult(Value value, StringRef str); 183 184 /// Return the existing or a new name for a Value. 185 StringRef getOrCreateName(Value val); 186 187 // Returns the textual representation of a subscript operation. 188 std::string getSubscriptName(emitc::SubscriptOp op); 189 190 // Returns the textual representation of a member (of object) operation. 191 std::string createMemberAccess(emitc::MemberOp op); 192 193 // Returns the textual representation of a member of pointer operation. 194 std::string createMemberAccess(emitc::MemberOfPtrOp op); 195 196 /// Return the existing or a new label of a Block. 197 StringRef getOrCreateName(Block &block); 198 199 /// Whether to map an mlir integer to a unsigned integer in C++. 200 bool shouldMapToUnsigned(IntegerType::SignednessSemantics val); 201 202 /// RAII helper function to manage entering/exiting C++ scopes. 203 struct Scope { 204 Scope(CppEmitter &emitter) 205 : valueMapperScope(emitter.valueMapper), 206 blockMapperScope(emitter.blockMapper), emitter(emitter) { 207 emitter.valueInScopeCount.push(emitter.valueInScopeCount.top()); 208 emitter.labelInScopeCount.push(emitter.labelInScopeCount.top()); 209 } 210 ~Scope() { 211 emitter.valueInScopeCount.pop(); 212 emitter.labelInScopeCount.pop(); 213 } 214 215 private: 216 llvm::ScopedHashTableScope<Value, std::string> valueMapperScope; 217 llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope; 218 CppEmitter &emitter; 219 }; 220 221 /// Returns wether the Value is assigned to a C++ variable in the scope. 222 bool hasValueInScope(Value val); 223 224 // Returns whether a label is assigned to the block. 225 bool hasBlockLabel(Block &block); 226 227 /// Returns the output stream. 228 raw_indented_ostream &ostream() { return os; }; 229 230 /// Returns if all variables for op results and basic block arguments need to 231 /// be declared at the beginning of a function. 232 bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; }; 233 234 /// Get expression currently being emitted. 235 ExpressionOp getEmittedExpression() { return emittedExpression; } 236 237 /// Determine whether given value is part of the expression potentially being 238 /// emitted. 239 bool isPartOfCurrentExpression(Value value) { 240 if (!emittedExpression) 241 return false; 242 Operation *def = value.getDefiningOp(); 243 if (!def) 244 return false; 245 auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp()); 246 return operandExpression == emittedExpression; 247 }; 248 249 private: 250 using ValueMapper = llvm::ScopedHashTable<Value, std::string>; 251 using BlockMapper = llvm::ScopedHashTable<Block *, std::string>; 252 253 /// Output stream to emit to. 254 raw_indented_ostream os; 255 256 /// Boolean to enforce that all variables for op results and block 257 /// arguments are declared at the beginning of the function. This also 258 /// includes results from ops located in nested regions. 259 bool declareVariablesAtTop; 260 261 /// Map from value to name of C++ variable that contain the name. 262 ValueMapper valueMapper; 263 264 /// Map from block to name of C++ label. 265 BlockMapper blockMapper; 266 267 /// The number of values in the current scope. This is used to declare the 268 /// names of values in a scope. 269 std::stack<int64_t> valueInScopeCount; 270 std::stack<int64_t> labelInScopeCount; 271 272 /// State of the current expression being emitted. 273 ExpressionOp emittedExpression; 274 SmallVector<int> emittedExpressionPrecedence; 275 276 void pushExpressionPrecedence(int precedence) { 277 emittedExpressionPrecedence.push_back(precedence); 278 } 279 void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); } 280 static int lowestPrecedence() { return 0; } 281 int getExpressionPrecedence() { 282 if (emittedExpressionPrecedence.empty()) 283 return lowestPrecedence(); 284 return emittedExpressionPrecedence.back(); 285 } 286 }; 287 } // namespace 288 289 /// Determine whether expression \p op should be emitted in a deferred way. 290 static bool hasDeferredEmission(Operation *op) { 291 return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp, 292 emitc::MemberOfPtrOp, emitc::SubscriptOp>(op); 293 } 294 295 /// Determine whether expression \p expressionOp should be emitted inline, i.e. 296 /// as part of its user. This function recommends inlining of any expressions 297 /// that can be inlined unless it is used by another expression, under the 298 /// assumption that any expression fusion/re-materialization was taken care of 299 /// by transformations run by the backend. 300 static bool shouldBeInlined(ExpressionOp expressionOp) { 301 // Do not inline if expression is marked as such. 302 if (expressionOp.getDoNotInline()) 303 return false; 304 305 // Do not inline expressions with side effects to prevent side-effect 306 // reordering. 307 if (expressionOp.hasSideEffects()) 308 return false; 309 310 // Do not inline expressions with multiple uses. 311 Value result = expressionOp.getResult(); 312 if (!result.hasOneUse()) 313 return false; 314 315 Operation *user = *result.getUsers().begin(); 316 317 // Do not inline expressions used by operations with deferred emission, since 318 // their translation requires the materialization of variables. 319 if (hasDeferredEmission(user)) 320 return false; 321 322 // Do not inline expressions used by ops with the CExpression trait. If this 323 // was intended, the user could have been merged into the expression op. 324 return !user->hasTrait<OpTrait::emitc::CExpression>(); 325 } 326 327 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, 328 Attribute value) { 329 OpResult result = operation->getResult(0); 330 331 // Only emit an assignment as the variable was already declared when printing 332 // the FuncOp. 333 if (emitter.shouldDeclareVariablesAtTop()) { 334 // Skip the assignment if the emitc.constant has no value. 335 if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) { 336 if (oAttr.getValue().empty()) 337 return success(); 338 } 339 340 if (failed(emitter.emitVariableAssignment(result))) 341 return failure(); 342 return emitter.emitAttribute(operation->getLoc(), value); 343 } 344 345 // Emit a variable declaration for an emitc.constant op without value. 346 if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) { 347 if (oAttr.getValue().empty()) 348 // The semicolon gets printed by the emitOperation function. 349 return emitter.emitVariableDeclaration(result, 350 /*trailingSemicolon=*/false); 351 } 352 353 // Emit a variable declaration. 354 if (failed(emitter.emitAssignPrefix(*operation))) 355 return failure(); 356 return emitter.emitAttribute(operation->getLoc(), value); 357 } 358 359 static LogicalResult printOperation(CppEmitter &emitter, 360 emitc::ConstantOp constantOp) { 361 Operation *operation = constantOp.getOperation(); 362 Attribute value = constantOp.getValue(); 363 364 return printConstantOp(emitter, operation, value); 365 } 366 367 static LogicalResult printOperation(CppEmitter &emitter, 368 emitc::VariableOp variableOp) { 369 Operation *operation = variableOp.getOperation(); 370 Attribute value = variableOp.getValue(); 371 372 return printConstantOp(emitter, operation, value); 373 } 374 375 static LogicalResult printOperation(CppEmitter &emitter, 376 emitc::GlobalOp globalOp) { 377 378 return emitter.emitGlobalVariable(globalOp); 379 } 380 381 static LogicalResult printOperation(CppEmitter &emitter, 382 emitc::AssignOp assignOp) { 383 OpResult result = assignOp.getVar().getDefiningOp()->getResult(0); 384 385 if (failed(emitter.emitVariableAssignment(result))) 386 return failure(); 387 388 return emitter.emitOperand(assignOp.getValue()); 389 } 390 391 static LogicalResult printOperation(CppEmitter &emitter, emitc::LoadOp loadOp) { 392 if (failed(emitter.emitAssignPrefix(*loadOp))) 393 return failure(); 394 395 return emitter.emitOperand(loadOp.getOperand()); 396 } 397 398 static LogicalResult printBinaryOperation(CppEmitter &emitter, 399 Operation *operation, 400 StringRef binaryOperator) { 401 raw_ostream &os = emitter.ostream(); 402 403 if (failed(emitter.emitAssignPrefix(*operation))) 404 return failure(); 405 406 if (failed(emitter.emitOperand(operation->getOperand(0)))) 407 return failure(); 408 409 os << " " << binaryOperator << " "; 410 411 if (failed(emitter.emitOperand(operation->getOperand(1)))) 412 return failure(); 413 414 return success(); 415 } 416 417 static LogicalResult printUnaryOperation(CppEmitter &emitter, 418 Operation *operation, 419 StringRef unaryOperator) { 420 raw_ostream &os = emitter.ostream(); 421 422 if (failed(emitter.emitAssignPrefix(*operation))) 423 return failure(); 424 425 os << unaryOperator; 426 427 if (failed(emitter.emitOperand(operation->getOperand(0)))) 428 return failure(); 429 430 return success(); 431 } 432 433 static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) { 434 Operation *operation = addOp.getOperation(); 435 436 return printBinaryOperation(emitter, operation, "+"); 437 } 438 439 static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) { 440 Operation *operation = divOp.getOperation(); 441 442 return printBinaryOperation(emitter, operation, "/"); 443 } 444 445 static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) { 446 Operation *operation = mulOp.getOperation(); 447 448 return printBinaryOperation(emitter, operation, "*"); 449 } 450 451 static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) { 452 Operation *operation = remOp.getOperation(); 453 454 return printBinaryOperation(emitter, operation, "%"); 455 } 456 457 static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) { 458 Operation *operation = subOp.getOperation(); 459 460 return printBinaryOperation(emitter, operation, "-"); 461 } 462 463 static LogicalResult emitSwitchCase(CppEmitter &emitter, 464 raw_indented_ostream &os, Region ®ion) { 465 for (Region::OpIterator iteratorOp = region.op_begin(), end = region.op_end(); 466 std::next(iteratorOp) != end; ++iteratorOp) { 467 if (failed(emitter.emitOperation(*iteratorOp, /*trailingSemicolon=*/true))) 468 return failure(); 469 } 470 os << "break;\n"; 471 return success(); 472 } 473 474 static LogicalResult printOperation(CppEmitter &emitter, 475 emitc::SwitchOp switchOp) { 476 raw_indented_ostream &os = emitter.ostream(); 477 478 os << "\nswitch ("; 479 if (failed(emitter.emitOperand(switchOp.getArg()))) 480 return failure(); 481 os << ") {"; 482 483 for (auto pair : llvm::zip(switchOp.getCases(), switchOp.getCaseRegions())) { 484 os << "\ncase " << std::get<0>(pair) << ": {\n"; 485 os.indent(); 486 487 if (failed(emitSwitchCase(emitter, os, std::get<1>(pair)))) 488 return failure(); 489 490 os.unindent() << "}"; 491 } 492 493 os << "\ndefault: {\n"; 494 os.indent(); 495 496 if (failed(emitSwitchCase(emitter, os, switchOp.getDefaultRegion()))) 497 return failure(); 498 499 os.unindent() << "}\n}"; 500 return success(); 501 } 502 503 static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) { 504 Operation *operation = cmpOp.getOperation(); 505 506 StringRef binaryOperator; 507 508 switch (cmpOp.getPredicate()) { 509 case emitc::CmpPredicate::eq: 510 binaryOperator = "=="; 511 break; 512 case emitc::CmpPredicate::ne: 513 binaryOperator = "!="; 514 break; 515 case emitc::CmpPredicate::lt: 516 binaryOperator = "<"; 517 break; 518 case emitc::CmpPredicate::le: 519 binaryOperator = "<="; 520 break; 521 case emitc::CmpPredicate::gt: 522 binaryOperator = ">"; 523 break; 524 case emitc::CmpPredicate::ge: 525 binaryOperator = ">="; 526 break; 527 case emitc::CmpPredicate::three_way: 528 binaryOperator = "<=>"; 529 break; 530 } 531 532 return printBinaryOperation(emitter, operation, binaryOperator); 533 } 534 535 static LogicalResult printOperation(CppEmitter &emitter, 536 emitc::ConditionalOp conditionalOp) { 537 raw_ostream &os = emitter.ostream(); 538 539 if (failed(emitter.emitAssignPrefix(*conditionalOp))) 540 return failure(); 541 542 if (failed(emitter.emitOperand(conditionalOp.getCondition()))) 543 return failure(); 544 545 os << " ? "; 546 547 if (failed(emitter.emitOperand(conditionalOp.getTrueValue()))) 548 return failure(); 549 550 os << " : "; 551 552 if (failed(emitter.emitOperand(conditionalOp.getFalseValue()))) 553 return failure(); 554 555 return success(); 556 } 557 558 static LogicalResult printOperation(CppEmitter &emitter, 559 emitc::VerbatimOp verbatimOp) { 560 raw_ostream &os = emitter.ostream(); 561 562 os << verbatimOp.getValue(); 563 564 return success(); 565 } 566 567 static LogicalResult printOperation(CppEmitter &emitter, 568 cf::BranchOp branchOp) { 569 raw_ostream &os = emitter.ostream(); 570 Block &successor = *branchOp.getSuccessor(); 571 572 for (auto pair : 573 llvm::zip(branchOp.getOperands(), successor.getArguments())) { 574 Value &operand = std::get<0>(pair); 575 BlockArgument &argument = std::get<1>(pair); 576 os << emitter.getOrCreateName(argument) << " = " 577 << emitter.getOrCreateName(operand) << ";\n"; 578 } 579 580 os << "goto "; 581 if (!(emitter.hasBlockLabel(successor))) 582 return branchOp.emitOpError("unable to find label for successor block"); 583 os << emitter.getOrCreateName(successor); 584 return success(); 585 } 586 587 static LogicalResult printOperation(CppEmitter &emitter, 588 cf::CondBranchOp condBranchOp) { 589 raw_indented_ostream &os = emitter.ostream(); 590 Block &trueSuccessor = *condBranchOp.getTrueDest(); 591 Block &falseSuccessor = *condBranchOp.getFalseDest(); 592 593 os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition()) 594 << ") {\n"; 595 596 os.indent(); 597 598 // If condition is true. 599 for (auto pair : llvm::zip(condBranchOp.getTrueOperands(), 600 trueSuccessor.getArguments())) { 601 Value &operand = std::get<0>(pair); 602 BlockArgument &argument = std::get<1>(pair); 603 os << emitter.getOrCreateName(argument) << " = " 604 << emitter.getOrCreateName(operand) << ";\n"; 605 } 606 607 os << "goto "; 608 if (!(emitter.hasBlockLabel(trueSuccessor))) { 609 return condBranchOp.emitOpError("unable to find label for successor block"); 610 } 611 os << emitter.getOrCreateName(trueSuccessor) << ";\n"; 612 os.unindent() << "} else {\n"; 613 os.indent(); 614 // If condition is false. 615 for (auto pair : llvm::zip(condBranchOp.getFalseOperands(), 616 falseSuccessor.getArguments())) { 617 Value &operand = std::get<0>(pair); 618 BlockArgument &argument = std::get<1>(pair); 619 os << emitter.getOrCreateName(argument) << " = " 620 << emitter.getOrCreateName(operand) << ";\n"; 621 } 622 623 os << "goto "; 624 if (!(emitter.hasBlockLabel(falseSuccessor))) { 625 return condBranchOp.emitOpError() 626 << "unable to find label for successor block"; 627 } 628 os << emitter.getOrCreateName(falseSuccessor) << ";\n"; 629 os.unindent() << "}"; 630 return success(); 631 } 632 633 static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp, 634 StringRef callee) { 635 if (failed(emitter.emitAssignPrefix(*callOp))) 636 return failure(); 637 638 raw_ostream &os = emitter.ostream(); 639 os << callee << "("; 640 if (failed(emitter.emitOperands(*callOp))) 641 return failure(); 642 os << ")"; 643 return success(); 644 } 645 646 static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) { 647 Operation *operation = callOp.getOperation(); 648 StringRef callee = callOp.getCallee(); 649 650 return printCallOperation(emitter, operation, callee); 651 } 652 653 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { 654 Operation *operation = callOp.getOperation(); 655 StringRef callee = callOp.getCallee(); 656 657 return printCallOperation(emitter, operation, callee); 658 } 659 660 static LogicalResult printOperation(CppEmitter &emitter, 661 emitc::CallOpaqueOp callOpaqueOp) { 662 raw_ostream &os = emitter.ostream(); 663 Operation &op = *callOpaqueOp.getOperation(); 664 665 if (failed(emitter.emitAssignPrefix(op))) 666 return failure(); 667 os << callOpaqueOp.getCallee(); 668 669 auto emitArgs = [&](Attribute attr) -> LogicalResult { 670 if (auto t = dyn_cast<IntegerAttr>(attr)) { 671 // Index attributes are treated specially as operand index. 672 if (t.getType().isIndex()) { 673 int64_t idx = t.getInt(); 674 Value operand = op.getOperand(idx); 675 if (!emitter.hasValueInScope(operand)) 676 return op.emitOpError("operand ") 677 << idx << "'s value not defined in scope"; 678 os << emitter.getOrCreateName(operand); 679 return success(); 680 } 681 } 682 if (failed(emitter.emitAttribute(op.getLoc(), attr))) 683 return failure(); 684 685 return success(); 686 }; 687 688 if (callOpaqueOp.getTemplateArgs()) { 689 os << "<"; 690 if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os, 691 emitArgs))) 692 return failure(); 693 os << ">"; 694 } 695 696 os << "("; 697 698 LogicalResult emittedArgs = 699 callOpaqueOp.getArgs() 700 ? interleaveCommaWithError(*callOpaqueOp.getArgs(), os, emitArgs) 701 : emitter.emitOperands(op); 702 if (failed(emittedArgs)) 703 return failure(); 704 os << ")"; 705 return success(); 706 } 707 708 static LogicalResult printOperation(CppEmitter &emitter, 709 emitc::ApplyOp applyOp) { 710 raw_ostream &os = emitter.ostream(); 711 Operation &op = *applyOp.getOperation(); 712 713 if (failed(emitter.emitAssignPrefix(op))) 714 return failure(); 715 os << applyOp.getApplicableOperator(); 716 os << emitter.getOrCreateName(applyOp.getOperand()); 717 718 return success(); 719 } 720 721 static LogicalResult printOperation(CppEmitter &emitter, 722 emitc::BitwiseAndOp bitwiseAndOp) { 723 Operation *operation = bitwiseAndOp.getOperation(); 724 return printBinaryOperation(emitter, operation, "&"); 725 } 726 727 static LogicalResult 728 printOperation(CppEmitter &emitter, 729 emitc::BitwiseLeftShiftOp bitwiseLeftShiftOp) { 730 Operation *operation = bitwiseLeftShiftOp.getOperation(); 731 return printBinaryOperation(emitter, operation, "<<"); 732 } 733 734 static LogicalResult printOperation(CppEmitter &emitter, 735 emitc::BitwiseNotOp bitwiseNotOp) { 736 Operation *operation = bitwiseNotOp.getOperation(); 737 return printUnaryOperation(emitter, operation, "~"); 738 } 739 740 static LogicalResult printOperation(CppEmitter &emitter, 741 emitc::BitwiseOrOp bitwiseOrOp) { 742 Operation *operation = bitwiseOrOp.getOperation(); 743 return printBinaryOperation(emitter, operation, "|"); 744 } 745 746 static LogicalResult 747 printOperation(CppEmitter &emitter, 748 emitc::BitwiseRightShiftOp bitwiseRightShiftOp) { 749 Operation *operation = bitwiseRightShiftOp.getOperation(); 750 return printBinaryOperation(emitter, operation, ">>"); 751 } 752 753 static LogicalResult printOperation(CppEmitter &emitter, 754 emitc::BitwiseXorOp bitwiseXorOp) { 755 Operation *operation = bitwiseXorOp.getOperation(); 756 return printBinaryOperation(emitter, operation, "^"); 757 } 758 759 static LogicalResult printOperation(CppEmitter &emitter, 760 emitc::UnaryPlusOp unaryPlusOp) { 761 Operation *operation = unaryPlusOp.getOperation(); 762 return printUnaryOperation(emitter, operation, "+"); 763 } 764 765 static LogicalResult printOperation(CppEmitter &emitter, 766 emitc::UnaryMinusOp unaryMinusOp) { 767 Operation *operation = unaryMinusOp.getOperation(); 768 return printUnaryOperation(emitter, operation, "-"); 769 } 770 771 static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) { 772 raw_ostream &os = emitter.ostream(); 773 Operation &op = *castOp.getOperation(); 774 775 if (failed(emitter.emitAssignPrefix(op))) 776 return failure(); 777 os << "("; 778 if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType()))) 779 return failure(); 780 os << ") "; 781 return emitter.emitOperand(castOp.getOperand()); 782 } 783 784 static LogicalResult printOperation(CppEmitter &emitter, 785 emitc::ExpressionOp expressionOp) { 786 if (shouldBeInlined(expressionOp)) 787 return success(); 788 789 Operation &op = *expressionOp.getOperation(); 790 791 if (failed(emitter.emitAssignPrefix(op))) 792 return failure(); 793 794 return emitter.emitExpression(expressionOp); 795 } 796 797 static LogicalResult printOperation(CppEmitter &emitter, 798 emitc::IncludeOp includeOp) { 799 raw_ostream &os = emitter.ostream(); 800 801 os << "#include "; 802 if (includeOp.getIsStandardInclude()) 803 os << "<" << includeOp.getInclude() << ">"; 804 else 805 os << "\"" << includeOp.getInclude() << "\""; 806 807 return success(); 808 } 809 810 static LogicalResult printOperation(CppEmitter &emitter, 811 emitc::LogicalAndOp logicalAndOp) { 812 Operation *operation = logicalAndOp.getOperation(); 813 return printBinaryOperation(emitter, operation, "&&"); 814 } 815 816 static LogicalResult printOperation(CppEmitter &emitter, 817 emitc::LogicalNotOp logicalNotOp) { 818 Operation *operation = logicalNotOp.getOperation(); 819 return printUnaryOperation(emitter, operation, "!"); 820 } 821 822 static LogicalResult printOperation(CppEmitter &emitter, 823 emitc::LogicalOrOp logicalOrOp) { 824 Operation *operation = logicalOrOp.getOperation(); 825 return printBinaryOperation(emitter, operation, "||"); 826 } 827 828 static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { 829 830 raw_indented_ostream &os = emitter.ostream(); 831 832 // Utility function to determine whether a value is an expression that will be 833 // inlined, and as such should be wrapped in parentheses in order to guarantee 834 // its precedence and associativity. 835 auto requiresParentheses = [&](Value value) { 836 auto expressionOp = 837 dyn_cast_if_present<ExpressionOp>(value.getDefiningOp()); 838 if (!expressionOp) 839 return false; 840 return shouldBeInlined(expressionOp); 841 }; 842 843 os << "for ("; 844 if (failed( 845 emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType()))) 846 return failure(); 847 os << " "; 848 os << emitter.getOrCreateName(forOp.getInductionVar()); 849 os << " = "; 850 if (failed(emitter.emitOperand(forOp.getLowerBound()))) 851 return failure(); 852 os << "; "; 853 os << emitter.getOrCreateName(forOp.getInductionVar()); 854 os << " < "; 855 Value upperBound = forOp.getUpperBound(); 856 bool upperBoundRequiresParentheses = requiresParentheses(upperBound); 857 if (upperBoundRequiresParentheses) 858 os << "("; 859 if (failed(emitter.emitOperand(upperBound))) 860 return failure(); 861 if (upperBoundRequiresParentheses) 862 os << ")"; 863 os << "; "; 864 os << emitter.getOrCreateName(forOp.getInductionVar()); 865 os << " += "; 866 if (failed(emitter.emitOperand(forOp.getStep()))) 867 return failure(); 868 os << ") {\n"; 869 os.indent(); 870 871 Region &forRegion = forOp.getRegion(); 872 auto regionOps = forRegion.getOps(); 873 874 // We skip the trailing yield op. 875 for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) { 876 if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true))) 877 return failure(); 878 } 879 880 os.unindent() << "}"; 881 882 return success(); 883 } 884 885 static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) { 886 raw_indented_ostream &os = emitter.ostream(); 887 888 // Helper function to emit all ops except the last one, expected to be 889 // emitc::yield. 890 auto emitAllExceptLast = [&emitter](Region ®ion) { 891 Region::OpIterator it = region.op_begin(), end = region.op_end(); 892 for (; std::next(it) != end; ++it) { 893 if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true))) 894 return failure(); 895 } 896 assert(isa<emitc::YieldOp>(*it) && 897 "Expected last operation in the region to be emitc::yield"); 898 return success(); 899 }; 900 901 os << "if ("; 902 if (failed(emitter.emitOperand(ifOp.getCondition()))) 903 return failure(); 904 os << ") {\n"; 905 os.indent(); 906 if (failed(emitAllExceptLast(ifOp.getThenRegion()))) 907 return failure(); 908 os.unindent() << "}"; 909 910 Region &elseRegion = ifOp.getElseRegion(); 911 if (!elseRegion.empty()) { 912 os << " else {\n"; 913 os.indent(); 914 if (failed(emitAllExceptLast(elseRegion))) 915 return failure(); 916 os.unindent() << "}"; 917 } 918 919 return success(); 920 } 921 922 static LogicalResult printOperation(CppEmitter &emitter, 923 func::ReturnOp returnOp) { 924 raw_ostream &os = emitter.ostream(); 925 os << "return"; 926 switch (returnOp.getNumOperands()) { 927 case 0: 928 return success(); 929 case 1: 930 os << " "; 931 if (failed(emitter.emitOperand(returnOp.getOperand(0)))) 932 return failure(); 933 return success(); 934 default: 935 os << " std::make_tuple("; 936 if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation()))) 937 return failure(); 938 os << ")"; 939 return success(); 940 } 941 } 942 943 static LogicalResult printOperation(CppEmitter &emitter, 944 emitc::ReturnOp returnOp) { 945 raw_ostream &os = emitter.ostream(); 946 os << "return"; 947 if (returnOp.getNumOperands() == 0) 948 return success(); 949 950 os << " "; 951 if (failed(emitter.emitOperand(returnOp.getOperand()))) 952 return failure(); 953 return success(); 954 } 955 956 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { 957 CppEmitter::Scope scope(emitter); 958 959 for (Operation &op : moduleOp) { 960 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false))) 961 return failure(); 962 } 963 return success(); 964 } 965 966 static LogicalResult printFunctionArgs(CppEmitter &emitter, 967 Operation *functionOp, 968 ArrayRef<Type> arguments) { 969 raw_indented_ostream &os = emitter.ostream(); 970 971 return ( 972 interleaveCommaWithError(arguments, os, [&](Type arg) -> LogicalResult { 973 return emitter.emitType(functionOp->getLoc(), arg); 974 })); 975 } 976 977 static LogicalResult printFunctionArgs(CppEmitter &emitter, 978 Operation *functionOp, 979 Region::BlockArgListType arguments) { 980 raw_indented_ostream &os = emitter.ostream(); 981 982 return (interleaveCommaWithError( 983 arguments, os, [&](BlockArgument arg) -> LogicalResult { 984 return emitter.emitVariableDeclaration( 985 functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg)); 986 })); 987 } 988 989 static LogicalResult printFunctionBody(CppEmitter &emitter, 990 Operation *functionOp, 991 Region::BlockListType &blocks) { 992 raw_indented_ostream &os = emitter.ostream(); 993 os.indent(); 994 995 if (emitter.shouldDeclareVariablesAtTop()) { 996 // Declare all variables that hold op results including those from nested 997 // regions. 998 WalkResult result = 999 functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult { 1000 if (isa<emitc::ExpressionOp>(op->getParentOp()) || 1001 (isa<emitc::ExpressionOp>(op) && 1002 shouldBeInlined(cast<emitc::ExpressionOp>(op)))) 1003 return WalkResult::skip(); 1004 for (OpResult result : op->getResults()) { 1005 if (failed(emitter.emitVariableDeclaration( 1006 result, /*trailingSemicolon=*/true))) { 1007 return WalkResult( 1008 op->emitError("unable to declare result variable for op")); 1009 } 1010 } 1011 return WalkResult::advance(); 1012 }); 1013 if (result.wasInterrupted()) 1014 return failure(); 1015 } 1016 1017 // Create label names for basic blocks. 1018 for (Block &block : blocks) { 1019 emitter.getOrCreateName(block); 1020 } 1021 1022 // Declare variables for basic block arguments. 1023 for (Block &block : llvm::drop_begin(blocks)) { 1024 for (BlockArgument &arg : block.getArguments()) { 1025 if (emitter.hasValueInScope(arg)) 1026 return functionOp->emitOpError(" block argument #") 1027 << arg.getArgNumber() << " is out of scope"; 1028 if (isa<ArrayType, LValueType>(arg.getType())) 1029 return functionOp->emitOpError("cannot emit block argument #") 1030 << arg.getArgNumber() << " with type " << arg.getType(); 1031 if (failed( 1032 emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) { 1033 return failure(); 1034 } 1035 os << " " << emitter.getOrCreateName(arg) << ";\n"; 1036 } 1037 } 1038 1039 for (Block &block : blocks) { 1040 // Only print a label if the block has predecessors. 1041 if (!block.hasNoPredecessors()) { 1042 if (failed(emitter.emitLabel(block))) 1043 return failure(); 1044 } 1045 for (Operation &op : block.getOperations()) { 1046 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) 1047 return failure(); 1048 } 1049 } 1050 1051 os.unindent(); 1052 1053 return success(); 1054 } 1055 1056 static LogicalResult printOperation(CppEmitter &emitter, 1057 func::FuncOp functionOp) { 1058 // We need to declare variables at top if the function has multiple blocks. 1059 if (!emitter.shouldDeclareVariablesAtTop() && 1060 functionOp.getBlocks().size() > 1) { 1061 return functionOp.emitOpError( 1062 "with multiple blocks needs variables declared at top"); 1063 } 1064 1065 if (llvm::any_of(functionOp.getArgumentTypes(), llvm::IsaPred<LValueType>)) { 1066 return functionOp.emitOpError() 1067 << "cannot emit lvalue type as argument type"; 1068 } 1069 1070 if (llvm::any_of(functionOp.getResultTypes(), llvm::IsaPred<ArrayType>)) { 1071 return functionOp.emitOpError() << "cannot emit array type as result type"; 1072 } 1073 1074 CppEmitter::Scope scope(emitter); 1075 raw_indented_ostream &os = emitter.ostream(); 1076 if (failed(emitter.emitTypes(functionOp.getLoc(), 1077 functionOp.getFunctionType().getResults()))) 1078 return failure(); 1079 os << " " << functionOp.getName(); 1080 1081 os << "("; 1082 Operation *operation = functionOp.getOperation(); 1083 if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) 1084 return failure(); 1085 os << ") {\n"; 1086 if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks()))) 1087 return failure(); 1088 os << "}\n"; 1089 1090 return success(); 1091 } 1092 1093 static LogicalResult printOperation(CppEmitter &emitter, 1094 emitc::FuncOp functionOp) { 1095 // We need to declare variables at top if the function has multiple blocks. 1096 if (!emitter.shouldDeclareVariablesAtTop() && 1097 functionOp.getBlocks().size() > 1) { 1098 return functionOp.emitOpError( 1099 "with multiple blocks needs variables declared at top"); 1100 } 1101 1102 CppEmitter::Scope scope(emitter); 1103 raw_indented_ostream &os = emitter.ostream(); 1104 if (functionOp.getSpecifiers()) { 1105 for (Attribute specifier : functionOp.getSpecifiersAttr()) { 1106 os << cast<StringAttr>(specifier).str() << " "; 1107 } 1108 } 1109 1110 if (failed(emitter.emitTypes(functionOp.getLoc(), 1111 functionOp.getFunctionType().getResults()))) 1112 return failure(); 1113 os << " " << functionOp.getName(); 1114 1115 os << "("; 1116 Operation *operation = functionOp.getOperation(); 1117 if (functionOp.isExternal()) { 1118 if (failed(printFunctionArgs(emitter, operation, 1119 functionOp.getArgumentTypes()))) 1120 return failure(); 1121 os << ");"; 1122 return success(); 1123 } 1124 if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) 1125 return failure(); 1126 os << ") {\n"; 1127 if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks()))) 1128 return failure(); 1129 os << "}\n"; 1130 1131 return success(); 1132 } 1133 1134 static LogicalResult printOperation(CppEmitter &emitter, 1135 DeclareFuncOp declareFuncOp) { 1136 CppEmitter::Scope scope(emitter); 1137 raw_indented_ostream &os = emitter.ostream(); 1138 1139 auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>( 1140 declareFuncOp, declareFuncOp.getSymNameAttr()); 1141 1142 if (!functionOp) 1143 return failure(); 1144 1145 if (functionOp.getSpecifiers()) { 1146 for (Attribute specifier : functionOp.getSpecifiersAttr()) { 1147 os << cast<StringAttr>(specifier).str() << " "; 1148 } 1149 } 1150 1151 if (failed(emitter.emitTypes(functionOp.getLoc(), 1152 functionOp.getFunctionType().getResults()))) 1153 return failure(); 1154 os << " " << functionOp.getName(); 1155 1156 os << "("; 1157 Operation *operation = functionOp.getOperation(); 1158 if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) 1159 return failure(); 1160 os << ");"; 1161 1162 return success(); 1163 } 1164 1165 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop) 1166 : os(os), declareVariablesAtTop(declareVariablesAtTop) { 1167 valueInScopeCount.push(0); 1168 labelInScopeCount.push(0); 1169 } 1170 1171 std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) { 1172 std::string out; 1173 llvm::raw_string_ostream ss(out); 1174 ss << getOrCreateName(op.getValue()); 1175 for (auto index : op.getIndices()) { 1176 ss << "[" << getOrCreateName(index) << "]"; 1177 } 1178 return out; 1179 } 1180 1181 std::string CppEmitter::createMemberAccess(emitc::MemberOp op) { 1182 std::string out; 1183 llvm::raw_string_ostream ss(out); 1184 ss << getOrCreateName(op.getOperand()); 1185 ss << "." << op.getMember(); 1186 return out; 1187 } 1188 1189 std::string CppEmitter::createMemberAccess(emitc::MemberOfPtrOp op) { 1190 std::string out; 1191 llvm::raw_string_ostream ss(out); 1192 ss << getOrCreateName(op.getOperand()); 1193 ss << "->" << op.getMember(); 1194 return out; 1195 } 1196 1197 void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) { 1198 if (!valueMapper.count(value)) 1199 valueMapper.insert(value, str.str()); 1200 } 1201 1202 /// Return the existing or a new name for a Value. 1203 StringRef CppEmitter::getOrCreateName(Value val) { 1204 if (!valueMapper.count(val)) { 1205 assert(!hasDeferredEmission(val.getDefiningOp()) && 1206 "cacheDeferredOpResult should have been called on this value, " 1207 "update the emitOperation function."); 1208 valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); 1209 } 1210 return *valueMapper.begin(val); 1211 } 1212 1213 /// Return the existing or a new label for a Block. 1214 StringRef CppEmitter::getOrCreateName(Block &block) { 1215 if (!blockMapper.count(&block)) 1216 blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top())); 1217 return *blockMapper.begin(&block); 1218 } 1219 1220 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) { 1221 switch (val) { 1222 case IntegerType::Signless: 1223 return false; 1224 case IntegerType::Signed: 1225 return false; 1226 case IntegerType::Unsigned: 1227 return true; 1228 } 1229 llvm_unreachable("Unexpected IntegerType::SignednessSemantics"); 1230 } 1231 1232 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); } 1233 1234 bool CppEmitter::hasBlockLabel(Block &block) { 1235 return blockMapper.count(&block); 1236 } 1237 1238 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { 1239 auto printInt = [&](const APInt &val, bool isUnsigned) { 1240 if (val.getBitWidth() == 1) { 1241 if (val.getBoolValue()) 1242 os << "true"; 1243 else 1244 os << "false"; 1245 } else { 1246 SmallString<128> strValue; 1247 val.toString(strValue, 10, !isUnsigned, false); 1248 os << strValue; 1249 } 1250 }; 1251 1252 auto printFloat = [&](const APFloat &val) { 1253 if (val.isFinite()) { 1254 SmallString<128> strValue; 1255 // Use default values of toString except don't truncate zeros. 1256 val.toString(strValue, 0, 0, false); 1257 os << strValue; 1258 switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) { 1259 case llvm::APFloatBase::S_IEEEhalf: 1260 os << "f16"; 1261 break; 1262 case llvm::APFloatBase::S_BFloat: 1263 os << "bf16"; 1264 break; 1265 case llvm::APFloatBase::S_IEEEsingle: 1266 os << "f"; 1267 break; 1268 case llvm::APFloatBase::S_IEEEdouble: 1269 break; 1270 default: 1271 llvm_unreachable("unsupported floating point type"); 1272 }; 1273 } else if (val.isNaN()) { 1274 os << "NAN"; 1275 } else if (val.isInfinity()) { 1276 if (val.isNegative()) 1277 os << "-"; 1278 os << "INFINITY"; 1279 } 1280 }; 1281 1282 // Print floating point attributes. 1283 if (auto fAttr = dyn_cast<FloatAttr>(attr)) { 1284 if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>( 1285 fAttr.getType())) { 1286 return emitError( 1287 loc, "expected floating point attribute to be f16, bf16, f32 or f64"); 1288 } 1289 printFloat(fAttr.getValue()); 1290 return success(); 1291 } 1292 if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) { 1293 if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>( 1294 dense.getElementType())) { 1295 return emitError( 1296 loc, "expected floating point attribute to be f16, bf16, f32 or f64"); 1297 } 1298 os << '{'; 1299 interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); }); 1300 os << '}'; 1301 return success(); 1302 } 1303 1304 // Print integer attributes. 1305 if (auto iAttr = dyn_cast<IntegerAttr>(attr)) { 1306 if (auto iType = dyn_cast<IntegerType>(iAttr.getType())) { 1307 printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness())); 1308 return success(); 1309 } 1310 if (auto iType = dyn_cast<IndexType>(iAttr.getType())) { 1311 printInt(iAttr.getValue(), false); 1312 return success(); 1313 } 1314 } 1315 if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) { 1316 if (auto iType = dyn_cast<IntegerType>( 1317 cast<TensorType>(dense.getType()).getElementType())) { 1318 os << '{'; 1319 interleaveComma(dense, os, [&](const APInt &val) { 1320 printInt(val, shouldMapToUnsigned(iType.getSignedness())); 1321 }); 1322 os << '}'; 1323 return success(); 1324 } 1325 if (auto iType = dyn_cast<IndexType>( 1326 cast<TensorType>(dense.getType()).getElementType())) { 1327 os << '{'; 1328 interleaveComma(dense, os, 1329 [&](const APInt &val) { printInt(val, false); }); 1330 os << '}'; 1331 return success(); 1332 } 1333 } 1334 1335 // Print opaque attributes. 1336 if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) { 1337 os << oAttr.getValue(); 1338 return success(); 1339 } 1340 1341 // Print symbolic reference attributes. 1342 if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) { 1343 if (sAttr.getNestedReferences().size() > 1) 1344 return emitError(loc, "attribute has more than 1 nested reference"); 1345 os << sAttr.getRootReference().getValue(); 1346 return success(); 1347 } 1348 1349 // Print type attributes. 1350 if (auto type = dyn_cast<TypeAttr>(attr)) 1351 return emitType(loc, type.getValue()); 1352 1353 return emitError(loc, "cannot emit attribute: ") << attr; 1354 } 1355 1356 LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { 1357 assert(emittedExpressionPrecedence.empty() && 1358 "Expected precedence stack to be empty"); 1359 Operation *rootOp = expressionOp.getRootOp(); 1360 1361 emittedExpression = expressionOp; 1362 FailureOr<int> precedence = getOperatorPrecedence(rootOp); 1363 if (failed(precedence)) 1364 return failure(); 1365 pushExpressionPrecedence(precedence.value()); 1366 1367 if (failed(emitOperation(*rootOp, /*trailingSemicolon=*/false))) 1368 return failure(); 1369 1370 popExpressionPrecedence(); 1371 assert(emittedExpressionPrecedence.empty() && 1372 "Expected precedence stack to be empty"); 1373 emittedExpression = nullptr; 1374 1375 return success(); 1376 } 1377 1378 LogicalResult CppEmitter::emitOperand(Value value) { 1379 if (isPartOfCurrentExpression(value)) { 1380 Operation *def = value.getDefiningOp(); 1381 assert(def && "Expected operand to be defined by an operation"); 1382 FailureOr<int> precedence = getOperatorPrecedence(def); 1383 if (failed(precedence)) 1384 return failure(); 1385 1386 // Sub-expressions with equal or lower precedence need to be parenthesized, 1387 // as they might be evaluated in the wrong order depending on the shape of 1388 // the expression tree. 1389 bool encloseInParenthesis = precedence.value() <= getExpressionPrecedence(); 1390 if (encloseInParenthesis) { 1391 os << "("; 1392 pushExpressionPrecedence(lowestPrecedence()); 1393 } else 1394 pushExpressionPrecedence(precedence.value()); 1395 1396 if (failed(emitOperation(*def, /*trailingSemicolon=*/false))) 1397 return failure(); 1398 1399 if (encloseInParenthesis) 1400 os << ")"; 1401 1402 popExpressionPrecedence(); 1403 return success(); 1404 } 1405 1406 auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp()); 1407 if (expressionOp && shouldBeInlined(expressionOp)) 1408 return emitExpression(expressionOp); 1409 1410 os << getOrCreateName(value); 1411 return success(); 1412 } 1413 1414 LogicalResult CppEmitter::emitOperands(Operation &op) { 1415 return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) { 1416 // If an expression is being emitted, push lowest precedence as these 1417 // operands are either wrapped by parenthesis. 1418 if (getEmittedExpression()) 1419 pushExpressionPrecedence(lowestPrecedence()); 1420 if (failed(emitOperand(operand))) 1421 return failure(); 1422 if (getEmittedExpression()) 1423 popExpressionPrecedence(); 1424 return success(); 1425 }); 1426 } 1427 1428 LogicalResult 1429 CppEmitter::emitOperandsAndAttributes(Operation &op, 1430 ArrayRef<StringRef> exclude) { 1431 if (failed(emitOperands(op))) 1432 return failure(); 1433 // Insert comma in between operands and non-filtered attributes if needed. 1434 if (op.getNumOperands() > 0) { 1435 for (NamedAttribute attr : op.getAttrs()) { 1436 if (!llvm::is_contained(exclude, attr.getName().strref())) { 1437 os << ", "; 1438 break; 1439 } 1440 } 1441 } 1442 // Emit attributes. 1443 auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult { 1444 if (llvm::is_contained(exclude, attr.getName().strref())) 1445 return success(); 1446 os << "/* " << attr.getName().getValue() << " */"; 1447 if (failed(emitAttribute(op.getLoc(), attr.getValue()))) 1448 return failure(); 1449 return success(); 1450 }; 1451 return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute); 1452 } 1453 1454 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) { 1455 if (!hasValueInScope(result)) { 1456 return result.getDefiningOp()->emitOpError( 1457 "result variable for the operation has not been declared"); 1458 } 1459 os << getOrCreateName(result) << " = "; 1460 return success(); 1461 } 1462 1463 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, 1464 bool trailingSemicolon) { 1465 if (hasDeferredEmission(result.getDefiningOp())) 1466 return success(); 1467 if (hasValueInScope(result)) { 1468 return result.getDefiningOp()->emitError( 1469 "result variable for the operation already declared"); 1470 } 1471 if (failed(emitVariableDeclaration(result.getOwner()->getLoc(), 1472 result.getType(), 1473 getOrCreateName(result)))) 1474 return failure(); 1475 if (trailingSemicolon) 1476 os << ";\n"; 1477 return success(); 1478 } 1479 1480 LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) { 1481 if (op.getExternSpecifier()) 1482 os << "extern "; 1483 else if (op.getStaticSpecifier()) 1484 os << "static "; 1485 if (op.getConstSpecifier()) 1486 os << "const "; 1487 1488 if (failed(emitVariableDeclaration(op->getLoc(), op.getType(), 1489 op.getSymName()))) { 1490 return failure(); 1491 } 1492 1493 std::optional<Attribute> initialValue = op.getInitialValue(); 1494 if (initialValue) { 1495 os << " = "; 1496 if (failed(emitAttribute(op->getLoc(), *initialValue))) 1497 return failure(); 1498 } 1499 1500 os << ";"; 1501 return success(); 1502 } 1503 1504 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { 1505 // If op is being emitted as part of an expression, bail out. 1506 if (getEmittedExpression()) 1507 return success(); 1508 1509 switch (op.getNumResults()) { 1510 case 0: 1511 break; 1512 case 1: { 1513 OpResult result = op.getResult(0); 1514 if (shouldDeclareVariablesAtTop()) { 1515 if (failed(emitVariableAssignment(result))) 1516 return failure(); 1517 } else { 1518 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false))) 1519 return failure(); 1520 os << " = "; 1521 } 1522 break; 1523 } 1524 default: 1525 if (!shouldDeclareVariablesAtTop()) { 1526 for (OpResult result : op.getResults()) { 1527 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true))) 1528 return failure(); 1529 } 1530 } 1531 os << "std::tie("; 1532 interleaveComma(op.getResults(), os, 1533 [&](Value result) { os << getOrCreateName(result); }); 1534 os << ") = "; 1535 } 1536 return success(); 1537 } 1538 1539 LogicalResult CppEmitter::emitLabel(Block &block) { 1540 if (!hasBlockLabel(block)) 1541 return block.getParentOp()->emitError("label for block not found"); 1542 // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block 1543 // label instead of using `getOStream`. 1544 os.getOStream() << getOrCreateName(block) << ":\n"; 1545 return success(); 1546 } 1547 1548 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { 1549 LogicalResult status = 1550 llvm::TypeSwitch<Operation *, LogicalResult>(&op) 1551 // Builtin ops. 1552 .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); }) 1553 // CF ops. 1554 .Case<cf::BranchOp, cf::CondBranchOp>( 1555 [&](auto op) { return printOperation(*this, op); }) 1556 // EmitC ops. 1557 .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, 1558 emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp, 1559 emitc::BitwiseNotOp, emitc::BitwiseOrOp, 1560 emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp, 1561 emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp, 1562 emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp, 1563 emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, 1564 emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LoadOp, 1565 emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp, 1566 emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp, 1567 emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp, 1568 emitc::VariableOp, emitc::VerbatimOp>( 1569 [&](auto op) { return printOperation(*this, op); }) 1570 // Func ops. 1571 .Case<func::CallOp, func::FuncOp, func::ReturnOp>( 1572 [&](auto op) { return printOperation(*this, op); }) 1573 .Case<emitc::GetGlobalOp>([&](auto op) { 1574 cacheDeferredOpResult(op.getResult(), op.getName()); 1575 return success(); 1576 }) 1577 .Case<emitc::LiteralOp>([&](auto op) { 1578 cacheDeferredOpResult(op.getResult(), op.getValue()); 1579 return success(); 1580 }) 1581 .Case<emitc::MemberOp>([&](auto op) { 1582 cacheDeferredOpResult(op.getResult(), createMemberAccess(op)); 1583 return success(); 1584 }) 1585 .Case<emitc::MemberOfPtrOp>([&](auto op) { 1586 cacheDeferredOpResult(op.getResult(), createMemberAccess(op)); 1587 return success(); 1588 }) 1589 .Case<emitc::SubscriptOp>([&](auto op) { 1590 cacheDeferredOpResult(op.getResult(), getSubscriptName(op)); 1591 return success(); 1592 }) 1593 .Default([&](Operation *) { 1594 return op.emitOpError("unable to find printer for op"); 1595 }); 1596 1597 if (failed(status)) 1598 return failure(); 1599 1600 if (hasDeferredEmission(&op)) 1601 return success(); 1602 1603 if (getEmittedExpression() || 1604 (isa<emitc::ExpressionOp>(op) && 1605 shouldBeInlined(cast<emitc::ExpressionOp>(op)))) 1606 return success(); 1607 1608 // Never emit a semicolon for some operations, especially if endening with 1609 // `}`. 1610 trailingSemicolon &= 1611 !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp, emitc::IfOp, 1612 emitc::IncludeOp, emitc::SwitchOp, emitc::VerbatimOp>(op); 1613 1614 os << (trailingSemicolon ? ";\n" : "\n"); 1615 1616 return success(); 1617 } 1618 1619 LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type, 1620 StringRef name) { 1621 if (auto arrType = dyn_cast<emitc::ArrayType>(type)) { 1622 if (failed(emitType(loc, arrType.getElementType()))) 1623 return failure(); 1624 os << " " << name; 1625 for (auto dim : arrType.getShape()) { 1626 os << "[" << dim << "]"; 1627 } 1628 return success(); 1629 } 1630 if (failed(emitType(loc, type))) 1631 return failure(); 1632 os << " " << name; 1633 return success(); 1634 } 1635 1636 LogicalResult CppEmitter::emitType(Location loc, Type type) { 1637 if (auto iType = dyn_cast<IntegerType>(type)) { 1638 switch (iType.getWidth()) { 1639 case 1: 1640 return (os << "bool"), success(); 1641 case 8: 1642 case 16: 1643 case 32: 1644 case 64: 1645 if (shouldMapToUnsigned(iType.getSignedness())) 1646 return (os << "uint" << iType.getWidth() << "_t"), success(); 1647 else 1648 return (os << "int" << iType.getWidth() << "_t"), success(); 1649 default: 1650 return emitError(loc, "cannot emit integer type ") << type; 1651 } 1652 } 1653 if (auto fType = dyn_cast<FloatType>(type)) { 1654 switch (fType.getWidth()) { 1655 case 16: { 1656 if (llvm::isa<Float16Type>(type)) 1657 return (os << "_Float16"), success(); 1658 else if (llvm::isa<BFloat16Type>(type)) 1659 return (os << "__bf16"), success(); 1660 else 1661 return emitError(loc, "cannot emit float type ") << type; 1662 } 1663 case 32: 1664 return (os << "float"), success(); 1665 case 64: 1666 return (os << "double"), success(); 1667 default: 1668 return emitError(loc, "cannot emit float type ") << type; 1669 } 1670 } 1671 if (auto iType = dyn_cast<IndexType>(type)) 1672 return (os << "size_t"), success(); 1673 if (auto sType = dyn_cast<emitc::SizeTType>(type)) 1674 return (os << "size_t"), success(); 1675 if (auto sType = dyn_cast<emitc::SignedSizeTType>(type)) 1676 return (os << "ssize_t"), success(); 1677 if (auto pType = dyn_cast<emitc::PtrDiffTType>(type)) 1678 return (os << "ptrdiff_t"), success(); 1679 if (auto tType = dyn_cast<TensorType>(type)) { 1680 if (!tType.hasRank()) 1681 return emitError(loc, "cannot emit unranked tensor type"); 1682 if (!tType.hasStaticShape()) 1683 return emitError(loc, "cannot emit tensor type with non static shape"); 1684 os << "Tensor<"; 1685 if (isa<ArrayType>(tType.getElementType())) 1686 return emitError(loc, "cannot emit tensor of array type ") << type; 1687 if (failed(emitType(loc, tType.getElementType()))) 1688 return failure(); 1689 auto shape = tType.getShape(); 1690 for (auto dimSize : shape) { 1691 os << ", "; 1692 os << dimSize; 1693 } 1694 os << ">"; 1695 return success(); 1696 } 1697 if (auto tType = dyn_cast<TupleType>(type)) 1698 return emitTupleType(loc, tType.getTypes()); 1699 if (auto oType = dyn_cast<emitc::OpaqueType>(type)) { 1700 os << oType.getValue(); 1701 return success(); 1702 } 1703 if (auto aType = dyn_cast<emitc::ArrayType>(type)) { 1704 if (failed(emitType(loc, aType.getElementType()))) 1705 return failure(); 1706 for (auto dim : aType.getShape()) 1707 os << "[" << dim << "]"; 1708 return success(); 1709 } 1710 if (auto lType = dyn_cast<emitc::LValueType>(type)) 1711 return emitType(loc, lType.getValueType()); 1712 if (auto pType = dyn_cast<emitc::PointerType>(type)) { 1713 if (isa<ArrayType>(pType.getPointee())) 1714 return emitError(loc, "cannot emit pointer to array type ") << type; 1715 if (failed(emitType(loc, pType.getPointee()))) 1716 return failure(); 1717 os << "*"; 1718 return success(); 1719 } 1720 return emitError(loc, "cannot emit type ") << type; 1721 } 1722 1723 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) { 1724 switch (types.size()) { 1725 case 0: 1726 os << "void"; 1727 return success(); 1728 case 1: 1729 return emitType(loc, types.front()); 1730 default: 1731 return emitTupleType(loc, types); 1732 } 1733 } 1734 1735 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) { 1736 if (llvm::any_of(types, llvm::IsaPred<ArrayType>)) { 1737 return emitError(loc, "cannot emit tuple of array type"); 1738 } 1739 os << "std::tuple<"; 1740 if (failed(interleaveCommaWithError( 1741 types, os, [&](Type type) { return emitType(loc, type); }))) 1742 return failure(); 1743 os << ">"; 1744 return success(); 1745 } 1746 1747 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os, 1748 bool declareVariablesAtTop) { 1749 CppEmitter emitter(os, declareVariablesAtTop); 1750 return emitter.emitOperation(*op, /*trailingSemicolon=*/false); 1751 } 1752