xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp (revision d94aeb507d71d72f4153b4c87c77fcb5187b3e9a)
1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Arith/Utils/Utils.h"
14 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
15 #include "mlir/Dialect/Index/IR/IndexAttrs.h"
16 #include "mlir/Dialect/Index/IR/IndexDialect.h"
17 #include "mlir/Dialect/Index/IR/IndexOps.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
22 
23 namespace mlir {
24 namespace linalg {
25 namespace {
26 /// Verify that the runtime sizes of the operands to linalg structured ops are
27 /// compatible with the runtime sizes inferred by composing the loop ranges with
28 /// the linalg op's indexing maps. This is similar to the verifier except that
29 /// here we insert IR to perform the verification at runtime.
30 template <typename T>
31 struct StructuredOpInterface
32     : public RuntimeVerifiableOpInterface::ExternalModel<
33           StructuredOpInterface<T>, T> {
generateRuntimeVerificationmlir::linalg::__anon31d6dceb0111::StructuredOpInterface34   void generateRuntimeVerification(Operation *op, OpBuilder &builder,
35                                    Location loc) const {
36     auto linalgOp = llvm::cast<LinalgOp>(op);
37 
38     SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
39     auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges);
40 
41     auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
42     auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
43 
44     // Subtract one from the loop ends before composing with the indexing map
45     transform(ends, ends.begin(), [&](OpFoldResult end) {
46       auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
47       return builder.createOrFold<index::SubOp>(loc, endValue, one);
48     });
49 
50     for (OpOperand &opOperand : linalgOp->getOpOperands()) {
51       AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
52       auto startIndices = affine::makeComposedFoldedMultiResultAffineApply(
53           builder, loc, indexingMap, starts);
54       auto endIndices = affine::makeComposedFoldedMultiResultAffineApply(
55           builder, loc, indexingMap, ends);
56 
57       for (auto dim : llvm::seq(linalgOp.getRank(&opOperand))) {
58         auto startIndex =
59             getValueOrCreateConstantIndexOp(builder, loc, startIndices[dim]);
60         auto endIndex =
61             getValueOrCreateConstantIndexOp(builder, loc, endIndices[dim]);
62 
63         // Generate:
64         //   minIndex = min(startIndex, endIndex)
65         //   assert(minIndex >= 0)
66         // To ensure we do not generate a negative index. We take the minimum of
67         // the start and end indices in order to handle reverse loops such as
68         // `affine_map<(i) -> (3 - i)>`
69         auto min =
70             builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
71         auto cmpOp = builder.createOrFold<index::CmpOp>(
72             loc, index::IndexCmpPredicate::SGE, min, zero);
73         auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
74             linalgOp, "unexpected negative result on dimension #" +
75                           std::to_string(dim) + " of input/output operand #" +
76                           std::to_string(opOperand.getOperandNumber()));
77         builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
78 
79         // Generate:
80         //   inferredDimSize = max(startIndex, endIndex) + 1
81         //   actualDimSize = dim(operand)
82         //   assert(inferredDimSize <= actualDimSize)
83         // To ensure that we do not index past the bounds of the operands.
84         auto max =
85             builder.createOrFold<index::MaxSOp>(loc, startIndex, endIndex);
86 
87         auto inferredDimSize =
88             builder.createOrFold<index::AddOp>(loc, max, one);
89 
90         auto actualDimSize =
91             createOrFoldDimOp(builder, loc, opOperand.get(), dim);
92 
93         // Similar to the verifier, when the affine expression in the indexing
94         // map is complicated, we just check that the inferred dimension sizes
95         // are in the boundary of the operands' size. Being more precise than
96         // that is difficult.
97         auto predicate = isa<AffineDimExpr>(indexingMap.getResult(dim))
98                              ? index::IndexCmpPredicate::EQ
99                              : index::IndexCmpPredicate::SLE;
100 
101         cmpOp = builder.createOrFold<index::CmpOp>(
102             loc, predicate, inferredDimSize, actualDimSize);
103         msg = RuntimeVerifiableOpInterface::generateErrorMessage(
104             linalgOp, "dimension #" + std::to_string(dim) +
105                           " of input/output operand #" +
106                           std::to_string(opOperand.getOperandNumber()) +
107                           " is incompatible with inferred dimension size");
108         builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
109       }
110     }
111   }
112 };
113 
114 template <typename... OpTs>
attachInterface(MLIRContext * ctx)115 void attachInterface(MLIRContext *ctx) {
116   (OpTs::template attachInterface<StructuredOpInterface<OpTs>>(*ctx), ...);
117 }
118 } // namespace
119 } // namespace linalg
120 } // namespace mlir
121 
registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry & registry)122 void mlir::linalg::registerRuntimeVerifiableOpInterfaceExternalModels(
123     DialectRegistry &registry) {
124   registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *) {
125     attachInterface<
126 #define GET_OP_LIST
127 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
128         >(ctx);
129 
130     // Load additional dialects of which ops may get created.
131     ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
132                      cf::ControlFlowDialect, index::IndexDialect,
133                      tensor::TensorDialect>();
134   });
135 }
136