1 //===- ControlFlowOps.cpp - MLIR SPIR-V Control Flow Ops -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Defines the control flow operations in the SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 16 #include "mlir/Interfaces/CallInterfaces.h" 17 18 #include "SPIRVOpUtils.h" 19 #include "SPIRVParsingUtils.h" 20 21 using namespace mlir::spirv::AttrNames; 22 23 namespace mlir::spirv { 24 25 /// Parses Function, Selection and Loop control attributes. If no control is 26 /// specified, "None" is used as a default. 27 template <typename EnumAttrClass, typename EnumClass> 28 static ParseResult 29 parseControlAttribute(OpAsmParser &parser, OperationState &state, 30 StringRef attrName = spirv::attributeName<EnumClass>()) { 31 if (succeeded(parser.parseOptionalKeyword(kControl))) { 32 EnumClass control; 33 if (parser.parseLParen() || 34 spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) || 35 parser.parseRParen()) 36 return failure(); 37 return success(); 38 } 39 // Set control to "None" otherwise. 40 Builder builder = parser.getBuilder(); 41 state.addAttribute(attrName, 42 builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0))); 43 return success(); 44 } 45 46 //===----------------------------------------------------------------------===// 47 // spirv.BranchOp 48 //===----------------------------------------------------------------------===// 49 50 SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) { 51 assert(index == 0 && "invalid successor index"); 52 return SuccessorOperands(0, getTargetOperandsMutable()); 53 } 54 55 //===----------------------------------------------------------------------===// 56 // spirv.BranchConditionalOp 57 //===----------------------------------------------------------------------===// 58 59 SuccessorOperands BranchConditionalOp::getSuccessorOperands(unsigned index) { 60 assert(index < 2 && "invalid successor index"); 61 return SuccessorOperands(index == kTrueIndex 62 ? getTrueTargetOperandsMutable() 63 : getFalseTargetOperandsMutable()); 64 } 65 66 ParseResult BranchConditionalOp::parse(OpAsmParser &parser, 67 OperationState &result) { 68 auto &builder = parser.getBuilder(); 69 OpAsmParser::UnresolvedOperand condInfo; 70 Block *dest; 71 72 // Parse the condition. 73 Type boolTy = builder.getI1Type(); 74 if (parser.parseOperand(condInfo) || 75 parser.resolveOperand(condInfo, boolTy, result.operands)) 76 return failure(); 77 78 // Parse the optional branch weights. 79 if (succeeded(parser.parseOptionalLSquare())) { 80 IntegerAttr trueWeight, falseWeight; 81 NamedAttrList weights; 82 83 auto i32Type = builder.getIntegerType(32); 84 if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) || 85 parser.parseComma() || 86 parser.parseAttribute(falseWeight, i32Type, "weight", weights) || 87 parser.parseRSquare()) 88 return failure(); 89 90 StringAttr branchWeightsAttrName = 91 BranchConditionalOp::getBranchWeightsAttrName(result.name); 92 result.addAttribute(branchWeightsAttrName, 93 builder.getArrayAttr({trueWeight, falseWeight})); 94 } 95 96 // Parse the true branch. 97 SmallVector<Value, 4> trueOperands; 98 if (parser.parseComma() || 99 parser.parseSuccessorAndUseList(dest, trueOperands)) 100 return failure(); 101 result.addSuccessors(dest); 102 result.addOperands(trueOperands); 103 104 // Parse the false branch. 105 SmallVector<Value, 4> falseOperands; 106 if (parser.parseComma() || 107 parser.parseSuccessorAndUseList(dest, falseOperands)) 108 return failure(); 109 result.addSuccessors(dest); 110 result.addOperands(falseOperands); 111 result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(), 112 builder.getDenseI32ArrayAttr( 113 {1, static_cast<int32_t>(trueOperands.size()), 114 static_cast<int32_t>(falseOperands.size())})); 115 116 return success(); 117 } 118 119 void BranchConditionalOp::print(OpAsmPrinter &printer) { 120 printer << ' ' << getCondition(); 121 122 if (auto weights = getBranchWeights()) { 123 printer << " ["; 124 llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) { 125 printer << llvm::cast<IntegerAttr>(a).getInt(); 126 }); 127 printer << "]"; 128 } 129 130 printer << ", "; 131 printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments()); 132 printer << ", "; 133 printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments()); 134 } 135 136 LogicalResult BranchConditionalOp::verify() { 137 if (auto weights = getBranchWeights()) { 138 if (weights->getValue().size() != 2) { 139 return emitOpError("must have exactly two branch weights"); 140 } 141 if (llvm::all_of(*weights, [](Attribute attr) { 142 return llvm::cast<IntegerAttr>(attr).getValue().isZero(); 143 })) 144 return emitOpError("branch weights cannot both be zero"); 145 } 146 147 return success(); 148 } 149 150 //===----------------------------------------------------------------------===// 151 // spirv.FunctionCall 152 //===----------------------------------------------------------------------===// 153 154 LogicalResult FunctionCallOp::verify() { 155 auto fnName = getCalleeAttr(); 156 157 auto funcOp = dyn_cast_or_null<spirv::FuncOp>( 158 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName)); 159 if (!funcOp) { 160 return emitOpError("callee function '") 161 << fnName.getValue() << "' not found in nearest symbol table"; 162 } 163 164 auto functionType = funcOp.getFunctionType(); 165 166 if (getNumResults() > 1) { 167 return emitOpError( 168 "expected callee function to have 0 or 1 result, but provided ") 169 << getNumResults(); 170 } 171 172 if (functionType.getNumInputs() != getNumOperands()) { 173 return emitOpError("has incorrect number of operands for callee: expected ") 174 << functionType.getNumInputs() << ", but provided " 175 << getNumOperands(); 176 } 177 178 for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { 179 if (getOperand(i).getType() != functionType.getInput(i)) { 180 return emitOpError("operand type mismatch: expected operand type ") 181 << functionType.getInput(i) << ", but provided " 182 << getOperand(i).getType() << " for operand number " << i; 183 } 184 } 185 186 if (functionType.getNumResults() != getNumResults()) { 187 return emitOpError( 188 "has incorrect number of results has for callee: expected ") 189 << functionType.getNumResults() << ", but provided " 190 << getNumResults(); 191 } 192 193 if (getNumResults() && 194 (getResult(0).getType() != functionType.getResult(0))) { 195 return emitOpError("result type mismatch: expected ") 196 << functionType.getResult(0) << ", but provided " 197 << getResult(0).getType(); 198 } 199 200 return success(); 201 } 202 203 CallInterfaceCallable FunctionCallOp::getCallableForCallee() { 204 return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName()); 205 } 206 207 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { 208 (*this)->setAttr(getCalleeAttrName(), cast<SymbolRefAttr>(callee)); 209 } 210 211 Operation::operand_range FunctionCallOp::getArgOperands() { 212 return getArguments(); 213 } 214 215 MutableOperandRange FunctionCallOp::getArgOperandsMutable() { 216 return getArgumentsMutable(); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // spirv.mlir.loop 221 //===----------------------------------------------------------------------===// 222 223 void LoopOp::build(OpBuilder &builder, OperationState &state) { 224 state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>( 225 spirv::LoopControl::None)); 226 state.addRegion(); 227 } 228 229 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) { 230 if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser, 231 result)) 232 return failure(); 233 return parser.parseRegion(*result.addRegion(), /*arguments=*/{}); 234 } 235 236 void LoopOp::print(OpAsmPrinter &printer) { 237 auto control = getLoopControl(); 238 if (control != spirv::LoopControl::None) 239 printer << " control(" << spirv::stringifyLoopControl(control) << ")"; 240 printer << ' '; 241 printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 242 /*printBlockTerminators=*/true); 243 } 244 245 /// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the 246 /// given `dstBlock`. 247 static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) { 248 // Check that there is only one op in the `srcBlock`. 249 if (!llvm::hasSingleElement(srcBlock)) 250 return false; 251 252 auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back()); 253 return branchOp && branchOp.getSuccessor() == &dstBlock; 254 } 255 256 /// Returns true if the given `block` only contains one `spirv.mlir.merge` op. 257 static bool isMergeBlock(Block &block) { 258 return !block.empty() && std::next(block.begin()) == block.end() && 259 isa<spirv::MergeOp>(block.front()); 260 } 261 262 LogicalResult LoopOp::verifyRegions() { 263 auto *op = getOperation(); 264 265 // We need to verify that the blocks follow the following layout: 266 // 267 // +-------------+ 268 // | entry block | 269 // +-------------+ 270 // | 271 // v 272 // +-------------+ 273 // | loop header | <-----+ 274 // +-------------+ | 275 // | 276 // ... | 277 // \ | / | 278 // v | 279 // +---------------+ | 280 // | loop continue | -----+ 281 // +---------------+ 282 // 283 // ... 284 // \ | / 285 // v 286 // +-------------+ 287 // | merge block | 288 // +-------------+ 289 290 auto ®ion = op->getRegion(0); 291 // Allow empty region as a degenerated case, which can come from 292 // optimizations. 293 if (region.empty()) 294 return success(); 295 296 // The last block is the merge block. 297 Block &merge = region.back(); 298 if (!isMergeBlock(merge)) 299 return emitOpError("last block must be the merge block with only one " 300 "'spirv.mlir.merge' op"); 301 302 if (std::next(region.begin()) == region.end()) 303 return emitOpError( 304 "must have an entry block branching to the loop header block"); 305 // The first block is the entry block. 306 Block &entry = region.front(); 307 308 if (std::next(region.begin(), 2) == region.end()) 309 return emitOpError( 310 "must have a loop header block branched from the entry block"); 311 // The second block is the loop header block. 312 Block &header = *std::next(region.begin(), 1); 313 314 if (!hasOneBranchOpTo(entry, header)) 315 return emitOpError( 316 "entry block must only have one 'spirv.Branch' op to the second block"); 317 318 if (std::next(region.begin(), 3) == region.end()) 319 return emitOpError( 320 "requires a loop continue block branching to the loop header block"); 321 // The second to last block is the loop continue block. 322 Block &cont = *std::prev(region.end(), 2); 323 324 // Make sure that we have a branch from the loop continue block to the loop 325 // header block. 326 if (llvm::none_of( 327 llvm::seq<unsigned>(0, cont.getNumSuccessors()), 328 [&](unsigned index) { return cont.getSuccessor(index) == &header; })) 329 return emitOpError("second to last block must be the loop continue " 330 "block that branches to the loop header block"); 331 332 // Make sure that no other blocks (except the entry and loop continue block) 333 // branches to the loop header block. 334 for (auto &block : llvm::make_range(std::next(region.begin(), 2), 335 std::prev(region.end(), 2))) { 336 for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) { 337 if (block.getSuccessor(i) == &header) { 338 return emitOpError("can only have the entry and loop continue " 339 "block branching to the loop header block"); 340 } 341 } 342 } 343 344 return success(); 345 } 346 347 Block *LoopOp::getEntryBlock() { 348 assert(!getBody().empty() && "op region should not be empty!"); 349 return &getBody().front(); 350 } 351 352 Block *LoopOp::getHeaderBlock() { 353 assert(!getBody().empty() && "op region should not be empty!"); 354 // The second block is the loop header block. 355 return &*std::next(getBody().begin()); 356 } 357 358 Block *LoopOp::getContinueBlock() { 359 assert(!getBody().empty() && "op region should not be empty!"); 360 // The second to last block is the loop continue block. 361 return &*std::prev(getBody().end(), 2); 362 } 363 364 Block *LoopOp::getMergeBlock() { 365 assert(!getBody().empty() && "op region should not be empty!"); 366 // The last block is the loop merge block. 367 return &getBody().back(); 368 } 369 370 void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) { 371 assert(getBody().empty() && "entry and merge block already exist"); 372 OpBuilder::InsertionGuard g(builder); 373 builder.createBlock(&getBody()); 374 builder.createBlock(&getBody()); 375 376 // Add a spirv.mlir.merge op into the merge block. 377 builder.create<spirv::MergeOp>(getLoc()); 378 } 379 380 //===----------------------------------------------------------------------===// 381 // spirv.mlir.merge 382 //===----------------------------------------------------------------------===// 383 384 LogicalResult MergeOp::verify() { 385 auto *parentOp = (*this)->getParentOp(); 386 if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp)) 387 return emitOpError( 388 "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'"); 389 390 // TODO: This check should be done in `verifyRegions` of parent op. 391 Block &parentLastBlock = (*this)->getParentRegion()->back(); 392 if (getOperation() != parentLastBlock.getTerminator()) 393 return emitOpError("can only be used in the last block of " 394 "'spirv.mlir.selection' or 'spirv.mlir.loop'"); 395 return success(); 396 } 397 398 //===----------------------------------------------------------------------===// 399 // spirv.Return 400 //===----------------------------------------------------------------------===// 401 402 LogicalResult ReturnOp::verify() { 403 // Verification is performed in spirv.func op. 404 return success(); 405 } 406 407 //===----------------------------------------------------------------------===// 408 // spirv.ReturnValue 409 //===----------------------------------------------------------------------===// 410 411 LogicalResult ReturnValueOp::verify() { 412 // Verification is performed in spirv.func op. 413 return success(); 414 } 415 416 //===----------------------------------------------------------------------===// 417 // spirv.Select 418 //===----------------------------------------------------------------------===// 419 420 LogicalResult SelectOp::verify() { 421 if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) { 422 auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType()); 423 if (!resultVectorTy) { 424 return emitOpError("result expected to be of vector type when " 425 "condition is of vector type"); 426 } 427 if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) { 428 return emitOpError("result should have the same number of elements as " 429 "the condition when condition is of vector type"); 430 } 431 } 432 return success(); 433 } 434 435 // Custom availability implementation is needed for spirv.Select given the 436 // syntax changes starting v1.4. 437 SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() { 438 return {}; 439 } 440 SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() { 441 return {}; 442 } 443 std::optional<spirv::Version> SelectOp::getMinVersion() { 444 // Per the spec, "Before version 1.4, results are only computed per 445 // component." 446 if (isa<spirv::ScalarType>(getCondition().getType()) && 447 isa<spirv::CompositeType>(getType())) 448 return Version::V_1_4; 449 450 return Version::V_1_0; 451 } 452 std::optional<spirv::Version> SelectOp::getMaxVersion() { 453 return Version::V_1_6; 454 } 455 456 //===----------------------------------------------------------------------===// 457 // spirv.mlir.selection 458 //===----------------------------------------------------------------------===// 459 460 ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) { 461 if (parseControlAttribute<spirv::SelectionControlAttr, 462 spirv::SelectionControl>(parser, result)) 463 return failure(); 464 return parser.parseRegion(*result.addRegion(), /*arguments=*/{}); 465 } 466 467 void SelectionOp::print(OpAsmPrinter &printer) { 468 auto control = getSelectionControl(); 469 if (control != spirv::SelectionControl::None) 470 printer << " control(" << spirv::stringifySelectionControl(control) << ")"; 471 printer << ' '; 472 printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 473 /*printBlockTerminators=*/true); 474 } 475 476 LogicalResult SelectionOp::verifyRegions() { 477 auto *op = getOperation(); 478 479 // We need to verify that the blocks follow the following layout: 480 // 481 // +--------------+ 482 // | header block | 483 // +--------------+ 484 // / | \ 485 // ... 486 // 487 // 488 // +---------+ +---------+ +---------+ 489 // | case #0 | | case #1 | | case #2 | ... 490 // +---------+ +---------+ +---------+ 491 // 492 // 493 // ... 494 // \ | / 495 // v 496 // +-------------+ 497 // | merge block | 498 // +-------------+ 499 500 auto ®ion = op->getRegion(0); 501 // Allow empty region as a degenerated case, which can come from 502 // optimizations. 503 if (region.empty()) 504 return success(); 505 506 // The last block is the merge block. 507 if (!isMergeBlock(region.back())) 508 return emitOpError("last block must be the merge block with only one " 509 "'spirv.mlir.merge' op"); 510 511 if (std::next(region.begin()) == region.end()) 512 return emitOpError("must have a selection header block"); 513 514 return success(); 515 } 516 517 Block *SelectionOp::getHeaderBlock() { 518 assert(!getBody().empty() && "op region should not be empty!"); 519 // The first block is the loop header block. 520 return &getBody().front(); 521 } 522 523 Block *SelectionOp::getMergeBlock() { 524 assert(!getBody().empty() && "op region should not be empty!"); 525 // The last block is the loop merge block. 526 return &getBody().back(); 527 } 528 529 void SelectionOp::addMergeBlock(OpBuilder &builder) { 530 assert(getBody().empty() && "entry and merge block already exist"); 531 OpBuilder::InsertionGuard guard(builder); 532 builder.createBlock(&getBody()); 533 534 // Add a spirv.mlir.merge op into the merge block. 535 builder.create<spirv::MergeOp>(getLoc()); 536 } 537 538 SelectionOp 539 SelectionOp::createIfThen(Location loc, Value condition, 540 function_ref<void(OpBuilder &builder)> thenBody, 541 OpBuilder &builder) { 542 auto selectionOp = 543 builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None); 544 545 selectionOp.addMergeBlock(builder); 546 Block *mergeBlock = selectionOp.getMergeBlock(); 547 Block *thenBlock = nullptr; 548 549 // Build the "then" block. 550 { 551 OpBuilder::InsertionGuard guard(builder); 552 thenBlock = builder.createBlock(mergeBlock); 553 thenBody(builder); 554 builder.create<spirv::BranchOp>(loc, mergeBlock); 555 } 556 557 // Build the header block. 558 { 559 OpBuilder::InsertionGuard guard(builder); 560 builder.createBlock(thenBlock); 561 builder.create<spirv::BranchConditionalOp>( 562 loc, condition, thenBlock, 563 /*trueArguments=*/ArrayRef<Value>(), mergeBlock, 564 /*falseArguments=*/ArrayRef<Value>()); 565 } 566 567 return selectionOp; 568 } 569 570 //===----------------------------------------------------------------------===// 571 // spirv.Unreachable 572 //===----------------------------------------------------------------------===// 573 574 LogicalResult spirv::UnreachableOp::verify() { 575 auto *block = (*this)->getBlock(); 576 // Fast track: if this is in entry block, its invalid. Otherwise, if no 577 // predecessors, it's valid. 578 if (block->isEntryBlock()) 579 return emitOpError("cannot be used in reachable block"); 580 if (block->hasNoPredecessors()) 581 return success(); 582 583 // TODO: further verification needs to analyze reachability from 584 // the entry block. 585 586 return success(); 587 } 588 589 } // namespace mlir::spirv 590