xref: /llvm-project/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp (revision 129f1001c3b1b5200de43917d53c0efbdf08f11f)
1 //===- ControlFlowOps.cpp - MLIR SPIR-V Control Flow Ops  -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the control flow operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
16 #include "mlir/Interfaces/CallInterfaces.h"
17 
18 #include "SPIRVOpUtils.h"
19 #include "SPIRVParsingUtils.h"
20 
21 using namespace mlir::spirv::AttrNames;
22 
23 namespace mlir::spirv {
24 
25 /// Parses Function, Selection and Loop control attributes. If no control is
26 /// specified, "None" is used as a default.
27 template <typename EnumAttrClass, typename EnumClass>
28 static ParseResult
29 parseControlAttribute(OpAsmParser &parser, OperationState &state,
30                       StringRef attrName = spirv::attributeName<EnumClass>()) {
31   if (succeeded(parser.parseOptionalKeyword(kControl))) {
32     EnumClass control;
33     if (parser.parseLParen() ||
34         spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
35         parser.parseRParen())
36       return failure();
37     return success();
38   }
39   // Set control to "None" otherwise.
40   Builder builder = parser.getBuilder();
41   state.addAttribute(attrName,
42                      builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
43   return success();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // spirv.BranchOp
48 //===----------------------------------------------------------------------===//
49 
50 SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
51   assert(index == 0 && "invalid successor index");
52   return SuccessorOperands(0, getTargetOperandsMutable());
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // spirv.BranchConditionalOp
57 //===----------------------------------------------------------------------===//
58 
59 SuccessorOperands BranchConditionalOp::getSuccessorOperands(unsigned index) {
60   assert(index < 2 && "invalid successor index");
61   return SuccessorOperands(index == kTrueIndex
62                                ? getTrueTargetOperandsMutable()
63                                : getFalseTargetOperandsMutable());
64 }
65 
66 ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
67                                        OperationState &result) {
68   auto &builder = parser.getBuilder();
69   OpAsmParser::UnresolvedOperand condInfo;
70   Block *dest;
71 
72   // Parse the condition.
73   Type boolTy = builder.getI1Type();
74   if (parser.parseOperand(condInfo) ||
75       parser.resolveOperand(condInfo, boolTy, result.operands))
76     return failure();
77 
78   // Parse the optional branch weights.
79   if (succeeded(parser.parseOptionalLSquare())) {
80     IntegerAttr trueWeight, falseWeight;
81     NamedAttrList weights;
82 
83     auto i32Type = builder.getIntegerType(32);
84     if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
85         parser.parseComma() ||
86         parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
87         parser.parseRSquare())
88       return failure();
89 
90     StringAttr branchWeightsAttrName =
91         BranchConditionalOp::getBranchWeightsAttrName(result.name);
92     result.addAttribute(branchWeightsAttrName,
93                         builder.getArrayAttr({trueWeight, falseWeight}));
94   }
95 
96   // Parse the true branch.
97   SmallVector<Value, 4> trueOperands;
98   if (parser.parseComma() ||
99       parser.parseSuccessorAndUseList(dest, trueOperands))
100     return failure();
101   result.addSuccessors(dest);
102   result.addOperands(trueOperands);
103 
104   // Parse the false branch.
105   SmallVector<Value, 4> falseOperands;
106   if (parser.parseComma() ||
107       parser.parseSuccessorAndUseList(dest, falseOperands))
108     return failure();
109   result.addSuccessors(dest);
110   result.addOperands(falseOperands);
111   result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
112                       builder.getDenseI32ArrayAttr(
113                           {1, static_cast<int32_t>(trueOperands.size()),
114                            static_cast<int32_t>(falseOperands.size())}));
115 
116   return success();
117 }
118 
119 void BranchConditionalOp::print(OpAsmPrinter &printer) {
120   printer << ' ' << getCondition();
121 
122   if (auto weights = getBranchWeights()) {
123     printer << " [";
124     llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
125       printer << llvm::cast<IntegerAttr>(a).getInt();
126     });
127     printer << "]";
128   }
129 
130   printer << ", ";
131   printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
132   printer << ", ";
133   printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
134 }
135 
136 LogicalResult BranchConditionalOp::verify() {
137   if (auto weights = getBranchWeights()) {
138     if (weights->getValue().size() != 2) {
139       return emitOpError("must have exactly two branch weights");
140     }
141     if (llvm::all_of(*weights, [](Attribute attr) {
142           return llvm::cast<IntegerAttr>(attr).getValue().isZero();
143         }))
144       return emitOpError("branch weights cannot both be zero");
145   }
146 
147   return success();
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // spirv.FunctionCall
152 //===----------------------------------------------------------------------===//
153 
154 LogicalResult FunctionCallOp::verify() {
155   auto fnName = getCalleeAttr();
156 
157   auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
158       SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
159   if (!funcOp) {
160     return emitOpError("callee function '")
161            << fnName.getValue() << "' not found in nearest symbol table";
162   }
163 
164   auto functionType = funcOp.getFunctionType();
165 
166   if (getNumResults() > 1) {
167     return emitOpError(
168                "expected callee function to have 0 or 1 result, but provided ")
169            << getNumResults();
170   }
171 
172   if (functionType.getNumInputs() != getNumOperands()) {
173     return emitOpError("has incorrect number of operands for callee: expected ")
174            << functionType.getNumInputs() << ", but provided "
175            << getNumOperands();
176   }
177 
178   for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
179     if (getOperand(i).getType() != functionType.getInput(i)) {
180       return emitOpError("operand type mismatch: expected operand type ")
181              << functionType.getInput(i) << ", but provided "
182              << getOperand(i).getType() << " for operand number " << i;
183     }
184   }
185 
186   if (functionType.getNumResults() != getNumResults()) {
187     return emitOpError(
188                "has incorrect number of results has for callee: expected ")
189            << functionType.getNumResults() << ", but provided "
190            << getNumResults();
191   }
192 
193   if (getNumResults() &&
194       (getResult(0).getType() != functionType.getResult(0))) {
195     return emitOpError("result type mismatch: expected ")
196            << functionType.getResult(0) << ", but provided "
197            << getResult(0).getType();
198   }
199 
200   return success();
201 }
202 
203 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
204   return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
205 }
206 
207 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
208   (*this)->setAttr(getCalleeAttrName(), cast<SymbolRefAttr>(callee));
209 }
210 
211 Operation::operand_range FunctionCallOp::getArgOperands() {
212   return getArguments();
213 }
214 
215 MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
216   return getArgumentsMutable();
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // spirv.mlir.loop
221 //===----------------------------------------------------------------------===//
222 
223 void LoopOp::build(OpBuilder &builder, OperationState &state) {
224   state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
225                                          spirv::LoopControl::None));
226   state.addRegion();
227 }
228 
229 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
230   if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
231                                                                         result))
232     return failure();
233   return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
234 }
235 
236 void LoopOp::print(OpAsmPrinter &printer) {
237   auto control = getLoopControl();
238   if (control != spirv::LoopControl::None)
239     printer << " control(" << spirv::stringifyLoopControl(control) << ")";
240   printer << ' ';
241   printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
242                       /*printBlockTerminators=*/true);
243 }
244 
245 /// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
246 /// given `dstBlock`.
247 static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
248   // Check that there is only one op in the `srcBlock`.
249   if (!llvm::hasSingleElement(srcBlock))
250     return false;
251 
252   auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
253   return branchOp && branchOp.getSuccessor() == &dstBlock;
254 }
255 
256 /// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
257 static bool isMergeBlock(Block &block) {
258   return !block.empty() && std::next(block.begin()) == block.end() &&
259          isa<spirv::MergeOp>(block.front());
260 }
261 
262 LogicalResult LoopOp::verifyRegions() {
263   auto *op = getOperation();
264 
265   // We need to verify that the blocks follow the following layout:
266   //
267   //                     +-------------+
268   //                     | entry block |
269   //                     +-------------+
270   //                            |
271   //                            v
272   //                     +-------------+
273   //                     | loop header | <-----+
274   //                     +-------------+       |
275   //                                           |
276   //                           ...             |
277   //                          \ | /            |
278   //                            v              |
279   //                    +---------------+      |
280   //                    | loop continue | -----+
281   //                    +---------------+
282   //
283   //                           ...
284   //                          \ | /
285   //                            v
286   //                     +-------------+
287   //                     | merge block |
288   //                     +-------------+
289 
290   auto &region = op->getRegion(0);
291   // Allow empty region as a degenerated case, which can come from
292   // optimizations.
293   if (region.empty())
294     return success();
295 
296   // The last block is the merge block.
297   Block &merge = region.back();
298   if (!isMergeBlock(merge))
299     return emitOpError("last block must be the merge block with only one "
300                        "'spirv.mlir.merge' op");
301 
302   if (std::next(region.begin()) == region.end())
303     return emitOpError(
304         "must have an entry block branching to the loop header block");
305   // The first block is the entry block.
306   Block &entry = region.front();
307 
308   if (std::next(region.begin(), 2) == region.end())
309     return emitOpError(
310         "must have a loop header block branched from the entry block");
311   // The second block is the loop header block.
312   Block &header = *std::next(region.begin(), 1);
313 
314   if (!hasOneBranchOpTo(entry, header))
315     return emitOpError(
316         "entry block must only have one 'spirv.Branch' op to the second block");
317 
318   if (std::next(region.begin(), 3) == region.end())
319     return emitOpError(
320         "requires a loop continue block branching to the loop header block");
321   // The second to last block is the loop continue block.
322   Block &cont = *std::prev(region.end(), 2);
323 
324   // Make sure that we have a branch from the loop continue block to the loop
325   // header block.
326   if (llvm::none_of(
327           llvm::seq<unsigned>(0, cont.getNumSuccessors()),
328           [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
329     return emitOpError("second to last block must be the loop continue "
330                        "block that branches to the loop header block");
331 
332   // Make sure that no other blocks (except the entry and loop continue block)
333   // branches to the loop header block.
334   for (auto &block : llvm::make_range(std::next(region.begin(), 2),
335                                       std::prev(region.end(), 2))) {
336     for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
337       if (block.getSuccessor(i) == &header) {
338         return emitOpError("can only have the entry and loop continue "
339                            "block branching to the loop header block");
340       }
341     }
342   }
343 
344   return success();
345 }
346 
347 Block *LoopOp::getEntryBlock() {
348   assert(!getBody().empty() && "op region should not be empty!");
349   return &getBody().front();
350 }
351 
352 Block *LoopOp::getHeaderBlock() {
353   assert(!getBody().empty() && "op region should not be empty!");
354   // The second block is the loop header block.
355   return &*std::next(getBody().begin());
356 }
357 
358 Block *LoopOp::getContinueBlock() {
359   assert(!getBody().empty() && "op region should not be empty!");
360   // The second to last block is the loop continue block.
361   return &*std::prev(getBody().end(), 2);
362 }
363 
364 Block *LoopOp::getMergeBlock() {
365   assert(!getBody().empty() && "op region should not be empty!");
366   // The last block is the loop merge block.
367   return &getBody().back();
368 }
369 
370 void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
371   assert(getBody().empty() && "entry and merge block already exist");
372   OpBuilder::InsertionGuard g(builder);
373   builder.createBlock(&getBody());
374   builder.createBlock(&getBody());
375 
376   // Add a spirv.mlir.merge op into the merge block.
377   builder.create<spirv::MergeOp>(getLoc());
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // spirv.mlir.merge
382 //===----------------------------------------------------------------------===//
383 
384 LogicalResult MergeOp::verify() {
385   auto *parentOp = (*this)->getParentOp();
386   if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
387     return emitOpError(
388         "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
389 
390   // TODO: This check should be done in `verifyRegions` of parent op.
391   Block &parentLastBlock = (*this)->getParentRegion()->back();
392   if (getOperation() != parentLastBlock.getTerminator())
393     return emitOpError("can only be used in the last block of "
394                        "'spirv.mlir.selection' or 'spirv.mlir.loop'");
395   return success();
396 }
397 
398 //===----------------------------------------------------------------------===//
399 // spirv.Return
400 //===----------------------------------------------------------------------===//
401 
402 LogicalResult ReturnOp::verify() {
403   // Verification is performed in spirv.func op.
404   return success();
405 }
406 
407 //===----------------------------------------------------------------------===//
408 // spirv.ReturnValue
409 //===----------------------------------------------------------------------===//
410 
411 LogicalResult ReturnValueOp::verify() {
412   // Verification is performed in spirv.func op.
413   return success();
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // spirv.Select
418 //===----------------------------------------------------------------------===//
419 
420 LogicalResult SelectOp::verify() {
421   if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
422     auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
423     if (!resultVectorTy) {
424       return emitOpError("result expected to be of vector type when "
425                          "condition is of vector type");
426     }
427     if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
428       return emitOpError("result should have the same number of elements as "
429                          "the condition when condition is of vector type");
430     }
431   }
432   return success();
433 }
434 
435 // Custom availability implementation is needed for spirv.Select given the
436 // syntax changes starting v1.4.
437 SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() {
438   return {};
439 }
440 SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() {
441   return {};
442 }
443 std::optional<spirv::Version> SelectOp::getMinVersion() {
444   // Per the spec, "Before version 1.4, results are only computed per
445   // component."
446   if (isa<spirv::ScalarType>(getCondition().getType()) &&
447       isa<spirv::CompositeType>(getType()))
448     return Version::V_1_4;
449 
450   return Version::V_1_0;
451 }
452 std::optional<spirv::Version> SelectOp::getMaxVersion() {
453   return Version::V_1_6;
454 }
455 
456 //===----------------------------------------------------------------------===//
457 // spirv.mlir.selection
458 //===----------------------------------------------------------------------===//
459 
460 ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
461   if (parseControlAttribute<spirv::SelectionControlAttr,
462                             spirv::SelectionControl>(parser, result))
463     return failure();
464   return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
465 }
466 
467 void SelectionOp::print(OpAsmPrinter &printer) {
468   auto control = getSelectionControl();
469   if (control != spirv::SelectionControl::None)
470     printer << " control(" << spirv::stringifySelectionControl(control) << ")";
471   printer << ' ';
472   printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
473                       /*printBlockTerminators=*/true);
474 }
475 
476 LogicalResult SelectionOp::verifyRegions() {
477   auto *op = getOperation();
478 
479   // We need to verify that the blocks follow the following layout:
480   //
481   //                     +--------------+
482   //                     | header block |
483   //                     +--------------+
484   //                          / | \
485   //                           ...
486   //
487   //
488   //         +---------+   +---------+   +---------+
489   //         | case #0 |   | case #1 |   | case #2 |  ...
490   //         +---------+   +---------+   +---------+
491   //
492   //
493   //                           ...
494   //                          \ | /
495   //                            v
496   //                     +-------------+
497   //                     | merge block |
498   //                     +-------------+
499 
500   auto &region = op->getRegion(0);
501   // Allow empty region as a degenerated case, which can come from
502   // optimizations.
503   if (region.empty())
504     return success();
505 
506   // The last block is the merge block.
507   if (!isMergeBlock(region.back()))
508     return emitOpError("last block must be the merge block with only one "
509                        "'spirv.mlir.merge' op");
510 
511   if (std::next(region.begin()) == region.end())
512     return emitOpError("must have a selection header block");
513 
514   return success();
515 }
516 
517 Block *SelectionOp::getHeaderBlock() {
518   assert(!getBody().empty() && "op region should not be empty!");
519   // The first block is the loop header block.
520   return &getBody().front();
521 }
522 
523 Block *SelectionOp::getMergeBlock() {
524   assert(!getBody().empty() && "op region should not be empty!");
525   // The last block is the loop merge block.
526   return &getBody().back();
527 }
528 
529 void SelectionOp::addMergeBlock(OpBuilder &builder) {
530   assert(getBody().empty() && "entry and merge block already exist");
531   OpBuilder::InsertionGuard guard(builder);
532   builder.createBlock(&getBody());
533 
534   // Add a spirv.mlir.merge op into the merge block.
535   builder.create<spirv::MergeOp>(getLoc());
536 }
537 
538 SelectionOp
539 SelectionOp::createIfThen(Location loc, Value condition,
540                           function_ref<void(OpBuilder &builder)> thenBody,
541                           OpBuilder &builder) {
542   auto selectionOp =
543       builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
544 
545   selectionOp.addMergeBlock(builder);
546   Block *mergeBlock = selectionOp.getMergeBlock();
547   Block *thenBlock = nullptr;
548 
549   // Build the "then" block.
550   {
551     OpBuilder::InsertionGuard guard(builder);
552     thenBlock = builder.createBlock(mergeBlock);
553     thenBody(builder);
554     builder.create<spirv::BranchOp>(loc, mergeBlock);
555   }
556 
557   // Build the header block.
558   {
559     OpBuilder::InsertionGuard guard(builder);
560     builder.createBlock(thenBlock);
561     builder.create<spirv::BranchConditionalOp>(
562         loc, condition, thenBlock,
563         /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
564         /*falseArguments=*/ArrayRef<Value>());
565   }
566 
567   return selectionOp;
568 }
569 
570 //===----------------------------------------------------------------------===//
571 // spirv.Unreachable
572 //===----------------------------------------------------------------------===//
573 
574 LogicalResult spirv::UnreachableOp::verify() {
575   auto *block = (*this)->getBlock();
576   // Fast track: if this is in entry block, its invalid. Otherwise, if no
577   // predecessors, it's valid.
578   if (block->isEntryBlock())
579     return emitOpError("cannot be used in reachable block");
580   if (block->hasNoPredecessors())
581     return success();
582 
583   // TODO: further verification needs to analyze reachability from
584   // the entry block.
585 
586   return success();
587 }
588 
589 } // namespace mlir::spirv
590