xref: /llvm-project/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp (revision 129f1001c3b1b5200de43917d53c0efbdf08f11f)
181b4e7d2Svarconst //===- ControlFlowOps.cpp - MLIR SPIR-V Control Flow Ops  -----------------===//
281b4e7d2Svarconst //
381b4e7d2Svarconst // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
481b4e7d2Svarconst // See https://llvm.org/LICENSE.txt for license information.
581b4e7d2Svarconst // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
681b4e7d2Svarconst //
781b4e7d2Svarconst //===----------------------------------------------------------------------===//
881b4e7d2Svarconst //
981b4e7d2Svarconst // Defines the control flow operations in the SPIR-V dialect.
1081b4e7d2Svarconst //
1181b4e7d2Svarconst //===----------------------------------------------------------------------===//
1281b4e7d2Svarconst 
136d9eb31cSLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
1481b4e7d2Svarconst #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
156d9eb31cSLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1681b4e7d2Svarconst #include "mlir/Interfaces/CallInterfaces.h"
1781b4e7d2Svarconst 
1881b4e7d2Svarconst #include "SPIRVOpUtils.h"
1981b4e7d2Svarconst #include "SPIRVParsingUtils.h"
2081b4e7d2Svarconst 
2181b4e7d2Svarconst using namespace mlir::spirv::AttrNames;
2281b4e7d2Svarconst 
2381b4e7d2Svarconst namespace mlir::spirv {
2481b4e7d2Svarconst 
2581b4e7d2Svarconst /// Parses Function, Selection and Loop control attributes. If no control is
2681b4e7d2Svarconst /// specified, "None" is used as a default.
2781b4e7d2Svarconst template <typename EnumAttrClass, typename EnumClass>
2881b4e7d2Svarconst static ParseResult
2981b4e7d2Svarconst parseControlAttribute(OpAsmParser &parser, OperationState &state,
3081b4e7d2Svarconst                       StringRef attrName = spirv::attributeName<EnumClass>()) {
3181b4e7d2Svarconst   if (succeeded(parser.parseOptionalKeyword(kControl))) {
3281b4e7d2Svarconst     EnumClass control;
3381b4e7d2Svarconst     if (parser.parseLParen() ||
3481b4e7d2Svarconst         spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
3581b4e7d2Svarconst         parser.parseRParen())
3681b4e7d2Svarconst       return failure();
3781b4e7d2Svarconst     return success();
3881b4e7d2Svarconst   }
3981b4e7d2Svarconst   // Set control to "None" otherwise.
4081b4e7d2Svarconst   Builder builder = parser.getBuilder();
4181b4e7d2Svarconst   state.addAttribute(attrName,
4281b4e7d2Svarconst                      builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
4381b4e7d2Svarconst   return success();
4481b4e7d2Svarconst }
4581b4e7d2Svarconst 
4681b4e7d2Svarconst //===----------------------------------------------------------------------===//
4781b4e7d2Svarconst // spirv.BranchOp
4881b4e7d2Svarconst //===----------------------------------------------------------------------===//
4981b4e7d2Svarconst 
5081b4e7d2Svarconst SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
5181b4e7d2Svarconst   assert(index == 0 && "invalid successor index");
5281b4e7d2Svarconst   return SuccessorOperands(0, getTargetOperandsMutable());
5381b4e7d2Svarconst }
5481b4e7d2Svarconst 
5581b4e7d2Svarconst //===----------------------------------------------------------------------===//
5681b4e7d2Svarconst // spirv.BranchConditionalOp
5781b4e7d2Svarconst //===----------------------------------------------------------------------===//
5881b4e7d2Svarconst 
5981b4e7d2Svarconst SuccessorOperands BranchConditionalOp::getSuccessorOperands(unsigned index) {
6081b4e7d2Svarconst   assert(index < 2 && "invalid successor index");
6181b4e7d2Svarconst   return SuccessorOperands(index == kTrueIndex
6281b4e7d2Svarconst                                ? getTrueTargetOperandsMutable()
6381b4e7d2Svarconst                                : getFalseTargetOperandsMutable());
6481b4e7d2Svarconst }
6581b4e7d2Svarconst 
6681b4e7d2Svarconst ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
6781b4e7d2Svarconst                                        OperationState &result) {
6881b4e7d2Svarconst   auto &builder = parser.getBuilder();
6981b4e7d2Svarconst   OpAsmParser::UnresolvedOperand condInfo;
7081b4e7d2Svarconst   Block *dest;
7181b4e7d2Svarconst 
7281b4e7d2Svarconst   // Parse the condition.
7381b4e7d2Svarconst   Type boolTy = builder.getI1Type();
7481b4e7d2Svarconst   if (parser.parseOperand(condInfo) ||
7581b4e7d2Svarconst       parser.resolveOperand(condInfo, boolTy, result.operands))
7681b4e7d2Svarconst     return failure();
7781b4e7d2Svarconst 
7881b4e7d2Svarconst   // Parse the optional branch weights.
7981b4e7d2Svarconst   if (succeeded(parser.parseOptionalLSquare())) {
8081b4e7d2Svarconst     IntegerAttr trueWeight, falseWeight;
8181b4e7d2Svarconst     NamedAttrList weights;
8281b4e7d2Svarconst 
8381b4e7d2Svarconst     auto i32Type = builder.getIntegerType(32);
8481b4e7d2Svarconst     if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
8581b4e7d2Svarconst         parser.parseComma() ||
8681b4e7d2Svarconst         parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
8781b4e7d2Svarconst         parser.parseRSquare())
8881b4e7d2Svarconst       return failure();
8981b4e7d2Svarconst 
901d5e3b2dStw-ilson     StringAttr branchWeightsAttrName =
911d5e3b2dStw-ilson         BranchConditionalOp::getBranchWeightsAttrName(result.name);
921d5e3b2dStw-ilson     result.addAttribute(branchWeightsAttrName,
9381b4e7d2Svarconst                         builder.getArrayAttr({trueWeight, falseWeight}));
9481b4e7d2Svarconst   }
9581b4e7d2Svarconst 
9681b4e7d2Svarconst   // Parse the true branch.
9781b4e7d2Svarconst   SmallVector<Value, 4> trueOperands;
9881b4e7d2Svarconst   if (parser.parseComma() ||
9981b4e7d2Svarconst       parser.parseSuccessorAndUseList(dest, trueOperands))
10081b4e7d2Svarconst     return failure();
10181b4e7d2Svarconst   result.addSuccessors(dest);
10281b4e7d2Svarconst   result.addOperands(trueOperands);
10381b4e7d2Svarconst 
10481b4e7d2Svarconst   // Parse the false branch.
10581b4e7d2Svarconst   SmallVector<Value, 4> falseOperands;
10681b4e7d2Svarconst   if (parser.parseComma() ||
10781b4e7d2Svarconst       parser.parseSuccessorAndUseList(dest, falseOperands))
10881b4e7d2Svarconst     return failure();
10981b4e7d2Svarconst   result.addSuccessors(dest);
11081b4e7d2Svarconst   result.addOperands(falseOperands);
11181b4e7d2Svarconst   result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
11281b4e7d2Svarconst                       builder.getDenseI32ArrayAttr(
11381b4e7d2Svarconst                           {1, static_cast<int32_t>(trueOperands.size()),
11481b4e7d2Svarconst                            static_cast<int32_t>(falseOperands.size())}));
11581b4e7d2Svarconst 
11681b4e7d2Svarconst   return success();
11781b4e7d2Svarconst }
11881b4e7d2Svarconst 
11981b4e7d2Svarconst void BranchConditionalOp::print(OpAsmPrinter &printer) {
12081b4e7d2Svarconst   printer << ' ' << getCondition();
12181b4e7d2Svarconst 
12281b4e7d2Svarconst   if (auto weights = getBranchWeights()) {
12381b4e7d2Svarconst     printer << " [";
12481b4e7d2Svarconst     llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
12581b4e7d2Svarconst       printer << llvm::cast<IntegerAttr>(a).getInt();
12681b4e7d2Svarconst     });
12781b4e7d2Svarconst     printer << "]";
12881b4e7d2Svarconst   }
12981b4e7d2Svarconst 
13081b4e7d2Svarconst   printer << ", ";
13181b4e7d2Svarconst   printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
13281b4e7d2Svarconst   printer << ", ";
13381b4e7d2Svarconst   printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
13481b4e7d2Svarconst }
13581b4e7d2Svarconst 
13681b4e7d2Svarconst LogicalResult BranchConditionalOp::verify() {
13781b4e7d2Svarconst   if (auto weights = getBranchWeights()) {
13881b4e7d2Svarconst     if (weights->getValue().size() != 2) {
13981b4e7d2Svarconst       return emitOpError("must have exactly two branch weights");
14081b4e7d2Svarconst     }
14181b4e7d2Svarconst     if (llvm::all_of(*weights, [](Attribute attr) {
14281b4e7d2Svarconst           return llvm::cast<IntegerAttr>(attr).getValue().isZero();
14381b4e7d2Svarconst         }))
14481b4e7d2Svarconst       return emitOpError("branch weights cannot both be zero");
14581b4e7d2Svarconst   }
14681b4e7d2Svarconst 
14781b4e7d2Svarconst   return success();
14881b4e7d2Svarconst }
14981b4e7d2Svarconst 
15081b4e7d2Svarconst //===----------------------------------------------------------------------===//
15181b4e7d2Svarconst // spirv.FunctionCall
15281b4e7d2Svarconst //===----------------------------------------------------------------------===//
15381b4e7d2Svarconst 
15481b4e7d2Svarconst LogicalResult FunctionCallOp::verify() {
15581b4e7d2Svarconst   auto fnName = getCalleeAttr();
15681b4e7d2Svarconst 
15781b4e7d2Svarconst   auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
15881b4e7d2Svarconst       SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
15981b4e7d2Svarconst   if (!funcOp) {
16081b4e7d2Svarconst     return emitOpError("callee function '")
16181b4e7d2Svarconst            << fnName.getValue() << "' not found in nearest symbol table";
16281b4e7d2Svarconst   }
16381b4e7d2Svarconst 
16481b4e7d2Svarconst   auto functionType = funcOp.getFunctionType();
16581b4e7d2Svarconst 
16681b4e7d2Svarconst   if (getNumResults() > 1) {
16781b4e7d2Svarconst     return emitOpError(
16881b4e7d2Svarconst                "expected callee function to have 0 or 1 result, but provided ")
16981b4e7d2Svarconst            << getNumResults();
17081b4e7d2Svarconst   }
17181b4e7d2Svarconst 
17281b4e7d2Svarconst   if (functionType.getNumInputs() != getNumOperands()) {
17381b4e7d2Svarconst     return emitOpError("has incorrect number of operands for callee: expected ")
17481b4e7d2Svarconst            << functionType.getNumInputs() << ", but provided "
17581b4e7d2Svarconst            << getNumOperands();
17681b4e7d2Svarconst   }
17781b4e7d2Svarconst 
17881b4e7d2Svarconst   for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
17981b4e7d2Svarconst     if (getOperand(i).getType() != functionType.getInput(i)) {
18081b4e7d2Svarconst       return emitOpError("operand type mismatch: expected operand type ")
18181b4e7d2Svarconst              << functionType.getInput(i) << ", but provided "
18281b4e7d2Svarconst              << getOperand(i).getType() << " for operand number " << i;
18381b4e7d2Svarconst     }
18481b4e7d2Svarconst   }
18581b4e7d2Svarconst 
18681b4e7d2Svarconst   if (functionType.getNumResults() != getNumResults()) {
18781b4e7d2Svarconst     return emitOpError(
18881b4e7d2Svarconst                "has incorrect number of results has for callee: expected ")
18981b4e7d2Svarconst            << functionType.getNumResults() << ", but provided "
19081b4e7d2Svarconst            << getNumResults();
19181b4e7d2Svarconst   }
19281b4e7d2Svarconst 
19381b4e7d2Svarconst   if (getNumResults() &&
19481b4e7d2Svarconst       (getResult(0).getType() != functionType.getResult(0))) {
19581b4e7d2Svarconst     return emitOpError("result type mismatch: expected ")
19681b4e7d2Svarconst            << functionType.getResult(0) << ", but provided "
19781b4e7d2Svarconst            << getResult(0).getType();
19881b4e7d2Svarconst   }
19981b4e7d2Svarconst 
20081b4e7d2Svarconst   return success();
20181b4e7d2Svarconst }
20281b4e7d2Svarconst 
20381b4e7d2Svarconst CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
2041d5e3b2dStw-ilson   return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
20581b4e7d2Svarconst }
20681b4e7d2Svarconst 
20781b4e7d2Svarconst void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
208*129f1001SKazu Hirata   (*this)->setAttr(getCalleeAttrName(), cast<SymbolRefAttr>(callee));
20981b4e7d2Svarconst }
21081b4e7d2Svarconst 
21181b4e7d2Svarconst Operation::operand_range FunctionCallOp::getArgOperands() {
21281b4e7d2Svarconst   return getArguments();
21381b4e7d2Svarconst }
21481b4e7d2Svarconst 
215d790a217SMartin Erhart MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
216d790a217SMartin Erhart   return getArgumentsMutable();
217d790a217SMartin Erhart }
218d790a217SMartin Erhart 
21981b4e7d2Svarconst //===----------------------------------------------------------------------===//
22081b4e7d2Svarconst // spirv.mlir.loop
22181b4e7d2Svarconst //===----------------------------------------------------------------------===//
22281b4e7d2Svarconst 
22381b4e7d2Svarconst void LoopOp::build(OpBuilder &builder, OperationState &state) {
22481b4e7d2Svarconst   state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
22581b4e7d2Svarconst                                          spirv::LoopControl::None));
22681b4e7d2Svarconst   state.addRegion();
22781b4e7d2Svarconst }
22881b4e7d2Svarconst 
22981b4e7d2Svarconst ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
23081b4e7d2Svarconst   if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
23181b4e7d2Svarconst                                                                         result))
23281b4e7d2Svarconst     return failure();
23381b4e7d2Svarconst   return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
23481b4e7d2Svarconst }
23581b4e7d2Svarconst 
23681b4e7d2Svarconst void LoopOp::print(OpAsmPrinter &printer) {
23781b4e7d2Svarconst   auto control = getLoopControl();
23881b4e7d2Svarconst   if (control != spirv::LoopControl::None)
23981b4e7d2Svarconst     printer << " control(" << spirv::stringifyLoopControl(control) << ")";
24081b4e7d2Svarconst   printer << ' ';
24181b4e7d2Svarconst   printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
24281b4e7d2Svarconst                       /*printBlockTerminators=*/true);
24381b4e7d2Svarconst }
24481b4e7d2Svarconst 
24581b4e7d2Svarconst /// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
24681b4e7d2Svarconst /// given `dstBlock`.
24781b4e7d2Svarconst static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
24881b4e7d2Svarconst   // Check that there is only one op in the `srcBlock`.
24981b4e7d2Svarconst   if (!llvm::hasSingleElement(srcBlock))
25081b4e7d2Svarconst     return false;
25181b4e7d2Svarconst 
25281b4e7d2Svarconst   auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
25381b4e7d2Svarconst   return branchOp && branchOp.getSuccessor() == &dstBlock;
25481b4e7d2Svarconst }
25581b4e7d2Svarconst 
25681b4e7d2Svarconst /// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
25781b4e7d2Svarconst static bool isMergeBlock(Block &block) {
25881b4e7d2Svarconst   return !block.empty() && std::next(block.begin()) == block.end() &&
25981b4e7d2Svarconst          isa<spirv::MergeOp>(block.front());
26081b4e7d2Svarconst }
26181b4e7d2Svarconst 
26281b4e7d2Svarconst LogicalResult LoopOp::verifyRegions() {
26381b4e7d2Svarconst   auto *op = getOperation();
26481b4e7d2Svarconst 
26581b4e7d2Svarconst   // We need to verify that the blocks follow the following layout:
26681b4e7d2Svarconst   //
26781b4e7d2Svarconst   //                     +-------------+
26881b4e7d2Svarconst   //                     | entry block |
26981b4e7d2Svarconst   //                     +-------------+
27081b4e7d2Svarconst   //                            |
27181b4e7d2Svarconst   //                            v
27281b4e7d2Svarconst   //                     +-------------+
27381b4e7d2Svarconst   //                     | loop header | <-----+
27481b4e7d2Svarconst   //                     +-------------+       |
27581b4e7d2Svarconst   //                                           |
27681b4e7d2Svarconst   //                           ...             |
27781b4e7d2Svarconst   //                          \ | /            |
27881b4e7d2Svarconst   //                            v              |
27981b4e7d2Svarconst   //                    +---------------+      |
28081b4e7d2Svarconst   //                    | loop continue | -----+
28181b4e7d2Svarconst   //                    +---------------+
28281b4e7d2Svarconst   //
28381b4e7d2Svarconst   //                           ...
28481b4e7d2Svarconst   //                          \ | /
28581b4e7d2Svarconst   //                            v
28681b4e7d2Svarconst   //                     +-------------+
28781b4e7d2Svarconst   //                     | merge block |
28881b4e7d2Svarconst   //                     +-------------+
28981b4e7d2Svarconst 
29081b4e7d2Svarconst   auto &region = op->getRegion(0);
29181b4e7d2Svarconst   // Allow empty region as a degenerated case, which can come from
29281b4e7d2Svarconst   // optimizations.
29381b4e7d2Svarconst   if (region.empty())
29481b4e7d2Svarconst     return success();
29581b4e7d2Svarconst 
29681b4e7d2Svarconst   // The last block is the merge block.
29781b4e7d2Svarconst   Block &merge = region.back();
29881b4e7d2Svarconst   if (!isMergeBlock(merge))
29981b4e7d2Svarconst     return emitOpError("last block must be the merge block with only one "
30081b4e7d2Svarconst                        "'spirv.mlir.merge' op");
30181b4e7d2Svarconst 
30281b4e7d2Svarconst   if (std::next(region.begin()) == region.end())
30381b4e7d2Svarconst     return emitOpError(
30481b4e7d2Svarconst         "must have an entry block branching to the loop header block");
30581b4e7d2Svarconst   // The first block is the entry block.
30681b4e7d2Svarconst   Block &entry = region.front();
30781b4e7d2Svarconst 
30881b4e7d2Svarconst   if (std::next(region.begin(), 2) == region.end())
30981b4e7d2Svarconst     return emitOpError(
31081b4e7d2Svarconst         "must have a loop header block branched from the entry block");
31181b4e7d2Svarconst   // The second block is the loop header block.
31281b4e7d2Svarconst   Block &header = *std::next(region.begin(), 1);
31381b4e7d2Svarconst 
31481b4e7d2Svarconst   if (!hasOneBranchOpTo(entry, header))
31581b4e7d2Svarconst     return emitOpError(
31681b4e7d2Svarconst         "entry block must only have one 'spirv.Branch' op to the second block");
31781b4e7d2Svarconst 
31881b4e7d2Svarconst   if (std::next(region.begin(), 3) == region.end())
31981b4e7d2Svarconst     return emitOpError(
32081b4e7d2Svarconst         "requires a loop continue block branching to the loop header block");
32181b4e7d2Svarconst   // The second to last block is the loop continue block.
32281b4e7d2Svarconst   Block &cont = *std::prev(region.end(), 2);
32381b4e7d2Svarconst 
32481b4e7d2Svarconst   // Make sure that we have a branch from the loop continue block to the loop
32581b4e7d2Svarconst   // header block.
32681b4e7d2Svarconst   if (llvm::none_of(
32781b4e7d2Svarconst           llvm::seq<unsigned>(0, cont.getNumSuccessors()),
32881b4e7d2Svarconst           [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
32981b4e7d2Svarconst     return emitOpError("second to last block must be the loop continue "
33081b4e7d2Svarconst                        "block that branches to the loop header block");
33181b4e7d2Svarconst 
33281b4e7d2Svarconst   // Make sure that no other blocks (except the entry and loop continue block)
33381b4e7d2Svarconst   // branches to the loop header block.
33481b4e7d2Svarconst   for (auto &block : llvm::make_range(std::next(region.begin(), 2),
33581b4e7d2Svarconst                                       std::prev(region.end(), 2))) {
33681b4e7d2Svarconst     for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
33781b4e7d2Svarconst       if (block.getSuccessor(i) == &header) {
33881b4e7d2Svarconst         return emitOpError("can only have the entry and loop continue "
33981b4e7d2Svarconst                            "block branching to the loop header block");
34081b4e7d2Svarconst       }
34181b4e7d2Svarconst     }
34281b4e7d2Svarconst   }
34381b4e7d2Svarconst 
34481b4e7d2Svarconst   return success();
34581b4e7d2Svarconst }
34681b4e7d2Svarconst 
34781b4e7d2Svarconst Block *LoopOp::getEntryBlock() {
34881b4e7d2Svarconst   assert(!getBody().empty() && "op region should not be empty!");
34981b4e7d2Svarconst   return &getBody().front();
35081b4e7d2Svarconst }
35181b4e7d2Svarconst 
35281b4e7d2Svarconst Block *LoopOp::getHeaderBlock() {
35381b4e7d2Svarconst   assert(!getBody().empty() && "op region should not be empty!");
35481b4e7d2Svarconst   // The second block is the loop header block.
35581b4e7d2Svarconst   return &*std::next(getBody().begin());
35681b4e7d2Svarconst }
35781b4e7d2Svarconst 
35881b4e7d2Svarconst Block *LoopOp::getContinueBlock() {
35981b4e7d2Svarconst   assert(!getBody().empty() && "op region should not be empty!");
36081b4e7d2Svarconst   // The second to last block is the loop continue block.
36181b4e7d2Svarconst   return &*std::prev(getBody().end(), 2);
36281b4e7d2Svarconst }
36381b4e7d2Svarconst 
36481b4e7d2Svarconst Block *LoopOp::getMergeBlock() {
36581b4e7d2Svarconst   assert(!getBody().empty() && "op region should not be empty!");
36681b4e7d2Svarconst   // The last block is the loop merge block.
36781b4e7d2Svarconst   return &getBody().back();
36881b4e7d2Svarconst }
36981b4e7d2Svarconst 
37091d5653eSMatthias Springer void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
37181b4e7d2Svarconst   assert(getBody().empty() && "entry and merge block already exist");
37291d5653eSMatthias Springer   OpBuilder::InsertionGuard g(builder);
37391d5653eSMatthias Springer   builder.createBlock(&getBody());
37491d5653eSMatthias Springer   builder.createBlock(&getBody());
37581b4e7d2Svarconst 
37681b4e7d2Svarconst   // Add a spirv.mlir.merge op into the merge block.
37781b4e7d2Svarconst   builder.create<spirv::MergeOp>(getLoc());
37881b4e7d2Svarconst }
37981b4e7d2Svarconst 
38081b4e7d2Svarconst //===----------------------------------------------------------------------===//
38181b4e7d2Svarconst // spirv.mlir.merge
38281b4e7d2Svarconst //===----------------------------------------------------------------------===//
38381b4e7d2Svarconst 
38481b4e7d2Svarconst LogicalResult MergeOp::verify() {
38581b4e7d2Svarconst   auto *parentOp = (*this)->getParentOp();
38681b4e7d2Svarconst   if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
38781b4e7d2Svarconst     return emitOpError(
38881b4e7d2Svarconst         "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
38981b4e7d2Svarconst 
39081b4e7d2Svarconst   // TODO: This check should be done in `verifyRegions` of parent op.
39181b4e7d2Svarconst   Block &parentLastBlock = (*this)->getParentRegion()->back();
39281b4e7d2Svarconst   if (getOperation() != parentLastBlock.getTerminator())
39381b4e7d2Svarconst     return emitOpError("can only be used in the last block of "
39481b4e7d2Svarconst                        "'spirv.mlir.selection' or 'spirv.mlir.loop'");
39581b4e7d2Svarconst   return success();
39681b4e7d2Svarconst }
39781b4e7d2Svarconst 
39881b4e7d2Svarconst //===----------------------------------------------------------------------===//
39981b4e7d2Svarconst // spirv.Return
40081b4e7d2Svarconst //===----------------------------------------------------------------------===//
40181b4e7d2Svarconst 
40281b4e7d2Svarconst LogicalResult ReturnOp::verify() {
40381b4e7d2Svarconst   // Verification is performed in spirv.func op.
40481b4e7d2Svarconst   return success();
40581b4e7d2Svarconst }
40681b4e7d2Svarconst 
40781b4e7d2Svarconst //===----------------------------------------------------------------------===//
40881b4e7d2Svarconst // spirv.ReturnValue
40981b4e7d2Svarconst //===----------------------------------------------------------------------===//
41081b4e7d2Svarconst 
41181b4e7d2Svarconst LogicalResult ReturnValueOp::verify() {
41281b4e7d2Svarconst   // Verification is performed in spirv.func op.
41381b4e7d2Svarconst   return success();
41481b4e7d2Svarconst }
41581b4e7d2Svarconst 
41681b4e7d2Svarconst //===----------------------------------------------------------------------===//
41781b4e7d2Svarconst // spirv.Select
41881b4e7d2Svarconst //===----------------------------------------------------------------------===//
41981b4e7d2Svarconst 
42081b4e7d2Svarconst LogicalResult SelectOp::verify() {
42181b4e7d2Svarconst   if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
42281b4e7d2Svarconst     auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
42381b4e7d2Svarconst     if (!resultVectorTy) {
42481b4e7d2Svarconst       return emitOpError("result expected to be of vector type when "
42581b4e7d2Svarconst                          "condition is of vector type");
42681b4e7d2Svarconst     }
42781b4e7d2Svarconst     if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
42881b4e7d2Svarconst       return emitOpError("result should have the same number of elements as "
42981b4e7d2Svarconst                          "the condition when condition is of vector type");
43081b4e7d2Svarconst     }
43181b4e7d2Svarconst   }
43281b4e7d2Svarconst   return success();
43381b4e7d2Svarconst }
43481b4e7d2Svarconst 
4356d9eb31cSLei Zhang // Custom availability implementation is needed for spirv.Select given the
4366d9eb31cSLei Zhang // syntax changes starting v1.4.
4376d9eb31cSLei Zhang SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() {
4386d9eb31cSLei Zhang   return {};
4396d9eb31cSLei Zhang }
4406d9eb31cSLei Zhang SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() {
4416d9eb31cSLei Zhang   return {};
4426d9eb31cSLei Zhang }
4436d9eb31cSLei Zhang std::optional<spirv::Version> SelectOp::getMinVersion() {
4446d9eb31cSLei Zhang   // Per the spec, "Before version 1.4, results are only computed per
4456d9eb31cSLei Zhang   // component."
4466d9eb31cSLei Zhang   if (isa<spirv::ScalarType>(getCondition().getType()) &&
4476d9eb31cSLei Zhang       isa<spirv::CompositeType>(getType()))
4486d9eb31cSLei Zhang     return Version::V_1_4;
4496d9eb31cSLei Zhang 
4506d9eb31cSLei Zhang   return Version::V_1_0;
4516d9eb31cSLei Zhang }
4526d9eb31cSLei Zhang std::optional<spirv::Version> SelectOp::getMaxVersion() {
4536d9eb31cSLei Zhang   return Version::V_1_6;
4546d9eb31cSLei Zhang }
4556d9eb31cSLei Zhang 
45681b4e7d2Svarconst //===----------------------------------------------------------------------===//
45781b4e7d2Svarconst // spirv.mlir.selection
45881b4e7d2Svarconst //===----------------------------------------------------------------------===//
45981b4e7d2Svarconst 
46081b4e7d2Svarconst ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
46181b4e7d2Svarconst   if (parseControlAttribute<spirv::SelectionControlAttr,
46281b4e7d2Svarconst                             spirv::SelectionControl>(parser, result))
46381b4e7d2Svarconst     return failure();
46481b4e7d2Svarconst   return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
46581b4e7d2Svarconst }
46681b4e7d2Svarconst 
46781b4e7d2Svarconst void SelectionOp::print(OpAsmPrinter &printer) {
46881b4e7d2Svarconst   auto control = getSelectionControl();
46981b4e7d2Svarconst   if (control != spirv::SelectionControl::None)
47081b4e7d2Svarconst     printer << " control(" << spirv::stringifySelectionControl(control) << ")";
47181b4e7d2Svarconst   printer << ' ';
47281b4e7d2Svarconst   printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
47381b4e7d2Svarconst                       /*printBlockTerminators=*/true);
47481b4e7d2Svarconst }
47581b4e7d2Svarconst 
47681b4e7d2Svarconst LogicalResult SelectionOp::verifyRegions() {
47781b4e7d2Svarconst   auto *op = getOperation();
47881b4e7d2Svarconst 
47981b4e7d2Svarconst   // We need to verify that the blocks follow the following layout:
48081b4e7d2Svarconst   //
48181b4e7d2Svarconst   //                     +--------------+
48281b4e7d2Svarconst   //                     | header block |
48381b4e7d2Svarconst   //                     +--------------+
48481b4e7d2Svarconst   //                          / | \
48581b4e7d2Svarconst   //                           ...
48681b4e7d2Svarconst   //
48781b4e7d2Svarconst   //
48881b4e7d2Svarconst   //         +---------+   +---------+   +---------+
48981b4e7d2Svarconst   //         | case #0 |   | case #1 |   | case #2 |  ...
49081b4e7d2Svarconst   //         +---------+   +---------+   +---------+
49181b4e7d2Svarconst   //
49281b4e7d2Svarconst   //
49381b4e7d2Svarconst   //                           ...
49481b4e7d2Svarconst   //                          \ | /
49581b4e7d2Svarconst   //                            v
49681b4e7d2Svarconst   //                     +-------------+
49781b4e7d2Svarconst   //                     | merge block |
49881b4e7d2Svarconst   //                     +-------------+
49981b4e7d2Svarconst 
50081b4e7d2Svarconst   auto &region = op->getRegion(0);
50181b4e7d2Svarconst   // Allow empty region as a degenerated case, which can come from
50281b4e7d2Svarconst   // optimizations.
50381b4e7d2Svarconst   if (region.empty())
50481b4e7d2Svarconst     return success();
50581b4e7d2Svarconst 
50681b4e7d2Svarconst   // The last block is the merge block.
50781b4e7d2Svarconst   if (!isMergeBlock(region.back()))
50881b4e7d2Svarconst     return emitOpError("last block must be the merge block with only one "
50981b4e7d2Svarconst                        "'spirv.mlir.merge' op");
51081b4e7d2Svarconst 
51181b4e7d2Svarconst   if (std::next(region.begin()) == region.end())
51281b4e7d2Svarconst     return emitOpError("must have a selection header block");
51381b4e7d2Svarconst 
51481b4e7d2Svarconst   return success();
51581b4e7d2Svarconst }
51681b4e7d2Svarconst 
51781b4e7d2Svarconst Block *SelectionOp::getHeaderBlock() {
51881b4e7d2Svarconst   assert(!getBody().empty() && "op region should not be empty!");
51981b4e7d2Svarconst   // The first block is the loop header block.
52081b4e7d2Svarconst   return &getBody().front();
52181b4e7d2Svarconst }
52281b4e7d2Svarconst 
52381b4e7d2Svarconst Block *SelectionOp::getMergeBlock() {
52481b4e7d2Svarconst   assert(!getBody().empty() && "op region should not be empty!");
52581b4e7d2Svarconst   // The last block is the loop merge block.
52681b4e7d2Svarconst   return &getBody().back();
52781b4e7d2Svarconst }
52881b4e7d2Svarconst 
52991d5653eSMatthias Springer void SelectionOp::addMergeBlock(OpBuilder &builder) {
53081b4e7d2Svarconst   assert(getBody().empty() && "entry and merge block already exist");
53191d5653eSMatthias Springer   OpBuilder::InsertionGuard guard(builder);
53291d5653eSMatthias Springer   builder.createBlock(&getBody());
53381b4e7d2Svarconst 
53481b4e7d2Svarconst   // Add a spirv.mlir.merge op into the merge block.
53581b4e7d2Svarconst   builder.create<spirv::MergeOp>(getLoc());
53681b4e7d2Svarconst }
53781b4e7d2Svarconst 
53881b4e7d2Svarconst SelectionOp
53981b4e7d2Svarconst SelectionOp::createIfThen(Location loc, Value condition,
54081b4e7d2Svarconst                           function_ref<void(OpBuilder &builder)> thenBody,
54181b4e7d2Svarconst                           OpBuilder &builder) {
54281b4e7d2Svarconst   auto selectionOp =
54381b4e7d2Svarconst       builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
54481b4e7d2Svarconst 
54591d5653eSMatthias Springer   selectionOp.addMergeBlock(builder);
54681b4e7d2Svarconst   Block *mergeBlock = selectionOp.getMergeBlock();
54781b4e7d2Svarconst   Block *thenBlock = nullptr;
54881b4e7d2Svarconst 
54981b4e7d2Svarconst   // Build the "then" block.
55081b4e7d2Svarconst   {
55181b4e7d2Svarconst     OpBuilder::InsertionGuard guard(builder);
55281b4e7d2Svarconst     thenBlock = builder.createBlock(mergeBlock);
55381b4e7d2Svarconst     thenBody(builder);
55481b4e7d2Svarconst     builder.create<spirv::BranchOp>(loc, mergeBlock);
55581b4e7d2Svarconst   }
55681b4e7d2Svarconst 
55781b4e7d2Svarconst   // Build the header block.
55881b4e7d2Svarconst   {
55981b4e7d2Svarconst     OpBuilder::InsertionGuard guard(builder);
56081b4e7d2Svarconst     builder.createBlock(thenBlock);
56181b4e7d2Svarconst     builder.create<spirv::BranchConditionalOp>(
56281b4e7d2Svarconst         loc, condition, thenBlock,
56381b4e7d2Svarconst         /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
56481b4e7d2Svarconst         /*falseArguments=*/ArrayRef<Value>());
56581b4e7d2Svarconst   }
56681b4e7d2Svarconst 
56781b4e7d2Svarconst   return selectionOp;
56881b4e7d2Svarconst }
56981b4e7d2Svarconst 
57081b4e7d2Svarconst //===----------------------------------------------------------------------===//
57181b4e7d2Svarconst // spirv.Unreachable
57281b4e7d2Svarconst //===----------------------------------------------------------------------===//
57381b4e7d2Svarconst 
57481b4e7d2Svarconst LogicalResult spirv::UnreachableOp::verify() {
57581b4e7d2Svarconst   auto *block = (*this)->getBlock();
57681b4e7d2Svarconst   // Fast track: if this is in entry block, its invalid. Otherwise, if no
57781b4e7d2Svarconst   // predecessors, it's valid.
57881b4e7d2Svarconst   if (block->isEntryBlock())
57981b4e7d2Svarconst     return emitOpError("cannot be used in reachable block");
58081b4e7d2Svarconst   if (block->hasNoPredecessors())
58181b4e7d2Svarconst     return success();
58281b4e7d2Svarconst 
58381b4e7d2Svarconst   // TODO: further verification needs to analyze reachability from
58481b4e7d2Svarconst   // the entry block.
58581b4e7d2Svarconst 
58681b4e7d2Svarconst   return success();
58781b4e7d2Svarconst }
58881b4e7d2Svarconst 
58981b4e7d2Svarconst } // namespace mlir::spirv
590