xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Arith/Utils/Utils.h"
14 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
15 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18 #include "mlir/Dialect/Utils/IndexingUtils.h"
19 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
20 
21 using namespace mlir;
22 
23 namespace mlir {
24 namespace memref {
25 namespace {
26 struct CastOpInterface
27     : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
28                                                          CastOp> {
29   void generateRuntimeVerification(Operation *op, OpBuilder &builder,
30                                    Location loc) const {
31     auto castOp = cast<CastOp>(op);
32     auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
33 
34     // Nothing to check if the result is an unranked memref.
35     auto resultType = dyn_cast<MemRefType>(castOp.getType());
36     if (!resultType)
37       return;
38 
39     if (isa<UnrankedMemRefType>(srcType)) {
40       // Check rank.
41       Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
42       Value resultRank =
43           builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
44       Value isSameRank = builder.create<arith::CmpIOp>(
45           loc, arith::CmpIPredicate::eq, srcRank, resultRank);
46       builder.create<cf::AssertOp>(
47           loc, isSameRank,
48           RuntimeVerifiableOpInterface::generateErrorMessage(op,
49                                                              "rank mismatch"));
50     }
51 
52     // Get source offset and strides. We do not have an op to get offsets and
53     // strides from unranked memrefs, so cast the source to a type with fully
54     // dynamic layout, from which we can then extract the offset and strides.
55     // (Rank was already verified.)
56     int64_t dynamicOffset = ShapedType::kDynamic;
57     SmallVector<int64_t> dynamicShape(resultType.getRank(),
58                                       ShapedType::kDynamic);
59     auto stridedLayout = StridedLayoutAttr::get(builder.getContext(),
60                                                 dynamicOffset, dynamicShape);
61     auto dynStridesType =
62         MemRefType::get(dynamicShape, resultType.getElementType(),
63                         stridedLayout, resultType.getMemorySpace());
64     Value helperCast =
65         builder.create<CastOp>(loc, dynStridesType, castOp.getSource());
66     auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast);
67 
68     // Check dimension sizes.
69     for (const auto &it : llvm::enumerate(resultType.getShape())) {
70       // Static dim size -> static/dynamic dim size does not need verification.
71       if (auto rankedSrcType = dyn_cast<MemRefType>(srcType))
72         if (!rankedSrcType.isDynamicDim(it.index()))
73           continue;
74 
75       // Static/dynamic dim size -> dynamic dim size does not need verification.
76       if (resultType.isDynamicDim(it.index()))
77         continue;
78 
79       Value srcDimSz =
80           builder.create<DimOp>(loc, castOp.getSource(), it.index());
81       Value resultDimSz =
82           builder.create<arith::ConstantIndexOp>(loc, it.value());
83       Value isSameSz = builder.create<arith::CmpIOp>(
84           loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
85       builder.create<cf::AssertOp>(
86           loc, isSameSz,
87           RuntimeVerifiableOpInterface::generateErrorMessage(
88               op, "size mismatch of dim " + std::to_string(it.index())));
89     }
90 
91     // Get result offset and strides.
92     int64_t resultOffset;
93     SmallVector<int64_t> resultStrides;
94     if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
95       return;
96 
97     // Check offset.
98     if (resultOffset != ShapedType::kDynamic) {
99       // Static/dynamic offset -> dynamic offset does not need verification.
100       Value srcOffset = metadataOp.getResult(1);
101       Value resultOffsetVal =
102           builder.create<arith::ConstantIndexOp>(loc, resultOffset);
103       Value isSameOffset = builder.create<arith::CmpIOp>(
104           loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
105       builder.create<cf::AssertOp>(
106           loc, isSameOffset,
107           RuntimeVerifiableOpInterface::generateErrorMessage(
108               op, "offset mismatch"));
109     }
110 
111     // Check strides.
112     for (const auto &it : llvm::enumerate(resultStrides)) {
113       // Static/dynamic stride -> dynamic stride does not need verification.
114       if (it.value() == ShapedType::kDynamic)
115         continue;
116 
117       Value srcStride =
118           metadataOp.getResult(2 + resultType.getRank() + it.index());
119       Value resultStrideVal =
120           builder.create<arith::ConstantIndexOp>(loc, it.value());
121       Value isSameStride = builder.create<arith::CmpIOp>(
122           loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
123       builder.create<cf::AssertOp>(
124           loc, isSameStride,
125           RuntimeVerifiableOpInterface::generateErrorMessage(
126               op, "stride mismatch of dim " + std::to_string(it.index())));
127     }
128   }
129 };
130 
131 /// Verifies that the indices on load/store ops are in-bounds of the memref's
132 /// index space: 0 <= index#i < dim#i
133 template <typename LoadStoreOp>
134 struct LoadStoreOpInterface
135     : public RuntimeVerifiableOpInterface::ExternalModel<
136           LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
137   void generateRuntimeVerification(Operation *op, OpBuilder &builder,
138                                    Location loc) const {
139     auto loadStoreOp = cast<LoadStoreOp>(op);
140 
141     auto memref = loadStoreOp.getMemref();
142     auto rank = memref.getType().getRank();
143     if (rank == 0) {
144       return;
145     }
146     auto indices = loadStoreOp.getIndices();
147 
148     auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
149     Value assertCond;
150     for (auto i : llvm::seq<int64_t>(0, rank)) {
151       auto index = indices[i];
152 
153       auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
154 
155       auto geLow = builder.createOrFold<arith::CmpIOp>(
156           loc, arith::CmpIPredicate::sge, index, zero);
157       auto ltHigh = builder.createOrFold<arith::CmpIOp>(
158           loc, arith::CmpIPredicate::slt, index, dimOp);
159       auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
160 
161       assertCond =
162           i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
163                 : andOp;
164     }
165     builder.create<cf::AssertOp>(
166         loc, assertCond,
167         RuntimeVerifiableOpInterface::generateErrorMessage(
168             op, "out-of-bounds access"));
169   }
170 };
171 
172 /// Compute the linear index for the provided strided layout and indices.
173 Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
174                          ArrayRef<OpFoldResult> strides,
175                          ArrayRef<OpFoldResult> indices) {
176   auto [expr, values] = computeLinearIndex(offset, strides, indices);
177   auto index =
178       affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
179   return getValueOrCreateConstantIndexOp(builder, loc, index);
180 }
181 
182 /// Returns two Values representing the bounds of the provided strided layout
183 /// metadata. The bounds are returned as a half open interval -- [low, high).
184 std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
185                                             OpFoldResult offset,
186                                             ArrayRef<OpFoldResult> strides,
187                                             ArrayRef<OpFoldResult> sizes) {
188   auto zeros = SmallVector<int64_t>(sizes.size(), 0);
189   auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
190   auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
191   auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
192   return {lowerBound, upperBound};
193 }
194 
195 /// Returns two Values representing the bounds of the memref. The bounds are
196 /// returned as a half open interval -- [low, high).
197 std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
198                                             TypedValue<BaseMemRefType> memref) {
199   auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
200   auto offset = runtimeMetadata.getConstifiedMixedOffset();
201   auto strides = runtimeMetadata.getConstifiedMixedStrides();
202   auto sizes = runtimeMetadata.getConstifiedMixedSizes();
203   return computeLinearBounds(builder, loc, offset, strides, sizes);
204 }
205 
206 /// Verifies that the linear bounds of a reinterpret_cast op are within the
207 /// linear bounds of the base memref: low >= baseLow && high <= baseHigh
208 struct ReinterpretCastOpInterface
209     : public RuntimeVerifiableOpInterface::ExternalModel<
210           ReinterpretCastOpInterface, ReinterpretCastOp> {
211   void generateRuntimeVerification(Operation *op, OpBuilder &builder,
212                                    Location loc) const {
213     auto reinterpretCast = cast<ReinterpretCastOp>(op);
214     auto baseMemref = reinterpretCast.getSource();
215     auto resultMemref =
216         cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
217 
218     builder.setInsertionPointAfter(op);
219 
220     // Compute the linear bounds of the base memref
221     auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
222 
223     // Compute the linear bounds of the resulting memref
224     auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
225 
226     // Check low >= baseLow
227     auto geLow = builder.createOrFold<arith::CmpIOp>(
228         loc, arith::CmpIPredicate::sge, low, baseLow);
229 
230     // Check high <= baseHigh
231     auto leHigh = builder.createOrFold<arith::CmpIOp>(
232         loc, arith::CmpIPredicate::sle, high, baseHigh);
233 
234     auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
235 
236     builder.create<cf::AssertOp>(
237         loc, assertCond,
238         RuntimeVerifiableOpInterface::generateErrorMessage(
239             op,
240             "result of reinterpret_cast is out-of-bounds of the base memref"));
241   }
242 };
243 
244 /// Verifies that the linear bounds of a subview op are within the linear bounds
245 /// of the base memref: low >= baseLow && high <= baseHigh
246 /// TODO: This is not yet a full runtime verification of subview. For example,
247 /// consider:
248 ///   %m = memref.alloc(%c10, %c10) : memref<10x10xf32>
249 ///   memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1]
250 ///      : memref<?x?xf32> to memref<?x?xf32>
251 /// The subview is in-bounds of the entire base memref but the first dimension
252 /// is out-of-bounds. Future work would verify the bounds on a per-dimension
253 /// basis.
254 struct SubViewOpInterface
255     : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
256                                                          SubViewOp> {
257   void generateRuntimeVerification(Operation *op, OpBuilder &builder,
258                                    Location loc) const {
259     auto subView = cast<SubViewOp>(op);
260     auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource());
261     auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult());
262 
263     builder.setInsertionPointAfter(op);
264 
265     // Compute the linear bounds of the base memref
266     auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
267 
268     // Compute the linear bounds of the resulting memref
269     auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
270 
271     // Check low >= baseLow
272     auto geLow = builder.createOrFold<arith::CmpIOp>(
273         loc, arith::CmpIPredicate::sge, low, baseLow);
274 
275     // Check high <= baseHigh
276     auto leHigh = builder.createOrFold<arith::CmpIOp>(
277         loc, arith::CmpIPredicate::sle, high, baseHigh);
278 
279     auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
280 
281     builder.create<cf::AssertOp>(
282         loc, assertCond,
283         RuntimeVerifiableOpInterface::generateErrorMessage(
284             op, "subview is out-of-bounds of the base memref"));
285   }
286 };
287 
288 struct ExpandShapeOpInterface
289     : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
290                                                          ExpandShapeOp> {
291   void generateRuntimeVerification(Operation *op, OpBuilder &builder,
292                                    Location loc) const {
293     auto expandShapeOp = cast<ExpandShapeOp>(op);
294 
295     // Verify that the expanded dim sizes are a product of the collapsed dim
296     // size.
297     for (const auto &it :
298          llvm::enumerate(expandShapeOp.getReassociationIndices())) {
299       Value srcDimSz =
300           builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
301       int64_t groupSz = 1;
302       bool foundDynamicDim = false;
303       for (int64_t resultDim : it.value()) {
304         if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
305           // Keep this assert here in case the op is extended in the future.
306           assert(!foundDynamicDim &&
307                  "more than one dynamic dim found in reassoc group");
308           (void)foundDynamicDim;
309           foundDynamicDim = true;
310           continue;
311         }
312         groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
313       }
314       Value staticResultDimSz =
315           builder.create<arith::ConstantIndexOp>(loc, groupSz);
316       // staticResultDimSz must divide srcDimSz evenly.
317       Value mod =
318           builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
319       Value isModZero = builder.create<arith::CmpIOp>(
320           loc, arith::CmpIPredicate::eq, mod,
321           builder.create<arith::ConstantIndexOp>(loc, 0));
322       builder.create<cf::AssertOp>(
323           loc, isModZero,
324           RuntimeVerifiableOpInterface::generateErrorMessage(
325               op, "static result dims in reassoc group do not "
326                   "divide src dim evenly"));
327     }
328   }
329 };
330 } // namespace
331 } // namespace memref
332 } // namespace mlir
333 
334 void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
335     DialectRegistry &registry) {
336   registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
337     CastOp::attachInterface<CastOpInterface>(*ctx);
338     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
339     LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
340     ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
341     StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
342     SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
343 
344     // Load additional dialects of which ops may get created.
345     ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
346                      cf::ControlFlowDialect>();
347   });
348 }
349