1108b08f2SMatthias Springer //===- RuntimeOpVerification.cpp - Op Verification ------------------------===// 2108b08f2SMatthias Springer // 3108b08f2SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4108b08f2SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5108b08f2SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6108b08f2SMatthias Springer // 7108b08f2SMatthias Springer //===----------------------------------------------------------------------===// 8108b08f2SMatthias Springer 9108b08f2SMatthias Springer #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" 10108b08f2SMatthias Springer 11847a6f8fSRyan Holt #include "mlir/Dialect/Affine/IR/AffineOps.h" 12108b08f2SMatthias Springer #include "mlir/Dialect/Arith/IR/Arith.h" 13847a6f8fSRyan Holt #include "mlir/Dialect/Arith/Utils/Utils.h" 14108b08f2SMatthias Springer #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" 15108b08f2SMatthias Springer #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 16108b08f2SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 17847a6f8fSRyan Holt #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 18847a6f8fSRyan Holt #include "mlir/Dialect/Utils/IndexingUtils.h" 19108b08f2SMatthias Springer #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" 20108b08f2SMatthias Springer 215eee80ceSMatthias Springer using namespace mlir; 225eee80ceSMatthias Springer 23108b08f2SMatthias Springer namespace mlir { 24108b08f2SMatthias Springer namespace memref { 25108b08f2SMatthias Springer namespace { 265eee80ceSMatthias Springer struct CastOpInterface 275eee80ceSMatthias Springer : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface, 285eee80ceSMatthias Springer CastOp> { 295eee80ceSMatthias Springer void generateRuntimeVerification(Operation *op, OpBuilder &builder, 305eee80ceSMatthias Springer Location loc) const { 315eee80ceSMatthias Springer auto castOp = cast<CastOp>(op); 325550c821STres Popp auto srcType = cast<BaseMemRefType>(castOp.getSource().getType()); 335eee80ceSMatthias Springer 345eee80ceSMatthias Springer // Nothing to check if the result is an unranked memref. 355550c821STres Popp auto resultType = dyn_cast<MemRefType>(castOp.getType()); 365eee80ceSMatthias Springer if (!resultType) 375eee80ceSMatthias Springer return; 385eee80ceSMatthias Springer 395550c821STres Popp if (isa<UnrankedMemRefType>(srcType)) { 405eee80ceSMatthias Springer // Check rank. 415eee80ceSMatthias Springer Value srcRank = builder.create<RankOp>(loc, castOp.getSource()); 425eee80ceSMatthias Springer Value resultRank = 435eee80ceSMatthias Springer builder.create<arith::ConstantIndexOp>(loc, resultType.getRank()); 445eee80ceSMatthias Springer Value isSameRank = builder.create<arith::CmpIOp>( 455eee80ceSMatthias Springer loc, arith::CmpIPredicate::eq, srcRank, resultRank); 46d94aeb50SRyan Holt builder.create<cf::AssertOp>( 47d94aeb50SRyan Holt loc, isSameRank, 48d94aeb50SRyan Holt RuntimeVerifiableOpInterface::generateErrorMessage(op, 49d94aeb50SRyan Holt "rank mismatch")); 505eee80ceSMatthias Springer } 515eee80ceSMatthias Springer 525eee80ceSMatthias Springer // Get source offset and strides. We do not have an op to get offsets and 535eee80ceSMatthias Springer // strides from unranked memrefs, so cast the source to a type with fully 545eee80ceSMatthias Springer // dynamic layout, from which we can then extract the offset and strides. 555eee80ceSMatthias Springer // (Rank was already verified.) 565eee80ceSMatthias Springer int64_t dynamicOffset = ShapedType::kDynamic; 575eee80ceSMatthias Springer SmallVector<int64_t> dynamicShape(resultType.getRank(), 585eee80ceSMatthias Springer ShapedType::kDynamic); 595eee80ceSMatthias Springer auto stridedLayout = StridedLayoutAttr::get(builder.getContext(), 605eee80ceSMatthias Springer dynamicOffset, dynamicShape); 615eee80ceSMatthias Springer auto dynStridesType = 625eee80ceSMatthias Springer MemRefType::get(dynamicShape, resultType.getElementType(), 635eee80ceSMatthias Springer stridedLayout, resultType.getMemorySpace()); 645eee80ceSMatthias Springer Value helperCast = 655eee80ceSMatthias Springer builder.create<CastOp>(loc, dynStridesType, castOp.getSource()); 665eee80ceSMatthias Springer auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast); 675eee80ceSMatthias Springer 685eee80ceSMatthias Springer // Check dimension sizes. 695eee80ceSMatthias Springer for (const auto &it : llvm::enumerate(resultType.getShape())) { 705eee80ceSMatthias Springer // Static dim size -> static/dynamic dim size does not need verification. 715550c821STres Popp if (auto rankedSrcType = dyn_cast<MemRefType>(srcType)) 725eee80ceSMatthias Springer if (!rankedSrcType.isDynamicDim(it.index())) 735eee80ceSMatthias Springer continue; 745eee80ceSMatthias Springer 755eee80ceSMatthias Springer // Static/dynamic dim size -> dynamic dim size does not need verification. 765eee80ceSMatthias Springer if (resultType.isDynamicDim(it.index())) 775eee80ceSMatthias Springer continue; 785eee80ceSMatthias Springer 795eee80ceSMatthias Springer Value srcDimSz = 805eee80ceSMatthias Springer builder.create<DimOp>(loc, castOp.getSource(), it.index()); 815eee80ceSMatthias Springer Value resultDimSz = 825eee80ceSMatthias Springer builder.create<arith::ConstantIndexOp>(loc, it.value()); 835eee80ceSMatthias Springer Value isSameSz = builder.create<arith::CmpIOp>( 845eee80ceSMatthias Springer loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); 855eee80ceSMatthias Springer builder.create<cf::AssertOp>( 865eee80ceSMatthias Springer loc, isSameSz, 87d94aeb50SRyan Holt RuntimeVerifiableOpInterface::generateErrorMessage( 88d94aeb50SRyan Holt op, "size mismatch of dim " + std::to_string(it.index()))); 895eee80ceSMatthias Springer } 905eee80ceSMatthias Springer 915eee80ceSMatthias Springer // Get result offset and strides. 925eee80ceSMatthias Springer int64_t resultOffset; 935eee80ceSMatthias Springer SmallVector<int64_t> resultStrides; 94*6aaa8f25SMatthias Springer if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset))) 955eee80ceSMatthias Springer return; 965eee80ceSMatthias Springer 975eee80ceSMatthias Springer // Check offset. 985eee80ceSMatthias Springer if (resultOffset != ShapedType::kDynamic) { 995eee80ceSMatthias Springer // Static/dynamic offset -> dynamic offset does not need verification. 1005eee80ceSMatthias Springer Value srcOffset = metadataOp.getResult(1); 1015eee80ceSMatthias Springer Value resultOffsetVal = 1025eee80ceSMatthias Springer builder.create<arith::ConstantIndexOp>(loc, resultOffset); 1035eee80ceSMatthias Springer Value isSameOffset = builder.create<arith::CmpIOp>( 1045eee80ceSMatthias Springer loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal); 105d94aeb50SRyan Holt builder.create<cf::AssertOp>( 106d94aeb50SRyan Holt loc, isSameOffset, 107d94aeb50SRyan Holt RuntimeVerifiableOpInterface::generateErrorMessage( 108d94aeb50SRyan Holt op, "offset mismatch")); 1095eee80ceSMatthias Springer } 1105eee80ceSMatthias Springer 1115eee80ceSMatthias Springer // Check strides. 1125eee80ceSMatthias Springer for (const auto &it : llvm::enumerate(resultStrides)) { 1135eee80ceSMatthias Springer // Static/dynamic stride -> dynamic stride does not need verification. 1145eee80ceSMatthias Springer if (it.value() == ShapedType::kDynamic) 1155eee80ceSMatthias Springer continue; 1165eee80ceSMatthias Springer 1175eee80ceSMatthias Springer Value srcStride = 1185eee80ceSMatthias Springer metadataOp.getResult(2 + resultType.getRank() + it.index()); 1195eee80ceSMatthias Springer Value resultStrideVal = 1205eee80ceSMatthias Springer builder.create<arith::ConstantIndexOp>(loc, it.value()); 1215eee80ceSMatthias Springer Value isSameStride = builder.create<arith::CmpIOp>( 1225eee80ceSMatthias Springer loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal); 1235eee80ceSMatthias Springer builder.create<cf::AssertOp>( 1245eee80ceSMatthias Springer loc, isSameStride, 125d94aeb50SRyan Holt RuntimeVerifiableOpInterface::generateErrorMessage( 126d94aeb50SRyan Holt op, "stride mismatch of dim " + std::to_string(it.index()))); 1275eee80ceSMatthias Springer } 1285eee80ceSMatthias Springer } 1295eee80ceSMatthias Springer }; 1305eee80ceSMatthias Springer 131847a6f8fSRyan Holt /// Verifies that the indices on load/store ops are in-bounds of the memref's 132847a6f8fSRyan Holt /// index space: 0 <= index#i < dim#i 133847a6f8fSRyan Holt template <typename LoadStoreOp> 134847a6f8fSRyan Holt struct LoadStoreOpInterface 135847a6f8fSRyan Holt : public RuntimeVerifiableOpInterface::ExternalModel< 136847a6f8fSRyan Holt LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> { 137847a6f8fSRyan Holt void generateRuntimeVerification(Operation *op, OpBuilder &builder, 138847a6f8fSRyan Holt Location loc) const { 139847a6f8fSRyan Holt auto loadStoreOp = cast<LoadStoreOp>(op); 140847a6f8fSRyan Holt 141847a6f8fSRyan Holt auto memref = loadStoreOp.getMemref(); 142847a6f8fSRyan Holt auto rank = memref.getType().getRank(); 143847a6f8fSRyan Holt if (rank == 0) { 144847a6f8fSRyan Holt return; 145847a6f8fSRyan Holt } 146847a6f8fSRyan Holt auto indices = loadStoreOp.getIndices(); 147847a6f8fSRyan Holt 148847a6f8fSRyan Holt auto zero = builder.create<arith::ConstantIndexOp>(loc, 0); 149847a6f8fSRyan Holt Value assertCond; 150847a6f8fSRyan Holt for (auto i : llvm::seq<int64_t>(0, rank)) { 151847a6f8fSRyan Holt auto index = indices[i]; 152847a6f8fSRyan Holt 153847a6f8fSRyan Holt auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i); 154847a6f8fSRyan Holt 155847a6f8fSRyan Holt auto geLow = builder.createOrFold<arith::CmpIOp>( 156847a6f8fSRyan Holt loc, arith::CmpIPredicate::sge, index, zero); 157847a6f8fSRyan Holt auto ltHigh = builder.createOrFold<arith::CmpIOp>( 158847a6f8fSRyan Holt loc, arith::CmpIPredicate::slt, index, dimOp); 159847a6f8fSRyan Holt auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh); 160847a6f8fSRyan Holt 161847a6f8fSRyan Holt assertCond = 162847a6f8fSRyan Holt i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp) 163847a6f8fSRyan Holt : andOp; 164847a6f8fSRyan Holt } 165847a6f8fSRyan Holt builder.create<cf::AssertOp>( 166d94aeb50SRyan Holt loc, assertCond, 167d94aeb50SRyan Holt RuntimeVerifiableOpInterface::generateErrorMessage( 168d94aeb50SRyan Holt op, "out-of-bounds access")); 169847a6f8fSRyan Holt } 170847a6f8fSRyan Holt }; 171847a6f8fSRyan Holt 172847a6f8fSRyan Holt /// Compute the linear index for the provided strided layout and indices. 173847a6f8fSRyan Holt Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset, 174847a6f8fSRyan Holt ArrayRef<OpFoldResult> strides, 175847a6f8fSRyan Holt ArrayRef<OpFoldResult> indices) { 176847a6f8fSRyan Holt auto [expr, values] = computeLinearIndex(offset, strides, indices); 177847a6f8fSRyan Holt auto index = 178847a6f8fSRyan Holt affine::makeComposedFoldedAffineApply(builder, loc, expr, values); 179847a6f8fSRyan Holt return getValueOrCreateConstantIndexOp(builder, loc, index); 180847a6f8fSRyan Holt } 181847a6f8fSRyan Holt 182847a6f8fSRyan Holt /// Returns two Values representing the bounds of the provided strided layout 183847a6f8fSRyan Holt /// metadata. The bounds are returned as a half open interval -- [low, high). 184847a6f8fSRyan Holt std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc, 185847a6f8fSRyan Holt OpFoldResult offset, 186847a6f8fSRyan Holt ArrayRef<OpFoldResult> strides, 187847a6f8fSRyan Holt ArrayRef<OpFoldResult> sizes) { 188847a6f8fSRyan Holt auto zeros = SmallVector<int64_t>(sizes.size(), 0); 189847a6f8fSRyan Holt auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros); 190847a6f8fSRyan Holt auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices); 191847a6f8fSRyan Holt auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes); 192847a6f8fSRyan Holt return {lowerBound, upperBound}; 193847a6f8fSRyan Holt } 194847a6f8fSRyan Holt 195847a6f8fSRyan Holt /// Returns two Values representing the bounds of the memref. The bounds are 196847a6f8fSRyan Holt /// returned as a half open interval -- [low, high). 197847a6f8fSRyan Holt std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc, 198847a6f8fSRyan Holt TypedValue<BaseMemRefType> memref) { 199847a6f8fSRyan Holt auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref); 200847a6f8fSRyan Holt auto offset = runtimeMetadata.getConstifiedMixedOffset(); 201847a6f8fSRyan Holt auto strides = runtimeMetadata.getConstifiedMixedStrides(); 202847a6f8fSRyan Holt auto sizes = runtimeMetadata.getConstifiedMixedSizes(); 203847a6f8fSRyan Holt return computeLinearBounds(builder, loc, offset, strides, sizes); 204847a6f8fSRyan Holt } 205847a6f8fSRyan Holt 206847a6f8fSRyan Holt /// Verifies that the linear bounds of a reinterpret_cast op are within the 207847a6f8fSRyan Holt /// linear bounds of the base memref: low >= baseLow && high <= baseHigh 208847a6f8fSRyan Holt struct ReinterpretCastOpInterface 209847a6f8fSRyan Holt : public RuntimeVerifiableOpInterface::ExternalModel< 210847a6f8fSRyan Holt ReinterpretCastOpInterface, ReinterpretCastOp> { 211847a6f8fSRyan Holt void generateRuntimeVerification(Operation *op, OpBuilder &builder, 212847a6f8fSRyan Holt Location loc) const { 213847a6f8fSRyan Holt auto reinterpretCast = cast<ReinterpretCastOp>(op); 214847a6f8fSRyan Holt auto baseMemref = reinterpretCast.getSource(); 215847a6f8fSRyan Holt auto resultMemref = 216847a6f8fSRyan Holt cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult()); 217847a6f8fSRyan Holt 218847a6f8fSRyan Holt builder.setInsertionPointAfter(op); 219847a6f8fSRyan Holt 220847a6f8fSRyan Holt // Compute the linear bounds of the base memref 221847a6f8fSRyan Holt auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref); 222847a6f8fSRyan Holt 223847a6f8fSRyan Holt // Compute the linear bounds of the resulting memref 224847a6f8fSRyan Holt auto [low, high] = computeLinearBounds(builder, loc, resultMemref); 225847a6f8fSRyan Holt 226847a6f8fSRyan Holt // Check low >= baseLow 227847a6f8fSRyan Holt auto geLow = builder.createOrFold<arith::CmpIOp>( 228847a6f8fSRyan Holt loc, arith::CmpIPredicate::sge, low, baseLow); 229847a6f8fSRyan Holt 230847a6f8fSRyan Holt // Check high <= baseHigh 231847a6f8fSRyan Holt auto leHigh = builder.createOrFold<arith::CmpIOp>( 232847a6f8fSRyan Holt loc, arith::CmpIPredicate::sle, high, baseHigh); 233847a6f8fSRyan Holt 234847a6f8fSRyan Holt auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh); 235847a6f8fSRyan Holt 236847a6f8fSRyan Holt builder.create<cf::AssertOp>( 237847a6f8fSRyan Holt loc, assertCond, 238d94aeb50SRyan Holt RuntimeVerifiableOpInterface::generateErrorMessage( 239847a6f8fSRyan Holt op, 240847a6f8fSRyan Holt "result of reinterpret_cast is out-of-bounds of the base memref")); 241847a6f8fSRyan Holt } 242847a6f8fSRyan Holt }; 243847a6f8fSRyan Holt 244847a6f8fSRyan Holt /// Verifies that the linear bounds of a subview op are within the linear bounds 245847a6f8fSRyan Holt /// of the base memref: low >= baseLow && high <= baseHigh 246847a6f8fSRyan Holt /// TODO: This is not yet a full runtime verification of subview. For example, 247847a6f8fSRyan Holt /// consider: 248847a6f8fSRyan Holt /// %m = memref.alloc(%c10, %c10) : memref<10x10xf32> 249847a6f8fSRyan Holt /// memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1] 250847a6f8fSRyan Holt /// : memref<?x?xf32> to memref<?x?xf32> 251847a6f8fSRyan Holt /// The subview is in-bounds of the entire base memref but the first dimension 252847a6f8fSRyan Holt /// is out-of-bounds. Future work would verify the bounds on a per-dimension 253847a6f8fSRyan Holt /// basis. 254847a6f8fSRyan Holt struct SubViewOpInterface 255847a6f8fSRyan Holt : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface, 256847a6f8fSRyan Holt SubViewOp> { 257847a6f8fSRyan Holt void generateRuntimeVerification(Operation *op, OpBuilder &builder, 258847a6f8fSRyan Holt Location loc) const { 259847a6f8fSRyan Holt auto subView = cast<SubViewOp>(op); 260847a6f8fSRyan Holt auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource()); 261847a6f8fSRyan Holt auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult()); 262847a6f8fSRyan Holt 263847a6f8fSRyan Holt builder.setInsertionPointAfter(op); 264847a6f8fSRyan Holt 265847a6f8fSRyan Holt // Compute the linear bounds of the base memref 266847a6f8fSRyan Holt auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref); 267847a6f8fSRyan Holt 268847a6f8fSRyan Holt // Compute the linear bounds of the resulting memref 269847a6f8fSRyan Holt auto [low, high] = computeLinearBounds(builder, loc, resultMemref); 270847a6f8fSRyan Holt 271847a6f8fSRyan Holt // Check low >= baseLow 272847a6f8fSRyan Holt auto geLow = builder.createOrFold<arith::CmpIOp>( 273847a6f8fSRyan Holt loc, arith::CmpIPredicate::sge, low, baseLow); 274847a6f8fSRyan Holt 275847a6f8fSRyan Holt // Check high <= baseHigh 276847a6f8fSRyan Holt auto leHigh = builder.createOrFold<arith::CmpIOp>( 277847a6f8fSRyan Holt loc, arith::CmpIPredicate::sle, high, baseHigh); 278847a6f8fSRyan Holt 279847a6f8fSRyan Holt auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh); 280847a6f8fSRyan Holt 281847a6f8fSRyan Holt builder.create<cf::AssertOp>( 282847a6f8fSRyan Holt loc, assertCond, 283d94aeb50SRyan Holt RuntimeVerifiableOpInterface::generateErrorMessage( 284d94aeb50SRyan Holt op, "subview is out-of-bounds of the base memref")); 285847a6f8fSRyan Holt } 286847a6f8fSRyan Holt }; 287847a6f8fSRyan Holt 288108b08f2SMatthias Springer struct ExpandShapeOpInterface 289108b08f2SMatthias Springer : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface, 290108b08f2SMatthias Springer ExpandShapeOp> { 291108b08f2SMatthias Springer void generateRuntimeVerification(Operation *op, OpBuilder &builder, 292108b08f2SMatthias Springer Location loc) const { 293108b08f2SMatthias Springer auto expandShapeOp = cast<ExpandShapeOp>(op); 294108b08f2SMatthias Springer 295108b08f2SMatthias Springer // Verify that the expanded dim sizes are a product of the collapsed dim 296108b08f2SMatthias Springer // size. 29770423eedSAdrian Kuegel for (const auto &it : 29870423eedSAdrian Kuegel llvm::enumerate(expandShapeOp.getReassociationIndices())) { 299108b08f2SMatthias Springer Value srcDimSz = 300108b08f2SMatthias Springer builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index()); 301108b08f2SMatthias Springer int64_t groupSz = 1; 302108b08f2SMatthias Springer bool foundDynamicDim = false; 303108b08f2SMatthias Springer for (int64_t resultDim : it.value()) { 304108b08f2SMatthias Springer if (expandShapeOp.getResultType().isDynamicDim(resultDim)) { 305108b08f2SMatthias Springer // Keep this assert here in case the op is extended in the future. 306108b08f2SMatthias Springer assert(!foundDynamicDim && 307108b08f2SMatthias Springer "more than one dynamic dim found in reassoc group"); 30802f4cfa3SKazu Hirata (void)foundDynamicDim; 309108b08f2SMatthias Springer foundDynamicDim = true; 310108b08f2SMatthias Springer continue; 311108b08f2SMatthias Springer } 312108b08f2SMatthias Springer groupSz *= expandShapeOp.getResultType().getDimSize(resultDim); 313108b08f2SMatthias Springer } 314108b08f2SMatthias Springer Value staticResultDimSz = 315108b08f2SMatthias Springer builder.create<arith::ConstantIndexOp>(loc, groupSz); 316108b08f2SMatthias Springer // staticResultDimSz must divide srcDimSz evenly. 317108b08f2SMatthias Springer Value mod = 318108b08f2SMatthias Springer builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz); 319108b08f2SMatthias Springer Value isModZero = builder.create<arith::CmpIOp>( 320108b08f2SMatthias Springer loc, arith::CmpIPredicate::eq, mod, 321108b08f2SMatthias Springer builder.create<arith::ConstantIndexOp>(loc, 0)); 322108b08f2SMatthias Springer builder.create<cf::AssertOp>( 323108b08f2SMatthias Springer loc, isModZero, 324d94aeb50SRyan Holt RuntimeVerifiableOpInterface::generateErrorMessage( 325d94aeb50SRyan Holt op, "static result dims in reassoc group do not " 3265eee80ceSMatthias Springer "divide src dim evenly")); 327108b08f2SMatthias Springer } 328108b08f2SMatthias Springer } 329108b08f2SMatthias Springer }; 330108b08f2SMatthias Springer } // namespace 331108b08f2SMatthias Springer } // namespace memref 332108b08f2SMatthias Springer } // namespace mlir 333108b08f2SMatthias Springer 334108b08f2SMatthias Springer void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels( 335108b08f2SMatthias Springer DialectRegistry ®istry) { 336108b08f2SMatthias Springer registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { 3375eee80ceSMatthias Springer CastOp::attachInterface<CastOpInterface>(*ctx); 338108b08f2SMatthias Springer ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 339847a6f8fSRyan Holt LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx); 340847a6f8fSRyan Holt ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx); 341847a6f8fSRyan Holt StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx); 342847a6f8fSRyan Holt SubViewOp::attachInterface<SubViewOpInterface>(*ctx); 343108b08f2SMatthias Springer 344108b08f2SMatthias Springer // Load additional dialects of which ops may get created. 345847a6f8fSRyan Holt ctx->loadDialect<affine::AffineDialect, arith::ArithDialect, 346847a6f8fSRyan Holt cf::ControlFlowDialect>(); 347108b08f2SMatthias Springer }); 348108b08f2SMatthias Springer } 349