1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
10
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Arith/Utils/Utils.h"
14 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
15 #include "mlir/Dialect/Index/IR/IndexAttrs.h"
16 #include "mlir/Dialect/Index/IR/IndexDialect.h"
17 #include "mlir/Dialect/Index/IR/IndexOps.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
22
23 namespace mlir {
24 namespace linalg {
25 namespace {
26 /// Verify that the runtime sizes of the operands to linalg structured ops are
27 /// compatible with the runtime sizes inferred by composing the loop ranges with
28 /// the linalg op's indexing maps. This is similar to the verifier except that
29 /// here we insert IR to perform the verification at runtime.
30 template <typename T>
31 struct StructuredOpInterface
32 : public RuntimeVerifiableOpInterface::ExternalModel<
33 StructuredOpInterface<T>, T> {
generateRuntimeVerificationmlir::linalg::__anon31d6dceb0111::StructuredOpInterface34 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
35 Location loc) const {
36 auto linalgOp = llvm::cast<LinalgOp>(op);
37
38 SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
39 auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges);
40
41 auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
42 auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
43
44 // Subtract one from the loop ends before composing with the indexing map
45 transform(ends, ends.begin(), [&](OpFoldResult end) {
46 auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
47 return builder.createOrFold<index::SubOp>(loc, endValue, one);
48 });
49
50 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
51 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
52 auto startIndices = affine::makeComposedFoldedMultiResultAffineApply(
53 builder, loc, indexingMap, starts);
54 auto endIndices = affine::makeComposedFoldedMultiResultAffineApply(
55 builder, loc, indexingMap, ends);
56
57 for (auto dim : llvm::seq(linalgOp.getRank(&opOperand))) {
58 auto startIndex =
59 getValueOrCreateConstantIndexOp(builder, loc, startIndices[dim]);
60 auto endIndex =
61 getValueOrCreateConstantIndexOp(builder, loc, endIndices[dim]);
62
63 // Generate:
64 // minIndex = min(startIndex, endIndex)
65 // assert(minIndex >= 0)
66 // To ensure we do not generate a negative index. We take the minimum of
67 // the start and end indices in order to handle reverse loops such as
68 // `affine_map<(i) -> (3 - i)>`
69 auto min =
70 builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
71 auto cmpOp = builder.createOrFold<index::CmpOp>(
72 loc, index::IndexCmpPredicate::SGE, min, zero);
73 auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
74 linalgOp, "unexpected negative result on dimension #" +
75 std::to_string(dim) + " of input/output operand #" +
76 std::to_string(opOperand.getOperandNumber()));
77 builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
78
79 // Generate:
80 // inferredDimSize = max(startIndex, endIndex) + 1
81 // actualDimSize = dim(operand)
82 // assert(inferredDimSize <= actualDimSize)
83 // To ensure that we do not index past the bounds of the operands.
84 auto max =
85 builder.createOrFold<index::MaxSOp>(loc, startIndex, endIndex);
86
87 auto inferredDimSize =
88 builder.createOrFold<index::AddOp>(loc, max, one);
89
90 auto actualDimSize =
91 createOrFoldDimOp(builder, loc, opOperand.get(), dim);
92
93 // Similar to the verifier, when the affine expression in the indexing
94 // map is complicated, we just check that the inferred dimension sizes
95 // are in the boundary of the operands' size. Being more precise than
96 // that is difficult.
97 auto predicate = isa<AffineDimExpr>(indexingMap.getResult(dim))
98 ? index::IndexCmpPredicate::EQ
99 : index::IndexCmpPredicate::SLE;
100
101 cmpOp = builder.createOrFold<index::CmpOp>(
102 loc, predicate, inferredDimSize, actualDimSize);
103 msg = RuntimeVerifiableOpInterface::generateErrorMessage(
104 linalgOp, "dimension #" + std::to_string(dim) +
105 " of input/output operand #" +
106 std::to_string(opOperand.getOperandNumber()) +
107 " is incompatible with inferred dimension size");
108 builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
109 }
110 }
111 }
112 };
113
114 template <typename... OpTs>
attachInterface(MLIRContext * ctx)115 void attachInterface(MLIRContext *ctx) {
116 (OpTs::template attachInterface<StructuredOpInterface<OpTs>>(*ctx), ...);
117 }
118 } // namespace
119 } // namespace linalg
120 } // namespace mlir
121
registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry & registry)122 void mlir::linalg::registerRuntimeVerifiableOpInterfaceExternalModels(
123 DialectRegistry ®istry) {
124 registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *) {
125 attachInterface<
126 #define GET_OP_LIST
127 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
128 >(ctx);
129
130 // Load additional dialects of which ops may get created.
131 ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
132 cf::ControlFlowDialect, index::IndexDialect,
133 tensor::TensorDialect>();
134 });
135 }
136