xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1108b08f2SMatthias Springer //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2108b08f2SMatthias Springer //
3108b08f2SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4108b08f2SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5108b08f2SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6108b08f2SMatthias Springer //
7108b08f2SMatthias Springer //===----------------------------------------------------------------------===//
8108b08f2SMatthias Springer 
9108b08f2SMatthias Springer #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
10108b08f2SMatthias Springer 
11847a6f8fSRyan Holt #include "mlir/Dialect/Affine/IR/AffineOps.h"
12108b08f2SMatthias Springer #include "mlir/Dialect/Arith/IR/Arith.h"
13847a6f8fSRyan Holt #include "mlir/Dialect/Arith/Utils/Utils.h"
14108b08f2SMatthias Springer #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
15108b08f2SMatthias Springer #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16108b08f2SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
17847a6f8fSRyan Holt #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18847a6f8fSRyan Holt #include "mlir/Dialect/Utils/IndexingUtils.h"
19108b08f2SMatthias Springer #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
20108b08f2SMatthias Springer 
215eee80ceSMatthias Springer using namespace mlir;
225eee80ceSMatthias Springer 
23108b08f2SMatthias Springer namespace mlir {
24108b08f2SMatthias Springer namespace memref {
25108b08f2SMatthias Springer namespace {
265eee80ceSMatthias Springer struct CastOpInterface
275eee80ceSMatthias Springer     : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
285eee80ceSMatthias Springer                                                          CastOp> {
295eee80ceSMatthias Springer   void generateRuntimeVerification(Operation *op, OpBuilder &builder,
305eee80ceSMatthias Springer                                    Location loc) const {
315eee80ceSMatthias Springer     auto castOp = cast<CastOp>(op);
325550c821STres Popp     auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
335eee80ceSMatthias Springer 
345eee80ceSMatthias Springer     // Nothing to check if the result is an unranked memref.
355550c821STres Popp     auto resultType = dyn_cast<MemRefType>(castOp.getType());
365eee80ceSMatthias Springer     if (!resultType)
375eee80ceSMatthias Springer       return;
385eee80ceSMatthias Springer 
395550c821STres Popp     if (isa<UnrankedMemRefType>(srcType)) {
405eee80ceSMatthias Springer       // Check rank.
415eee80ceSMatthias Springer       Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
425eee80ceSMatthias Springer       Value resultRank =
435eee80ceSMatthias Springer           builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
445eee80ceSMatthias Springer       Value isSameRank = builder.create<arith::CmpIOp>(
455eee80ceSMatthias Springer           loc, arith::CmpIPredicate::eq, srcRank, resultRank);
46d94aeb50SRyan Holt       builder.create<cf::AssertOp>(
47d94aeb50SRyan Holt           loc, isSameRank,
48d94aeb50SRyan Holt           RuntimeVerifiableOpInterface::generateErrorMessage(op,
49d94aeb50SRyan Holt                                                              "rank mismatch"));
505eee80ceSMatthias Springer     }
515eee80ceSMatthias Springer 
525eee80ceSMatthias Springer     // Get source offset and strides. We do not have an op to get offsets and
535eee80ceSMatthias Springer     // strides from unranked memrefs, so cast the source to a type with fully
545eee80ceSMatthias Springer     // dynamic layout, from which we can then extract the offset and strides.
555eee80ceSMatthias Springer     // (Rank was already verified.)
565eee80ceSMatthias Springer     int64_t dynamicOffset = ShapedType::kDynamic;
575eee80ceSMatthias Springer     SmallVector<int64_t> dynamicShape(resultType.getRank(),
585eee80ceSMatthias Springer                                       ShapedType::kDynamic);
595eee80ceSMatthias Springer     auto stridedLayout = StridedLayoutAttr::get(builder.getContext(),
605eee80ceSMatthias Springer                                                 dynamicOffset, dynamicShape);
615eee80ceSMatthias Springer     auto dynStridesType =
625eee80ceSMatthias Springer         MemRefType::get(dynamicShape, resultType.getElementType(),
635eee80ceSMatthias Springer                         stridedLayout, resultType.getMemorySpace());
645eee80ceSMatthias Springer     Value helperCast =
655eee80ceSMatthias Springer         builder.create<CastOp>(loc, dynStridesType, castOp.getSource());
665eee80ceSMatthias Springer     auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast);
675eee80ceSMatthias Springer 
685eee80ceSMatthias Springer     // Check dimension sizes.
695eee80ceSMatthias Springer     for (const auto &it : llvm::enumerate(resultType.getShape())) {
705eee80ceSMatthias Springer       // Static dim size -> static/dynamic dim size does not need verification.
715550c821STres Popp       if (auto rankedSrcType = dyn_cast<MemRefType>(srcType))
725eee80ceSMatthias Springer         if (!rankedSrcType.isDynamicDim(it.index()))
735eee80ceSMatthias Springer           continue;
745eee80ceSMatthias Springer 
755eee80ceSMatthias Springer       // Static/dynamic dim size -> dynamic dim size does not need verification.
765eee80ceSMatthias Springer       if (resultType.isDynamicDim(it.index()))
775eee80ceSMatthias Springer         continue;
785eee80ceSMatthias Springer 
795eee80ceSMatthias Springer       Value srcDimSz =
805eee80ceSMatthias Springer           builder.create<DimOp>(loc, castOp.getSource(), it.index());
815eee80ceSMatthias Springer       Value resultDimSz =
825eee80ceSMatthias Springer           builder.create<arith::ConstantIndexOp>(loc, it.value());
835eee80ceSMatthias Springer       Value isSameSz = builder.create<arith::CmpIOp>(
845eee80ceSMatthias Springer           loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
855eee80ceSMatthias Springer       builder.create<cf::AssertOp>(
865eee80ceSMatthias Springer           loc, isSameSz,
87d94aeb50SRyan Holt           RuntimeVerifiableOpInterface::generateErrorMessage(
88d94aeb50SRyan Holt               op, "size mismatch of dim " + std::to_string(it.index())));
895eee80ceSMatthias Springer     }
905eee80ceSMatthias Springer 
915eee80ceSMatthias Springer     // Get result offset and strides.
925eee80ceSMatthias Springer     int64_t resultOffset;
935eee80ceSMatthias Springer     SmallVector<int64_t> resultStrides;
94*6aaa8f25SMatthias Springer     if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
955eee80ceSMatthias Springer       return;
965eee80ceSMatthias Springer 
975eee80ceSMatthias Springer     // Check offset.
985eee80ceSMatthias Springer     if (resultOffset != ShapedType::kDynamic) {
995eee80ceSMatthias Springer       // Static/dynamic offset -> dynamic offset does not need verification.
1005eee80ceSMatthias Springer       Value srcOffset = metadataOp.getResult(1);
1015eee80ceSMatthias Springer       Value resultOffsetVal =
1025eee80ceSMatthias Springer           builder.create<arith::ConstantIndexOp>(loc, resultOffset);
1035eee80ceSMatthias Springer       Value isSameOffset = builder.create<arith::CmpIOp>(
1045eee80ceSMatthias Springer           loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
105d94aeb50SRyan Holt       builder.create<cf::AssertOp>(
106d94aeb50SRyan Holt           loc, isSameOffset,
107d94aeb50SRyan Holt           RuntimeVerifiableOpInterface::generateErrorMessage(
108d94aeb50SRyan Holt               op, "offset mismatch"));
1095eee80ceSMatthias Springer     }
1105eee80ceSMatthias Springer 
1115eee80ceSMatthias Springer     // Check strides.
1125eee80ceSMatthias Springer     for (const auto &it : llvm::enumerate(resultStrides)) {
1135eee80ceSMatthias Springer       // Static/dynamic stride -> dynamic stride does not need verification.
1145eee80ceSMatthias Springer       if (it.value() == ShapedType::kDynamic)
1155eee80ceSMatthias Springer         continue;
1165eee80ceSMatthias Springer 
1175eee80ceSMatthias Springer       Value srcStride =
1185eee80ceSMatthias Springer           metadataOp.getResult(2 + resultType.getRank() + it.index());
1195eee80ceSMatthias Springer       Value resultStrideVal =
1205eee80ceSMatthias Springer           builder.create<arith::ConstantIndexOp>(loc, it.value());
1215eee80ceSMatthias Springer       Value isSameStride = builder.create<arith::CmpIOp>(
1225eee80ceSMatthias Springer           loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
1235eee80ceSMatthias Springer       builder.create<cf::AssertOp>(
1245eee80ceSMatthias Springer           loc, isSameStride,
125d94aeb50SRyan Holt           RuntimeVerifiableOpInterface::generateErrorMessage(
126d94aeb50SRyan Holt               op, "stride mismatch of dim " + std::to_string(it.index())));
1275eee80ceSMatthias Springer     }
1285eee80ceSMatthias Springer   }
1295eee80ceSMatthias Springer };
1305eee80ceSMatthias Springer 
131847a6f8fSRyan Holt /// Verifies that the indices on load/store ops are in-bounds of the memref's
132847a6f8fSRyan Holt /// index space: 0 <= index#i < dim#i
133847a6f8fSRyan Holt template <typename LoadStoreOp>
134847a6f8fSRyan Holt struct LoadStoreOpInterface
135847a6f8fSRyan Holt     : public RuntimeVerifiableOpInterface::ExternalModel<
136847a6f8fSRyan Holt           LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
137847a6f8fSRyan Holt   void generateRuntimeVerification(Operation *op, OpBuilder &builder,
138847a6f8fSRyan Holt                                    Location loc) const {
139847a6f8fSRyan Holt     auto loadStoreOp = cast<LoadStoreOp>(op);
140847a6f8fSRyan Holt 
141847a6f8fSRyan Holt     auto memref = loadStoreOp.getMemref();
142847a6f8fSRyan Holt     auto rank = memref.getType().getRank();
143847a6f8fSRyan Holt     if (rank == 0) {
144847a6f8fSRyan Holt       return;
145847a6f8fSRyan Holt     }
146847a6f8fSRyan Holt     auto indices = loadStoreOp.getIndices();
147847a6f8fSRyan Holt 
148847a6f8fSRyan Holt     auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
149847a6f8fSRyan Holt     Value assertCond;
150847a6f8fSRyan Holt     for (auto i : llvm::seq<int64_t>(0, rank)) {
151847a6f8fSRyan Holt       auto index = indices[i];
152847a6f8fSRyan Holt 
153847a6f8fSRyan Holt       auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
154847a6f8fSRyan Holt 
155847a6f8fSRyan Holt       auto geLow = builder.createOrFold<arith::CmpIOp>(
156847a6f8fSRyan Holt           loc, arith::CmpIPredicate::sge, index, zero);
157847a6f8fSRyan Holt       auto ltHigh = builder.createOrFold<arith::CmpIOp>(
158847a6f8fSRyan Holt           loc, arith::CmpIPredicate::slt, index, dimOp);
159847a6f8fSRyan Holt       auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
160847a6f8fSRyan Holt 
161847a6f8fSRyan Holt       assertCond =
162847a6f8fSRyan Holt           i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
163847a6f8fSRyan Holt                 : andOp;
164847a6f8fSRyan Holt     }
165847a6f8fSRyan Holt     builder.create<cf::AssertOp>(
166d94aeb50SRyan Holt         loc, assertCond,
167d94aeb50SRyan Holt         RuntimeVerifiableOpInterface::generateErrorMessage(
168d94aeb50SRyan Holt             op, "out-of-bounds access"));
169847a6f8fSRyan Holt   }
170847a6f8fSRyan Holt };
171847a6f8fSRyan Holt 
172847a6f8fSRyan Holt /// Compute the linear index for the provided strided layout and indices.
173847a6f8fSRyan Holt Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
174847a6f8fSRyan Holt                          ArrayRef<OpFoldResult> strides,
175847a6f8fSRyan Holt                          ArrayRef<OpFoldResult> indices) {
176847a6f8fSRyan Holt   auto [expr, values] = computeLinearIndex(offset, strides, indices);
177847a6f8fSRyan Holt   auto index =
178847a6f8fSRyan Holt       affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
179847a6f8fSRyan Holt   return getValueOrCreateConstantIndexOp(builder, loc, index);
180847a6f8fSRyan Holt }
181847a6f8fSRyan Holt 
182847a6f8fSRyan Holt /// Returns two Values representing the bounds of the provided strided layout
183847a6f8fSRyan Holt /// metadata. The bounds are returned as a half open interval -- [low, high).
184847a6f8fSRyan Holt std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
185847a6f8fSRyan Holt                                             OpFoldResult offset,
186847a6f8fSRyan Holt                                             ArrayRef<OpFoldResult> strides,
187847a6f8fSRyan Holt                                             ArrayRef<OpFoldResult> sizes) {
188847a6f8fSRyan Holt   auto zeros = SmallVector<int64_t>(sizes.size(), 0);
189847a6f8fSRyan Holt   auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
190847a6f8fSRyan Holt   auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
191847a6f8fSRyan Holt   auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
192847a6f8fSRyan Holt   return {lowerBound, upperBound};
193847a6f8fSRyan Holt }
194847a6f8fSRyan Holt 
195847a6f8fSRyan Holt /// Returns two Values representing the bounds of the memref. The bounds are
196847a6f8fSRyan Holt /// returned as a half open interval -- [low, high).
197847a6f8fSRyan Holt std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
198847a6f8fSRyan Holt                                             TypedValue<BaseMemRefType> memref) {
199847a6f8fSRyan Holt   auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
200847a6f8fSRyan Holt   auto offset = runtimeMetadata.getConstifiedMixedOffset();
201847a6f8fSRyan Holt   auto strides = runtimeMetadata.getConstifiedMixedStrides();
202847a6f8fSRyan Holt   auto sizes = runtimeMetadata.getConstifiedMixedSizes();
203847a6f8fSRyan Holt   return computeLinearBounds(builder, loc, offset, strides, sizes);
204847a6f8fSRyan Holt }
205847a6f8fSRyan Holt 
206847a6f8fSRyan Holt /// Verifies that the linear bounds of a reinterpret_cast op are within the
207847a6f8fSRyan Holt /// linear bounds of the base memref: low >= baseLow && high <= baseHigh
208847a6f8fSRyan Holt struct ReinterpretCastOpInterface
209847a6f8fSRyan Holt     : public RuntimeVerifiableOpInterface::ExternalModel<
210847a6f8fSRyan Holt           ReinterpretCastOpInterface, ReinterpretCastOp> {
211847a6f8fSRyan Holt   void generateRuntimeVerification(Operation *op, OpBuilder &builder,
212847a6f8fSRyan Holt                                    Location loc) const {
213847a6f8fSRyan Holt     auto reinterpretCast = cast<ReinterpretCastOp>(op);
214847a6f8fSRyan Holt     auto baseMemref = reinterpretCast.getSource();
215847a6f8fSRyan Holt     auto resultMemref =
216847a6f8fSRyan Holt         cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
217847a6f8fSRyan Holt 
218847a6f8fSRyan Holt     builder.setInsertionPointAfter(op);
219847a6f8fSRyan Holt 
220847a6f8fSRyan Holt     // Compute the linear bounds of the base memref
221847a6f8fSRyan Holt     auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
222847a6f8fSRyan Holt 
223847a6f8fSRyan Holt     // Compute the linear bounds of the resulting memref
224847a6f8fSRyan Holt     auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
225847a6f8fSRyan Holt 
226847a6f8fSRyan Holt     // Check low >= baseLow
227847a6f8fSRyan Holt     auto geLow = builder.createOrFold<arith::CmpIOp>(
228847a6f8fSRyan Holt         loc, arith::CmpIPredicate::sge, low, baseLow);
229847a6f8fSRyan Holt 
230847a6f8fSRyan Holt     // Check high <= baseHigh
231847a6f8fSRyan Holt     auto leHigh = builder.createOrFold<arith::CmpIOp>(
232847a6f8fSRyan Holt         loc, arith::CmpIPredicate::sle, high, baseHigh);
233847a6f8fSRyan Holt 
234847a6f8fSRyan Holt     auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
235847a6f8fSRyan Holt 
236847a6f8fSRyan Holt     builder.create<cf::AssertOp>(
237847a6f8fSRyan Holt         loc, assertCond,
238d94aeb50SRyan Holt         RuntimeVerifiableOpInterface::generateErrorMessage(
239847a6f8fSRyan Holt             op,
240847a6f8fSRyan Holt             "result of reinterpret_cast is out-of-bounds of the base memref"));
241847a6f8fSRyan Holt   }
242847a6f8fSRyan Holt };
243847a6f8fSRyan Holt 
244847a6f8fSRyan Holt /// Verifies that the linear bounds of a subview op are within the linear bounds
245847a6f8fSRyan Holt /// of the base memref: low >= baseLow && high <= baseHigh
246847a6f8fSRyan Holt /// TODO: This is not yet a full runtime verification of subview. For example,
247847a6f8fSRyan Holt /// consider:
248847a6f8fSRyan Holt ///   %m = memref.alloc(%c10, %c10) : memref<10x10xf32>
249847a6f8fSRyan Holt ///   memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1]
250847a6f8fSRyan Holt ///      : memref<?x?xf32> to memref<?x?xf32>
251847a6f8fSRyan Holt /// The subview is in-bounds of the entire base memref but the first dimension
252847a6f8fSRyan Holt /// is out-of-bounds. Future work would verify the bounds on a per-dimension
253847a6f8fSRyan Holt /// basis.
254847a6f8fSRyan Holt struct SubViewOpInterface
255847a6f8fSRyan Holt     : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
256847a6f8fSRyan Holt                                                          SubViewOp> {
257847a6f8fSRyan Holt   void generateRuntimeVerification(Operation *op, OpBuilder &builder,
258847a6f8fSRyan Holt                                    Location loc) const {
259847a6f8fSRyan Holt     auto subView = cast<SubViewOp>(op);
260847a6f8fSRyan Holt     auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource());
261847a6f8fSRyan Holt     auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult());
262847a6f8fSRyan Holt 
263847a6f8fSRyan Holt     builder.setInsertionPointAfter(op);
264847a6f8fSRyan Holt 
265847a6f8fSRyan Holt     // Compute the linear bounds of the base memref
266847a6f8fSRyan Holt     auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
267847a6f8fSRyan Holt 
268847a6f8fSRyan Holt     // Compute the linear bounds of the resulting memref
269847a6f8fSRyan Holt     auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
270847a6f8fSRyan Holt 
271847a6f8fSRyan Holt     // Check low >= baseLow
272847a6f8fSRyan Holt     auto geLow = builder.createOrFold<arith::CmpIOp>(
273847a6f8fSRyan Holt         loc, arith::CmpIPredicate::sge, low, baseLow);
274847a6f8fSRyan Holt 
275847a6f8fSRyan Holt     // Check high <= baseHigh
276847a6f8fSRyan Holt     auto leHigh = builder.createOrFold<arith::CmpIOp>(
277847a6f8fSRyan Holt         loc, arith::CmpIPredicate::sle, high, baseHigh);
278847a6f8fSRyan Holt 
279847a6f8fSRyan Holt     auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
280847a6f8fSRyan Holt 
281847a6f8fSRyan Holt     builder.create<cf::AssertOp>(
282847a6f8fSRyan Holt         loc, assertCond,
283d94aeb50SRyan Holt         RuntimeVerifiableOpInterface::generateErrorMessage(
284d94aeb50SRyan Holt             op, "subview is out-of-bounds of the base memref"));
285847a6f8fSRyan Holt   }
286847a6f8fSRyan Holt };
287847a6f8fSRyan Holt 
288108b08f2SMatthias Springer struct ExpandShapeOpInterface
289108b08f2SMatthias Springer     : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
290108b08f2SMatthias Springer                                                          ExpandShapeOp> {
291108b08f2SMatthias Springer   void generateRuntimeVerification(Operation *op, OpBuilder &builder,
292108b08f2SMatthias Springer                                    Location loc) const {
293108b08f2SMatthias Springer     auto expandShapeOp = cast<ExpandShapeOp>(op);
294108b08f2SMatthias Springer 
295108b08f2SMatthias Springer     // Verify that the expanded dim sizes are a product of the collapsed dim
296108b08f2SMatthias Springer     // size.
29770423eedSAdrian Kuegel     for (const auto &it :
29870423eedSAdrian Kuegel          llvm::enumerate(expandShapeOp.getReassociationIndices())) {
299108b08f2SMatthias Springer       Value srcDimSz =
300108b08f2SMatthias Springer           builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
301108b08f2SMatthias Springer       int64_t groupSz = 1;
302108b08f2SMatthias Springer       bool foundDynamicDim = false;
303108b08f2SMatthias Springer       for (int64_t resultDim : it.value()) {
304108b08f2SMatthias Springer         if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
305108b08f2SMatthias Springer           // Keep this assert here in case the op is extended in the future.
306108b08f2SMatthias Springer           assert(!foundDynamicDim &&
307108b08f2SMatthias Springer                  "more than one dynamic dim found in reassoc group");
30802f4cfa3SKazu Hirata           (void)foundDynamicDim;
309108b08f2SMatthias Springer           foundDynamicDim = true;
310108b08f2SMatthias Springer           continue;
311108b08f2SMatthias Springer         }
312108b08f2SMatthias Springer         groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
313108b08f2SMatthias Springer       }
314108b08f2SMatthias Springer       Value staticResultDimSz =
315108b08f2SMatthias Springer           builder.create<arith::ConstantIndexOp>(loc, groupSz);
316108b08f2SMatthias Springer       // staticResultDimSz must divide srcDimSz evenly.
317108b08f2SMatthias Springer       Value mod =
318108b08f2SMatthias Springer           builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
319108b08f2SMatthias Springer       Value isModZero = builder.create<arith::CmpIOp>(
320108b08f2SMatthias Springer           loc, arith::CmpIPredicate::eq, mod,
321108b08f2SMatthias Springer           builder.create<arith::ConstantIndexOp>(loc, 0));
322108b08f2SMatthias Springer       builder.create<cf::AssertOp>(
323108b08f2SMatthias Springer           loc, isModZero,
324d94aeb50SRyan Holt           RuntimeVerifiableOpInterface::generateErrorMessage(
325d94aeb50SRyan Holt               op, "static result dims in reassoc group do not "
3265eee80ceSMatthias Springer                   "divide src dim evenly"));
327108b08f2SMatthias Springer     }
328108b08f2SMatthias Springer   }
329108b08f2SMatthias Springer };
330108b08f2SMatthias Springer } // namespace
331108b08f2SMatthias Springer } // namespace memref
332108b08f2SMatthias Springer } // namespace mlir
333108b08f2SMatthias Springer 
334108b08f2SMatthias Springer void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
335108b08f2SMatthias Springer     DialectRegistry &registry) {
336108b08f2SMatthias Springer   registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
3375eee80ceSMatthias Springer     CastOp::attachInterface<CastOpInterface>(*ctx);
338108b08f2SMatthias Springer     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
339847a6f8fSRyan Holt     LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
340847a6f8fSRyan Holt     ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
341847a6f8fSRyan Holt     StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
342847a6f8fSRyan Holt     SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
343108b08f2SMatthias Springer 
344108b08f2SMatthias Springer     // Load additional dialects of which ops may get created.
345847a6f8fSRyan Holt     ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
346847a6f8fSRyan Holt                      cf::ControlFlowDialect>();
347108b08f2SMatthias Springer   });
348108b08f2SMatthias Springer }
349