1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Arith/Utils/Utils.h" 14 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" 15 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 18 #include "mlir/Dialect/Utils/IndexingUtils.h" 19 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" 20 21 using namespace mlir; 22 23 namespace mlir { 24 namespace memref { 25 namespace { 26 struct CastOpInterface 27 : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface, 28 CastOp> { 29 void generateRuntimeVerification(Operation *op, OpBuilder &builder, 30 Location loc) const { 31 auto castOp = cast<CastOp>(op); 32 auto srcType = cast<BaseMemRefType>(castOp.getSource().getType()); 33 34 // Nothing to check if the result is an unranked memref. 35 auto resultType = dyn_cast<MemRefType>(castOp.getType()); 36 if (!resultType) 37 return; 38 39 if (isa<UnrankedMemRefType>(srcType)) { 40 // Check rank. 41 Value srcRank = builder.create<RankOp>(loc, castOp.getSource()); 42 Value resultRank = 43 builder.create<arith::ConstantIndexOp>(loc, resultType.getRank()); 44 Value isSameRank = builder.create<arith::CmpIOp>( 45 loc, arith::CmpIPredicate::eq, srcRank, resultRank); 46 builder.create<cf::AssertOp>( 47 loc, isSameRank, 48 RuntimeVerifiableOpInterface::generateErrorMessage(op, 49 "rank mismatch")); 50 } 51 52 // Get source offset and strides. We do not have an op to get offsets and 53 // strides from unranked memrefs, so cast the source to a type with fully 54 // dynamic layout, from which we can then extract the offset and strides. 55 // (Rank was already verified.) 56 int64_t dynamicOffset = ShapedType::kDynamic; 57 SmallVector<int64_t> dynamicShape(resultType.getRank(), 58 ShapedType::kDynamic); 59 auto stridedLayout = StridedLayoutAttr::get(builder.getContext(), 60 dynamicOffset, dynamicShape); 61 auto dynStridesType = 62 MemRefType::get(dynamicShape, resultType.getElementType(), 63 stridedLayout, resultType.getMemorySpace()); 64 Value helperCast = 65 builder.create<CastOp>(loc, dynStridesType, castOp.getSource()); 66 auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast); 67 68 // Check dimension sizes. 69 for (const auto &it : llvm::enumerate(resultType.getShape())) { 70 // Static dim size -> static/dynamic dim size does not need verification. 71 if (auto rankedSrcType = dyn_cast<MemRefType>(srcType)) 72 if (!rankedSrcType.isDynamicDim(it.index())) 73 continue; 74 75 // Static/dynamic dim size -> dynamic dim size does not need verification. 76 if (resultType.isDynamicDim(it.index())) 77 continue; 78 79 Value srcDimSz = 80 builder.create<DimOp>(loc, castOp.getSource(), it.index()); 81 Value resultDimSz = 82 builder.create<arith::ConstantIndexOp>(loc, it.value()); 83 Value isSameSz = builder.create<arith::CmpIOp>( 84 loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); 85 builder.create<cf::AssertOp>( 86 loc, isSameSz, 87 RuntimeVerifiableOpInterface::generateErrorMessage( 88 op, "size mismatch of dim " + std::to_string(it.index()))); 89 } 90 91 // Get result offset and strides. 92 int64_t resultOffset; 93 SmallVector<int64_t> resultStrides; 94 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset))) 95 return; 96 97 // Check offset. 98 if (resultOffset != ShapedType::kDynamic) { 99 // Static/dynamic offset -> dynamic offset does not need verification. 100 Value srcOffset = metadataOp.getResult(1); 101 Value resultOffsetVal = 102 builder.create<arith::ConstantIndexOp>(loc, resultOffset); 103 Value isSameOffset = builder.create<arith::CmpIOp>( 104 loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal); 105 builder.create<cf::AssertOp>( 106 loc, isSameOffset, 107 RuntimeVerifiableOpInterface::generateErrorMessage( 108 op, "offset mismatch")); 109 } 110 111 // Check strides. 112 for (const auto &it : llvm::enumerate(resultStrides)) { 113 // Static/dynamic stride -> dynamic stride does not need verification. 114 if (it.value() == ShapedType::kDynamic) 115 continue; 116 117 Value srcStride = 118 metadataOp.getResult(2 + resultType.getRank() + it.index()); 119 Value resultStrideVal = 120 builder.create<arith::ConstantIndexOp>(loc, it.value()); 121 Value isSameStride = builder.create<arith::CmpIOp>( 122 loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal); 123 builder.create<cf::AssertOp>( 124 loc, isSameStride, 125 RuntimeVerifiableOpInterface::generateErrorMessage( 126 op, "stride mismatch of dim " + std::to_string(it.index()))); 127 } 128 } 129 }; 130 131 /// Verifies that the indices on load/store ops are in-bounds of the memref's 132 /// index space: 0 <= index#i < dim#i 133 template <typename LoadStoreOp> 134 struct LoadStoreOpInterface 135 : public RuntimeVerifiableOpInterface::ExternalModel< 136 LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> { 137 void generateRuntimeVerification(Operation *op, OpBuilder &builder, 138 Location loc) const { 139 auto loadStoreOp = cast<LoadStoreOp>(op); 140 141 auto memref = loadStoreOp.getMemref(); 142 auto rank = memref.getType().getRank(); 143 if (rank == 0) { 144 return; 145 } 146 auto indices = loadStoreOp.getIndices(); 147 148 auto zero = builder.create<arith::ConstantIndexOp>(loc, 0); 149 Value assertCond; 150 for (auto i : llvm::seq<int64_t>(0, rank)) { 151 auto index = indices[i]; 152 153 auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i); 154 155 auto geLow = builder.createOrFold<arith::CmpIOp>( 156 loc, arith::CmpIPredicate::sge, index, zero); 157 auto ltHigh = builder.createOrFold<arith::CmpIOp>( 158 loc, arith::CmpIPredicate::slt, index, dimOp); 159 auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh); 160 161 assertCond = 162 i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp) 163 : andOp; 164 } 165 builder.create<cf::AssertOp>( 166 loc, assertCond, 167 RuntimeVerifiableOpInterface::generateErrorMessage( 168 op, "out-of-bounds access")); 169 } 170 }; 171 172 /// Compute the linear index for the provided strided layout and indices. 173 Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset, 174 ArrayRef<OpFoldResult> strides, 175 ArrayRef<OpFoldResult> indices) { 176 auto [expr, values] = computeLinearIndex(offset, strides, indices); 177 auto index = 178 affine::makeComposedFoldedAffineApply(builder, loc, expr, values); 179 return getValueOrCreateConstantIndexOp(builder, loc, index); 180 } 181 182 /// Returns two Values representing the bounds of the provided strided layout 183 /// metadata. The bounds are returned as a half open interval -- [low, high). 184 std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc, 185 OpFoldResult offset, 186 ArrayRef<OpFoldResult> strides, 187 ArrayRef<OpFoldResult> sizes) { 188 auto zeros = SmallVector<int64_t>(sizes.size(), 0); 189 auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros); 190 auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices); 191 auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes); 192 return {lowerBound, upperBound}; 193 } 194 195 /// Returns two Values representing the bounds of the memref. The bounds are 196 /// returned as a half open interval -- [low, high). 197 std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc, 198 TypedValue<BaseMemRefType> memref) { 199 auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref); 200 auto offset = runtimeMetadata.getConstifiedMixedOffset(); 201 auto strides = runtimeMetadata.getConstifiedMixedStrides(); 202 auto sizes = runtimeMetadata.getConstifiedMixedSizes(); 203 return computeLinearBounds(builder, loc, offset, strides, sizes); 204 } 205 206 /// Verifies that the linear bounds of a reinterpret_cast op are within the 207 /// linear bounds of the base memref: low >= baseLow && high <= baseHigh 208 struct ReinterpretCastOpInterface 209 : public RuntimeVerifiableOpInterface::ExternalModel< 210 ReinterpretCastOpInterface, ReinterpretCastOp> { 211 void generateRuntimeVerification(Operation *op, OpBuilder &builder, 212 Location loc) const { 213 auto reinterpretCast = cast<ReinterpretCastOp>(op); 214 auto baseMemref = reinterpretCast.getSource(); 215 auto resultMemref = 216 cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult()); 217 218 builder.setInsertionPointAfter(op); 219 220 // Compute the linear bounds of the base memref 221 auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref); 222 223 // Compute the linear bounds of the resulting memref 224 auto [low, high] = computeLinearBounds(builder, loc, resultMemref); 225 226 // Check low >= baseLow 227 auto geLow = builder.createOrFold<arith::CmpIOp>( 228 loc, arith::CmpIPredicate::sge, low, baseLow); 229 230 // Check high <= baseHigh 231 auto leHigh = builder.createOrFold<arith::CmpIOp>( 232 loc, arith::CmpIPredicate::sle, high, baseHigh); 233 234 auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh); 235 236 builder.create<cf::AssertOp>( 237 loc, assertCond, 238 RuntimeVerifiableOpInterface::generateErrorMessage( 239 op, 240 "result of reinterpret_cast is out-of-bounds of the base memref")); 241 } 242 }; 243 244 /// Verifies that the linear bounds of a subview op are within the linear bounds 245 /// of the base memref: low >= baseLow && high <= baseHigh 246 /// TODO: This is not yet a full runtime verification of subview. For example, 247 /// consider: 248 /// %m = memref.alloc(%c10, %c10) : memref<10x10xf32> 249 /// memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1] 250 /// : memref<?x?xf32> to memref<?x?xf32> 251 /// The subview is in-bounds of the entire base memref but the first dimension 252 /// is out-of-bounds. Future work would verify the bounds on a per-dimension 253 /// basis. 254 struct SubViewOpInterface 255 : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface, 256 SubViewOp> { 257 void generateRuntimeVerification(Operation *op, OpBuilder &builder, 258 Location loc) const { 259 auto subView = cast<SubViewOp>(op); 260 auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource()); 261 auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult()); 262 263 builder.setInsertionPointAfter(op); 264 265 // Compute the linear bounds of the base memref 266 auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref); 267 268 // Compute the linear bounds of the resulting memref 269 auto [low, high] = computeLinearBounds(builder, loc, resultMemref); 270 271 // Check low >= baseLow 272 auto geLow = builder.createOrFold<arith::CmpIOp>( 273 loc, arith::CmpIPredicate::sge, low, baseLow); 274 275 // Check high <= baseHigh 276 auto leHigh = builder.createOrFold<arith::CmpIOp>( 277 loc, arith::CmpIPredicate::sle, high, baseHigh); 278 279 auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh); 280 281 builder.create<cf::AssertOp>( 282 loc, assertCond, 283 RuntimeVerifiableOpInterface::generateErrorMessage( 284 op, "subview is out-of-bounds of the base memref")); 285 } 286 }; 287 288 struct ExpandShapeOpInterface 289 : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface, 290 ExpandShapeOp> { 291 void generateRuntimeVerification(Operation *op, OpBuilder &builder, 292 Location loc) const { 293 auto expandShapeOp = cast<ExpandShapeOp>(op); 294 295 // Verify that the expanded dim sizes are a product of the collapsed dim 296 // size. 297 for (const auto &it : 298 llvm::enumerate(expandShapeOp.getReassociationIndices())) { 299 Value srcDimSz = 300 builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index()); 301 int64_t groupSz = 1; 302 bool foundDynamicDim = false; 303 for (int64_t resultDim : it.value()) { 304 if (expandShapeOp.getResultType().isDynamicDim(resultDim)) { 305 // Keep this assert here in case the op is extended in the future. 306 assert(!foundDynamicDim && 307 "more than one dynamic dim found in reassoc group"); 308 (void)foundDynamicDim; 309 foundDynamicDim = true; 310 continue; 311 } 312 groupSz *= expandShapeOp.getResultType().getDimSize(resultDim); 313 } 314 Value staticResultDimSz = 315 builder.create<arith::ConstantIndexOp>(loc, groupSz); 316 // staticResultDimSz must divide srcDimSz evenly. 317 Value mod = 318 builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz); 319 Value isModZero = builder.create<arith::CmpIOp>( 320 loc, arith::CmpIPredicate::eq, mod, 321 builder.create<arith::ConstantIndexOp>(loc, 0)); 322 builder.create<cf::AssertOp>( 323 loc, isModZero, 324 RuntimeVerifiableOpInterface::generateErrorMessage( 325 op, "static result dims in reassoc group do not " 326 "divide src dim evenly")); 327 } 328 } 329 }; 330 } // namespace 331 } // namespace memref 332 } // namespace mlir 333 334 void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels( 335 DialectRegistry ®istry) { 336 registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { 337 CastOp::attachInterface<CastOpInterface>(*ctx); 338 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 339 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx); 340 ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx); 341 StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx); 342 SubViewOp::attachInterface<SubViewOpInterface>(*ctx); 343 344 // Load additional dialects of which ops may get created. 345 ctx->loadDialect<affine::AffineDialect, arith::ArithDialect, 346 cf::ControlFlowDialect>(); 347 }); 348 } 349