181b4e7d2Svarconst //===- ControlFlowOps.cpp - MLIR SPIR-V Control Flow Ops -----------------===// 281b4e7d2Svarconst // 381b4e7d2Svarconst // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 481b4e7d2Svarconst // See https://llvm.org/LICENSE.txt for license information. 581b4e7d2Svarconst // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 681b4e7d2Svarconst // 781b4e7d2Svarconst //===----------------------------------------------------------------------===// 881b4e7d2Svarconst // 981b4e7d2Svarconst // Defines the control flow operations in the SPIR-V dialect. 1081b4e7d2Svarconst // 1181b4e7d2Svarconst //===----------------------------------------------------------------------===// 1281b4e7d2Svarconst 136d9eb31cSLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 1481b4e7d2Svarconst #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 156d9eb31cSLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 1681b4e7d2Svarconst #include "mlir/Interfaces/CallInterfaces.h" 1781b4e7d2Svarconst 1881b4e7d2Svarconst #include "SPIRVOpUtils.h" 1981b4e7d2Svarconst #include "SPIRVParsingUtils.h" 2081b4e7d2Svarconst 2181b4e7d2Svarconst using namespace mlir::spirv::AttrNames; 2281b4e7d2Svarconst 2381b4e7d2Svarconst namespace mlir::spirv { 2481b4e7d2Svarconst 2581b4e7d2Svarconst /// Parses Function, Selection and Loop control attributes. If no control is 2681b4e7d2Svarconst /// specified, "None" is used as a default. 2781b4e7d2Svarconst template <typename EnumAttrClass, typename EnumClass> 2881b4e7d2Svarconst static ParseResult 2981b4e7d2Svarconst parseControlAttribute(OpAsmParser &parser, OperationState &state, 3081b4e7d2Svarconst StringRef attrName = spirv::attributeName<EnumClass>()) { 3181b4e7d2Svarconst if (succeeded(parser.parseOptionalKeyword(kControl))) { 3281b4e7d2Svarconst EnumClass control; 3381b4e7d2Svarconst if (parser.parseLParen() || 3481b4e7d2Svarconst spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) || 3581b4e7d2Svarconst parser.parseRParen()) 3681b4e7d2Svarconst return failure(); 3781b4e7d2Svarconst return success(); 3881b4e7d2Svarconst } 3981b4e7d2Svarconst // Set control to "None" otherwise. 4081b4e7d2Svarconst Builder builder = parser.getBuilder(); 4181b4e7d2Svarconst state.addAttribute(attrName, 4281b4e7d2Svarconst builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0))); 4381b4e7d2Svarconst return success(); 4481b4e7d2Svarconst } 4581b4e7d2Svarconst 4681b4e7d2Svarconst //===----------------------------------------------------------------------===// 4781b4e7d2Svarconst // spirv.BranchOp 4881b4e7d2Svarconst //===----------------------------------------------------------------------===// 4981b4e7d2Svarconst 5081b4e7d2Svarconst SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) { 5181b4e7d2Svarconst assert(index == 0 && "invalid successor index"); 5281b4e7d2Svarconst return SuccessorOperands(0, getTargetOperandsMutable()); 5381b4e7d2Svarconst } 5481b4e7d2Svarconst 5581b4e7d2Svarconst //===----------------------------------------------------------------------===// 5681b4e7d2Svarconst // spirv.BranchConditionalOp 5781b4e7d2Svarconst //===----------------------------------------------------------------------===// 5881b4e7d2Svarconst 5981b4e7d2Svarconst SuccessorOperands BranchConditionalOp::getSuccessorOperands(unsigned index) { 6081b4e7d2Svarconst assert(index < 2 && "invalid successor index"); 6181b4e7d2Svarconst return SuccessorOperands(index == kTrueIndex 6281b4e7d2Svarconst ? getTrueTargetOperandsMutable() 6381b4e7d2Svarconst : getFalseTargetOperandsMutable()); 6481b4e7d2Svarconst } 6581b4e7d2Svarconst 6681b4e7d2Svarconst ParseResult BranchConditionalOp::parse(OpAsmParser &parser, 6781b4e7d2Svarconst OperationState &result) { 6881b4e7d2Svarconst auto &builder = parser.getBuilder(); 6981b4e7d2Svarconst OpAsmParser::UnresolvedOperand condInfo; 7081b4e7d2Svarconst Block *dest; 7181b4e7d2Svarconst 7281b4e7d2Svarconst // Parse the condition. 7381b4e7d2Svarconst Type boolTy = builder.getI1Type(); 7481b4e7d2Svarconst if (parser.parseOperand(condInfo) || 7581b4e7d2Svarconst parser.resolveOperand(condInfo, boolTy, result.operands)) 7681b4e7d2Svarconst return failure(); 7781b4e7d2Svarconst 7881b4e7d2Svarconst // Parse the optional branch weights. 7981b4e7d2Svarconst if (succeeded(parser.parseOptionalLSquare())) { 8081b4e7d2Svarconst IntegerAttr trueWeight, falseWeight; 8181b4e7d2Svarconst NamedAttrList weights; 8281b4e7d2Svarconst 8381b4e7d2Svarconst auto i32Type = builder.getIntegerType(32); 8481b4e7d2Svarconst if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) || 8581b4e7d2Svarconst parser.parseComma() || 8681b4e7d2Svarconst parser.parseAttribute(falseWeight, i32Type, "weight", weights) || 8781b4e7d2Svarconst parser.parseRSquare()) 8881b4e7d2Svarconst return failure(); 8981b4e7d2Svarconst 901d5e3b2dStw-ilson StringAttr branchWeightsAttrName = 911d5e3b2dStw-ilson BranchConditionalOp::getBranchWeightsAttrName(result.name); 921d5e3b2dStw-ilson result.addAttribute(branchWeightsAttrName, 9381b4e7d2Svarconst builder.getArrayAttr({trueWeight, falseWeight})); 9481b4e7d2Svarconst } 9581b4e7d2Svarconst 9681b4e7d2Svarconst // Parse the true branch. 9781b4e7d2Svarconst SmallVector<Value, 4> trueOperands; 9881b4e7d2Svarconst if (parser.parseComma() || 9981b4e7d2Svarconst parser.parseSuccessorAndUseList(dest, trueOperands)) 10081b4e7d2Svarconst return failure(); 10181b4e7d2Svarconst result.addSuccessors(dest); 10281b4e7d2Svarconst result.addOperands(trueOperands); 10381b4e7d2Svarconst 10481b4e7d2Svarconst // Parse the false branch. 10581b4e7d2Svarconst SmallVector<Value, 4> falseOperands; 10681b4e7d2Svarconst if (parser.parseComma() || 10781b4e7d2Svarconst parser.parseSuccessorAndUseList(dest, falseOperands)) 10881b4e7d2Svarconst return failure(); 10981b4e7d2Svarconst result.addSuccessors(dest); 11081b4e7d2Svarconst result.addOperands(falseOperands); 11181b4e7d2Svarconst result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(), 11281b4e7d2Svarconst builder.getDenseI32ArrayAttr( 11381b4e7d2Svarconst {1, static_cast<int32_t>(trueOperands.size()), 11481b4e7d2Svarconst static_cast<int32_t>(falseOperands.size())})); 11581b4e7d2Svarconst 11681b4e7d2Svarconst return success(); 11781b4e7d2Svarconst } 11881b4e7d2Svarconst 11981b4e7d2Svarconst void BranchConditionalOp::print(OpAsmPrinter &printer) { 12081b4e7d2Svarconst printer << ' ' << getCondition(); 12181b4e7d2Svarconst 12281b4e7d2Svarconst if (auto weights = getBranchWeights()) { 12381b4e7d2Svarconst printer << " ["; 12481b4e7d2Svarconst llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) { 12581b4e7d2Svarconst printer << llvm::cast<IntegerAttr>(a).getInt(); 12681b4e7d2Svarconst }); 12781b4e7d2Svarconst printer << "]"; 12881b4e7d2Svarconst } 12981b4e7d2Svarconst 13081b4e7d2Svarconst printer << ", "; 13181b4e7d2Svarconst printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments()); 13281b4e7d2Svarconst printer << ", "; 13381b4e7d2Svarconst printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments()); 13481b4e7d2Svarconst } 13581b4e7d2Svarconst 13681b4e7d2Svarconst LogicalResult BranchConditionalOp::verify() { 13781b4e7d2Svarconst if (auto weights = getBranchWeights()) { 13881b4e7d2Svarconst if (weights->getValue().size() != 2) { 13981b4e7d2Svarconst return emitOpError("must have exactly two branch weights"); 14081b4e7d2Svarconst } 14181b4e7d2Svarconst if (llvm::all_of(*weights, [](Attribute attr) { 14281b4e7d2Svarconst return llvm::cast<IntegerAttr>(attr).getValue().isZero(); 14381b4e7d2Svarconst })) 14481b4e7d2Svarconst return emitOpError("branch weights cannot both be zero"); 14581b4e7d2Svarconst } 14681b4e7d2Svarconst 14781b4e7d2Svarconst return success(); 14881b4e7d2Svarconst } 14981b4e7d2Svarconst 15081b4e7d2Svarconst //===----------------------------------------------------------------------===// 15181b4e7d2Svarconst // spirv.FunctionCall 15281b4e7d2Svarconst //===----------------------------------------------------------------------===// 15381b4e7d2Svarconst 15481b4e7d2Svarconst LogicalResult FunctionCallOp::verify() { 15581b4e7d2Svarconst auto fnName = getCalleeAttr(); 15681b4e7d2Svarconst 15781b4e7d2Svarconst auto funcOp = dyn_cast_or_null<spirv::FuncOp>( 15881b4e7d2Svarconst SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName)); 15981b4e7d2Svarconst if (!funcOp) { 16081b4e7d2Svarconst return emitOpError("callee function '") 16181b4e7d2Svarconst << fnName.getValue() << "' not found in nearest symbol table"; 16281b4e7d2Svarconst } 16381b4e7d2Svarconst 16481b4e7d2Svarconst auto functionType = funcOp.getFunctionType(); 16581b4e7d2Svarconst 16681b4e7d2Svarconst if (getNumResults() > 1) { 16781b4e7d2Svarconst return emitOpError( 16881b4e7d2Svarconst "expected callee function to have 0 or 1 result, but provided ") 16981b4e7d2Svarconst << getNumResults(); 17081b4e7d2Svarconst } 17181b4e7d2Svarconst 17281b4e7d2Svarconst if (functionType.getNumInputs() != getNumOperands()) { 17381b4e7d2Svarconst return emitOpError("has incorrect number of operands for callee: expected ") 17481b4e7d2Svarconst << functionType.getNumInputs() << ", but provided " 17581b4e7d2Svarconst << getNumOperands(); 17681b4e7d2Svarconst } 17781b4e7d2Svarconst 17881b4e7d2Svarconst for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { 17981b4e7d2Svarconst if (getOperand(i).getType() != functionType.getInput(i)) { 18081b4e7d2Svarconst return emitOpError("operand type mismatch: expected operand type ") 18181b4e7d2Svarconst << functionType.getInput(i) << ", but provided " 18281b4e7d2Svarconst << getOperand(i).getType() << " for operand number " << i; 18381b4e7d2Svarconst } 18481b4e7d2Svarconst } 18581b4e7d2Svarconst 18681b4e7d2Svarconst if (functionType.getNumResults() != getNumResults()) { 18781b4e7d2Svarconst return emitOpError( 18881b4e7d2Svarconst "has incorrect number of results has for callee: expected ") 18981b4e7d2Svarconst << functionType.getNumResults() << ", but provided " 19081b4e7d2Svarconst << getNumResults(); 19181b4e7d2Svarconst } 19281b4e7d2Svarconst 19381b4e7d2Svarconst if (getNumResults() && 19481b4e7d2Svarconst (getResult(0).getType() != functionType.getResult(0))) { 19581b4e7d2Svarconst return emitOpError("result type mismatch: expected ") 19681b4e7d2Svarconst << functionType.getResult(0) << ", but provided " 19781b4e7d2Svarconst << getResult(0).getType(); 19881b4e7d2Svarconst } 19981b4e7d2Svarconst 20081b4e7d2Svarconst return success(); 20181b4e7d2Svarconst } 20281b4e7d2Svarconst 20381b4e7d2Svarconst CallInterfaceCallable FunctionCallOp::getCallableForCallee() { 2041d5e3b2dStw-ilson return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName()); 20581b4e7d2Svarconst } 20681b4e7d2Svarconst 20781b4e7d2Svarconst void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { 208*129f1001SKazu Hirata (*this)->setAttr(getCalleeAttrName(), cast<SymbolRefAttr>(callee)); 20981b4e7d2Svarconst } 21081b4e7d2Svarconst 21181b4e7d2Svarconst Operation::operand_range FunctionCallOp::getArgOperands() { 21281b4e7d2Svarconst return getArguments(); 21381b4e7d2Svarconst } 21481b4e7d2Svarconst 215d790a217SMartin Erhart MutableOperandRange FunctionCallOp::getArgOperandsMutable() { 216d790a217SMartin Erhart return getArgumentsMutable(); 217d790a217SMartin Erhart } 218d790a217SMartin Erhart 21981b4e7d2Svarconst //===----------------------------------------------------------------------===// 22081b4e7d2Svarconst // spirv.mlir.loop 22181b4e7d2Svarconst //===----------------------------------------------------------------------===// 22281b4e7d2Svarconst 22381b4e7d2Svarconst void LoopOp::build(OpBuilder &builder, OperationState &state) { 22481b4e7d2Svarconst state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>( 22581b4e7d2Svarconst spirv::LoopControl::None)); 22681b4e7d2Svarconst state.addRegion(); 22781b4e7d2Svarconst } 22881b4e7d2Svarconst 22981b4e7d2Svarconst ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) { 23081b4e7d2Svarconst if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser, 23181b4e7d2Svarconst result)) 23281b4e7d2Svarconst return failure(); 23381b4e7d2Svarconst return parser.parseRegion(*result.addRegion(), /*arguments=*/{}); 23481b4e7d2Svarconst } 23581b4e7d2Svarconst 23681b4e7d2Svarconst void LoopOp::print(OpAsmPrinter &printer) { 23781b4e7d2Svarconst auto control = getLoopControl(); 23881b4e7d2Svarconst if (control != spirv::LoopControl::None) 23981b4e7d2Svarconst printer << " control(" << spirv::stringifyLoopControl(control) << ")"; 24081b4e7d2Svarconst printer << ' '; 24181b4e7d2Svarconst printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 24281b4e7d2Svarconst /*printBlockTerminators=*/true); 24381b4e7d2Svarconst } 24481b4e7d2Svarconst 24581b4e7d2Svarconst /// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the 24681b4e7d2Svarconst /// given `dstBlock`. 24781b4e7d2Svarconst static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) { 24881b4e7d2Svarconst // Check that there is only one op in the `srcBlock`. 24981b4e7d2Svarconst if (!llvm::hasSingleElement(srcBlock)) 25081b4e7d2Svarconst return false; 25181b4e7d2Svarconst 25281b4e7d2Svarconst auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back()); 25381b4e7d2Svarconst return branchOp && branchOp.getSuccessor() == &dstBlock; 25481b4e7d2Svarconst } 25581b4e7d2Svarconst 25681b4e7d2Svarconst /// Returns true if the given `block` only contains one `spirv.mlir.merge` op. 25781b4e7d2Svarconst static bool isMergeBlock(Block &block) { 25881b4e7d2Svarconst return !block.empty() && std::next(block.begin()) == block.end() && 25981b4e7d2Svarconst isa<spirv::MergeOp>(block.front()); 26081b4e7d2Svarconst } 26181b4e7d2Svarconst 26281b4e7d2Svarconst LogicalResult LoopOp::verifyRegions() { 26381b4e7d2Svarconst auto *op = getOperation(); 26481b4e7d2Svarconst 26581b4e7d2Svarconst // We need to verify that the blocks follow the following layout: 26681b4e7d2Svarconst // 26781b4e7d2Svarconst // +-------------+ 26881b4e7d2Svarconst // | entry block | 26981b4e7d2Svarconst // +-------------+ 27081b4e7d2Svarconst // | 27181b4e7d2Svarconst // v 27281b4e7d2Svarconst // +-------------+ 27381b4e7d2Svarconst // | loop header | <-----+ 27481b4e7d2Svarconst // +-------------+ | 27581b4e7d2Svarconst // | 27681b4e7d2Svarconst // ... | 27781b4e7d2Svarconst // \ | / | 27881b4e7d2Svarconst // v | 27981b4e7d2Svarconst // +---------------+ | 28081b4e7d2Svarconst // | loop continue | -----+ 28181b4e7d2Svarconst // +---------------+ 28281b4e7d2Svarconst // 28381b4e7d2Svarconst // ... 28481b4e7d2Svarconst // \ | / 28581b4e7d2Svarconst // v 28681b4e7d2Svarconst // +-------------+ 28781b4e7d2Svarconst // | merge block | 28881b4e7d2Svarconst // +-------------+ 28981b4e7d2Svarconst 29081b4e7d2Svarconst auto ®ion = op->getRegion(0); 29181b4e7d2Svarconst // Allow empty region as a degenerated case, which can come from 29281b4e7d2Svarconst // optimizations. 29381b4e7d2Svarconst if (region.empty()) 29481b4e7d2Svarconst return success(); 29581b4e7d2Svarconst 29681b4e7d2Svarconst // The last block is the merge block. 29781b4e7d2Svarconst Block &merge = region.back(); 29881b4e7d2Svarconst if (!isMergeBlock(merge)) 29981b4e7d2Svarconst return emitOpError("last block must be the merge block with only one " 30081b4e7d2Svarconst "'spirv.mlir.merge' op"); 30181b4e7d2Svarconst 30281b4e7d2Svarconst if (std::next(region.begin()) == region.end()) 30381b4e7d2Svarconst return emitOpError( 30481b4e7d2Svarconst "must have an entry block branching to the loop header block"); 30581b4e7d2Svarconst // The first block is the entry block. 30681b4e7d2Svarconst Block &entry = region.front(); 30781b4e7d2Svarconst 30881b4e7d2Svarconst if (std::next(region.begin(), 2) == region.end()) 30981b4e7d2Svarconst return emitOpError( 31081b4e7d2Svarconst "must have a loop header block branched from the entry block"); 31181b4e7d2Svarconst // The second block is the loop header block. 31281b4e7d2Svarconst Block &header = *std::next(region.begin(), 1); 31381b4e7d2Svarconst 31481b4e7d2Svarconst if (!hasOneBranchOpTo(entry, header)) 31581b4e7d2Svarconst return emitOpError( 31681b4e7d2Svarconst "entry block must only have one 'spirv.Branch' op to the second block"); 31781b4e7d2Svarconst 31881b4e7d2Svarconst if (std::next(region.begin(), 3) == region.end()) 31981b4e7d2Svarconst return emitOpError( 32081b4e7d2Svarconst "requires a loop continue block branching to the loop header block"); 32181b4e7d2Svarconst // The second to last block is the loop continue block. 32281b4e7d2Svarconst Block &cont = *std::prev(region.end(), 2); 32381b4e7d2Svarconst 32481b4e7d2Svarconst // Make sure that we have a branch from the loop continue block to the loop 32581b4e7d2Svarconst // header block. 32681b4e7d2Svarconst if (llvm::none_of( 32781b4e7d2Svarconst llvm::seq<unsigned>(0, cont.getNumSuccessors()), 32881b4e7d2Svarconst [&](unsigned index) { return cont.getSuccessor(index) == &header; })) 32981b4e7d2Svarconst return emitOpError("second to last block must be the loop continue " 33081b4e7d2Svarconst "block that branches to the loop header block"); 33181b4e7d2Svarconst 33281b4e7d2Svarconst // Make sure that no other blocks (except the entry and loop continue block) 33381b4e7d2Svarconst // branches to the loop header block. 33481b4e7d2Svarconst for (auto &block : llvm::make_range(std::next(region.begin(), 2), 33581b4e7d2Svarconst std::prev(region.end(), 2))) { 33681b4e7d2Svarconst for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) { 33781b4e7d2Svarconst if (block.getSuccessor(i) == &header) { 33881b4e7d2Svarconst return emitOpError("can only have the entry and loop continue " 33981b4e7d2Svarconst "block branching to the loop header block"); 34081b4e7d2Svarconst } 34181b4e7d2Svarconst } 34281b4e7d2Svarconst } 34381b4e7d2Svarconst 34481b4e7d2Svarconst return success(); 34581b4e7d2Svarconst } 34681b4e7d2Svarconst 34781b4e7d2Svarconst Block *LoopOp::getEntryBlock() { 34881b4e7d2Svarconst assert(!getBody().empty() && "op region should not be empty!"); 34981b4e7d2Svarconst return &getBody().front(); 35081b4e7d2Svarconst } 35181b4e7d2Svarconst 35281b4e7d2Svarconst Block *LoopOp::getHeaderBlock() { 35381b4e7d2Svarconst assert(!getBody().empty() && "op region should not be empty!"); 35481b4e7d2Svarconst // The second block is the loop header block. 35581b4e7d2Svarconst return &*std::next(getBody().begin()); 35681b4e7d2Svarconst } 35781b4e7d2Svarconst 35881b4e7d2Svarconst Block *LoopOp::getContinueBlock() { 35981b4e7d2Svarconst assert(!getBody().empty() && "op region should not be empty!"); 36081b4e7d2Svarconst // The second to last block is the loop continue block. 36181b4e7d2Svarconst return &*std::prev(getBody().end(), 2); 36281b4e7d2Svarconst } 36381b4e7d2Svarconst 36481b4e7d2Svarconst Block *LoopOp::getMergeBlock() { 36581b4e7d2Svarconst assert(!getBody().empty() && "op region should not be empty!"); 36681b4e7d2Svarconst // The last block is the loop merge block. 36781b4e7d2Svarconst return &getBody().back(); 36881b4e7d2Svarconst } 36981b4e7d2Svarconst 37091d5653eSMatthias Springer void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) { 37181b4e7d2Svarconst assert(getBody().empty() && "entry and merge block already exist"); 37291d5653eSMatthias Springer OpBuilder::InsertionGuard g(builder); 37391d5653eSMatthias Springer builder.createBlock(&getBody()); 37491d5653eSMatthias Springer builder.createBlock(&getBody()); 37581b4e7d2Svarconst 37681b4e7d2Svarconst // Add a spirv.mlir.merge op into the merge block. 37781b4e7d2Svarconst builder.create<spirv::MergeOp>(getLoc()); 37881b4e7d2Svarconst } 37981b4e7d2Svarconst 38081b4e7d2Svarconst //===----------------------------------------------------------------------===// 38181b4e7d2Svarconst // spirv.mlir.merge 38281b4e7d2Svarconst //===----------------------------------------------------------------------===// 38381b4e7d2Svarconst 38481b4e7d2Svarconst LogicalResult MergeOp::verify() { 38581b4e7d2Svarconst auto *parentOp = (*this)->getParentOp(); 38681b4e7d2Svarconst if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp)) 38781b4e7d2Svarconst return emitOpError( 38881b4e7d2Svarconst "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'"); 38981b4e7d2Svarconst 39081b4e7d2Svarconst // TODO: This check should be done in `verifyRegions` of parent op. 39181b4e7d2Svarconst Block &parentLastBlock = (*this)->getParentRegion()->back(); 39281b4e7d2Svarconst if (getOperation() != parentLastBlock.getTerminator()) 39381b4e7d2Svarconst return emitOpError("can only be used in the last block of " 39481b4e7d2Svarconst "'spirv.mlir.selection' or 'spirv.mlir.loop'"); 39581b4e7d2Svarconst return success(); 39681b4e7d2Svarconst } 39781b4e7d2Svarconst 39881b4e7d2Svarconst //===----------------------------------------------------------------------===// 39981b4e7d2Svarconst // spirv.Return 40081b4e7d2Svarconst //===----------------------------------------------------------------------===// 40181b4e7d2Svarconst 40281b4e7d2Svarconst LogicalResult ReturnOp::verify() { 40381b4e7d2Svarconst // Verification is performed in spirv.func op. 40481b4e7d2Svarconst return success(); 40581b4e7d2Svarconst } 40681b4e7d2Svarconst 40781b4e7d2Svarconst //===----------------------------------------------------------------------===// 40881b4e7d2Svarconst // spirv.ReturnValue 40981b4e7d2Svarconst //===----------------------------------------------------------------------===// 41081b4e7d2Svarconst 41181b4e7d2Svarconst LogicalResult ReturnValueOp::verify() { 41281b4e7d2Svarconst // Verification is performed in spirv.func op. 41381b4e7d2Svarconst return success(); 41481b4e7d2Svarconst } 41581b4e7d2Svarconst 41681b4e7d2Svarconst //===----------------------------------------------------------------------===// 41781b4e7d2Svarconst // spirv.Select 41881b4e7d2Svarconst //===----------------------------------------------------------------------===// 41981b4e7d2Svarconst 42081b4e7d2Svarconst LogicalResult SelectOp::verify() { 42181b4e7d2Svarconst if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) { 42281b4e7d2Svarconst auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType()); 42381b4e7d2Svarconst if (!resultVectorTy) { 42481b4e7d2Svarconst return emitOpError("result expected to be of vector type when " 42581b4e7d2Svarconst "condition is of vector type"); 42681b4e7d2Svarconst } 42781b4e7d2Svarconst if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) { 42881b4e7d2Svarconst return emitOpError("result should have the same number of elements as " 42981b4e7d2Svarconst "the condition when condition is of vector type"); 43081b4e7d2Svarconst } 43181b4e7d2Svarconst } 43281b4e7d2Svarconst return success(); 43381b4e7d2Svarconst } 43481b4e7d2Svarconst 4356d9eb31cSLei Zhang // Custom availability implementation is needed for spirv.Select given the 4366d9eb31cSLei Zhang // syntax changes starting v1.4. 4376d9eb31cSLei Zhang SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() { 4386d9eb31cSLei Zhang return {}; 4396d9eb31cSLei Zhang } 4406d9eb31cSLei Zhang SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() { 4416d9eb31cSLei Zhang return {}; 4426d9eb31cSLei Zhang } 4436d9eb31cSLei Zhang std::optional<spirv::Version> SelectOp::getMinVersion() { 4446d9eb31cSLei Zhang // Per the spec, "Before version 1.4, results are only computed per 4456d9eb31cSLei Zhang // component." 4466d9eb31cSLei Zhang if (isa<spirv::ScalarType>(getCondition().getType()) && 4476d9eb31cSLei Zhang isa<spirv::CompositeType>(getType())) 4486d9eb31cSLei Zhang return Version::V_1_4; 4496d9eb31cSLei Zhang 4506d9eb31cSLei Zhang return Version::V_1_0; 4516d9eb31cSLei Zhang } 4526d9eb31cSLei Zhang std::optional<spirv::Version> SelectOp::getMaxVersion() { 4536d9eb31cSLei Zhang return Version::V_1_6; 4546d9eb31cSLei Zhang } 4556d9eb31cSLei Zhang 45681b4e7d2Svarconst //===----------------------------------------------------------------------===// 45781b4e7d2Svarconst // spirv.mlir.selection 45881b4e7d2Svarconst //===----------------------------------------------------------------------===// 45981b4e7d2Svarconst 46081b4e7d2Svarconst ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) { 46181b4e7d2Svarconst if (parseControlAttribute<spirv::SelectionControlAttr, 46281b4e7d2Svarconst spirv::SelectionControl>(parser, result)) 46381b4e7d2Svarconst return failure(); 46481b4e7d2Svarconst return parser.parseRegion(*result.addRegion(), /*arguments=*/{}); 46581b4e7d2Svarconst } 46681b4e7d2Svarconst 46781b4e7d2Svarconst void SelectionOp::print(OpAsmPrinter &printer) { 46881b4e7d2Svarconst auto control = getSelectionControl(); 46981b4e7d2Svarconst if (control != spirv::SelectionControl::None) 47081b4e7d2Svarconst printer << " control(" << spirv::stringifySelectionControl(control) << ")"; 47181b4e7d2Svarconst printer << ' '; 47281b4e7d2Svarconst printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 47381b4e7d2Svarconst /*printBlockTerminators=*/true); 47481b4e7d2Svarconst } 47581b4e7d2Svarconst 47681b4e7d2Svarconst LogicalResult SelectionOp::verifyRegions() { 47781b4e7d2Svarconst auto *op = getOperation(); 47881b4e7d2Svarconst 47981b4e7d2Svarconst // We need to verify that the blocks follow the following layout: 48081b4e7d2Svarconst // 48181b4e7d2Svarconst // +--------------+ 48281b4e7d2Svarconst // | header block | 48381b4e7d2Svarconst // +--------------+ 48481b4e7d2Svarconst // / | \ 48581b4e7d2Svarconst // ... 48681b4e7d2Svarconst // 48781b4e7d2Svarconst // 48881b4e7d2Svarconst // +---------+ +---------+ +---------+ 48981b4e7d2Svarconst // | case #0 | | case #1 | | case #2 | ... 49081b4e7d2Svarconst // +---------+ +---------+ +---------+ 49181b4e7d2Svarconst // 49281b4e7d2Svarconst // 49381b4e7d2Svarconst // ... 49481b4e7d2Svarconst // \ | / 49581b4e7d2Svarconst // v 49681b4e7d2Svarconst // +-------------+ 49781b4e7d2Svarconst // | merge block | 49881b4e7d2Svarconst // +-------------+ 49981b4e7d2Svarconst 50081b4e7d2Svarconst auto ®ion = op->getRegion(0); 50181b4e7d2Svarconst // Allow empty region as a degenerated case, which can come from 50281b4e7d2Svarconst // optimizations. 50381b4e7d2Svarconst if (region.empty()) 50481b4e7d2Svarconst return success(); 50581b4e7d2Svarconst 50681b4e7d2Svarconst // The last block is the merge block. 50781b4e7d2Svarconst if (!isMergeBlock(region.back())) 50881b4e7d2Svarconst return emitOpError("last block must be the merge block with only one " 50981b4e7d2Svarconst "'spirv.mlir.merge' op"); 51081b4e7d2Svarconst 51181b4e7d2Svarconst if (std::next(region.begin()) == region.end()) 51281b4e7d2Svarconst return emitOpError("must have a selection header block"); 51381b4e7d2Svarconst 51481b4e7d2Svarconst return success(); 51581b4e7d2Svarconst } 51681b4e7d2Svarconst 51781b4e7d2Svarconst Block *SelectionOp::getHeaderBlock() { 51881b4e7d2Svarconst assert(!getBody().empty() && "op region should not be empty!"); 51981b4e7d2Svarconst // The first block is the loop header block. 52081b4e7d2Svarconst return &getBody().front(); 52181b4e7d2Svarconst } 52281b4e7d2Svarconst 52381b4e7d2Svarconst Block *SelectionOp::getMergeBlock() { 52481b4e7d2Svarconst assert(!getBody().empty() && "op region should not be empty!"); 52581b4e7d2Svarconst // The last block is the loop merge block. 52681b4e7d2Svarconst return &getBody().back(); 52781b4e7d2Svarconst } 52881b4e7d2Svarconst 52991d5653eSMatthias Springer void SelectionOp::addMergeBlock(OpBuilder &builder) { 53081b4e7d2Svarconst assert(getBody().empty() && "entry and merge block already exist"); 53191d5653eSMatthias Springer OpBuilder::InsertionGuard guard(builder); 53291d5653eSMatthias Springer builder.createBlock(&getBody()); 53381b4e7d2Svarconst 53481b4e7d2Svarconst // Add a spirv.mlir.merge op into the merge block. 53581b4e7d2Svarconst builder.create<spirv::MergeOp>(getLoc()); 53681b4e7d2Svarconst } 53781b4e7d2Svarconst 53881b4e7d2Svarconst SelectionOp 53981b4e7d2Svarconst SelectionOp::createIfThen(Location loc, Value condition, 54081b4e7d2Svarconst function_ref<void(OpBuilder &builder)> thenBody, 54181b4e7d2Svarconst OpBuilder &builder) { 54281b4e7d2Svarconst auto selectionOp = 54381b4e7d2Svarconst builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None); 54481b4e7d2Svarconst 54591d5653eSMatthias Springer selectionOp.addMergeBlock(builder); 54681b4e7d2Svarconst Block *mergeBlock = selectionOp.getMergeBlock(); 54781b4e7d2Svarconst Block *thenBlock = nullptr; 54881b4e7d2Svarconst 54981b4e7d2Svarconst // Build the "then" block. 55081b4e7d2Svarconst { 55181b4e7d2Svarconst OpBuilder::InsertionGuard guard(builder); 55281b4e7d2Svarconst thenBlock = builder.createBlock(mergeBlock); 55381b4e7d2Svarconst thenBody(builder); 55481b4e7d2Svarconst builder.create<spirv::BranchOp>(loc, mergeBlock); 55581b4e7d2Svarconst } 55681b4e7d2Svarconst 55781b4e7d2Svarconst // Build the header block. 55881b4e7d2Svarconst { 55981b4e7d2Svarconst OpBuilder::InsertionGuard guard(builder); 56081b4e7d2Svarconst builder.createBlock(thenBlock); 56181b4e7d2Svarconst builder.create<spirv::BranchConditionalOp>( 56281b4e7d2Svarconst loc, condition, thenBlock, 56381b4e7d2Svarconst /*trueArguments=*/ArrayRef<Value>(), mergeBlock, 56481b4e7d2Svarconst /*falseArguments=*/ArrayRef<Value>()); 56581b4e7d2Svarconst } 56681b4e7d2Svarconst 56781b4e7d2Svarconst return selectionOp; 56881b4e7d2Svarconst } 56981b4e7d2Svarconst 57081b4e7d2Svarconst //===----------------------------------------------------------------------===// 57181b4e7d2Svarconst // spirv.Unreachable 57281b4e7d2Svarconst //===----------------------------------------------------------------------===// 57381b4e7d2Svarconst 57481b4e7d2Svarconst LogicalResult spirv::UnreachableOp::verify() { 57581b4e7d2Svarconst auto *block = (*this)->getBlock(); 57681b4e7d2Svarconst // Fast track: if this is in entry block, its invalid. Otherwise, if no 57781b4e7d2Svarconst // predecessors, it's valid. 57881b4e7d2Svarconst if (block->isEntryBlock()) 57981b4e7d2Svarconst return emitOpError("cannot be used in reachable block"); 58081b4e7d2Svarconst if (block->hasNoPredecessors()) 58181b4e7d2Svarconst return success(); 58281b4e7d2Svarconst 58381b4e7d2Svarconst // TODO: further verification needs to analyze reachability from 58481b4e7d2Svarconst // the entry block. 58581b4e7d2Svarconst 58681b4e7d2Svarconst return success(); 58781b4e7d2Svarconst } 58881b4e7d2Svarconst 58981b4e7d2Svarconst } // namespace mlir::spirv 590