1 //===- DecomposeMemRefs.cpp - Decompose memrefs pass implementation -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements decompose memrefs pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 16 #include "mlir/Dialect/GPU/Transforms/Passes.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/Utils/IndexingUtils.h" 19 #include "mlir/IR/AffineExpr.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Pass/Pass.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 namespace mlir { 26 #define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS 27 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc" 28 } // namespace mlir 29 30 using namespace mlir; 31 32 static MemRefType inferCastResultType(Value source, OpFoldResult offset) { 33 auto sourceType = cast<BaseMemRefType>(source.getType()); 34 SmallVector<int64_t> staticOffsets; 35 SmallVector<Value> dynamicOffsets; 36 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets); 37 auto stridedLayout = 38 StridedLayoutAttr::get(source.getContext(), staticOffsets.front(), {}); 39 return MemRefType::get({}, sourceType.getElementType(), stridedLayout, 40 sourceType.getMemorySpace()); 41 } 42 43 static void setInsertionPointToStart(OpBuilder &builder, Value val) { 44 if (auto *parentOp = val.getDefiningOp()) { 45 builder.setInsertionPointAfter(parentOp); 46 } else { 47 builder.setInsertionPointToStart(val.getParentBlock()); 48 } 49 } 50 51 static bool isInsideLaunch(Operation *op) { 52 return op->getParentOfType<gpu::LaunchOp>(); 53 } 54 55 static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>> 56 getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, 57 ArrayRef<OpFoldResult> subOffsets, 58 ArrayRef<OpFoldResult> subStrides = std::nullopt) { 59 auto sourceType = cast<MemRefType>(source.getType()); 60 auto sourceRank = static_cast<unsigned>(sourceType.getRank()); 61 62 memref::ExtractStridedMetadataOp newExtractStridedMetadata; 63 { 64 OpBuilder::InsertionGuard g(rewriter); 65 setInsertionPointToStart(rewriter, source); 66 newExtractStridedMetadata = 67 rewriter.create<memref::ExtractStridedMetadataOp>(loc, source); 68 } 69 70 auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset(); 71 72 auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult { 73 return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal) 74 : rewriter.getIndexAttr(dim); 75 }; 76 77 OpFoldResult origOffset = 78 getDim(sourceOffset, newExtractStridedMetadata.getOffset()); 79 ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides(); 80 81 SmallVector<OpFoldResult> origStrides; 82 origStrides.reserve(sourceRank); 83 84 SmallVector<OpFoldResult> strides; 85 strides.reserve(sourceRank); 86 87 AffineExpr s0 = rewriter.getAffineSymbolExpr(0); 88 AffineExpr s1 = rewriter.getAffineSymbolExpr(1); 89 for (auto i : llvm::seq(0u, sourceRank)) { 90 OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]); 91 92 if (!subStrides.empty()) { 93 strides.push_back(affine::makeComposedFoldedAffineApply( 94 rewriter, loc, s0 * s1, {subStrides[i], origStride})); 95 } 96 97 origStrides.emplace_back(origStride); 98 } 99 100 auto &&[expr, values] = 101 computeLinearIndex(origOffset, origStrides, subOffsets); 102 OpFoldResult finalOffset = 103 affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values); 104 return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides}; 105 } 106 107 static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source, 108 ValueRange offsets) { 109 SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets); 110 auto &&[base, offset, ignore] = 111 getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp); 112 MemRefType retType = inferCastResultType(base, offset); 113 return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset, 114 std::nullopt, std::nullopt); 115 } 116 117 static bool needFlatten(Value val) { 118 auto type = cast<MemRefType>(val.getType()); 119 return type.getRank() != 0; 120 } 121 122 static bool checkLayout(Value val) { 123 auto type = cast<MemRefType>(val.getType()); 124 return type.getLayout().isIdentity() || 125 isa<StridedLayoutAttr>(type.getLayout()); 126 } 127 128 namespace { 129 struct FlattenLoad : public OpRewritePattern<memref::LoadOp> { 130 using OpRewritePattern::OpRewritePattern; 131 132 LogicalResult matchAndRewrite(memref::LoadOp op, 133 PatternRewriter &rewriter) const override { 134 if (!isInsideLaunch(op)) 135 return rewriter.notifyMatchFailure(op, "not inside gpu.launch"); 136 137 Value memref = op.getMemref(); 138 if (!needFlatten(memref)) 139 return rewriter.notifyMatchFailure(op, "nothing to do"); 140 141 if (!checkLayout(memref)) 142 return rewriter.notifyMatchFailure(op, "unsupported layout"); 143 144 Location loc = op.getLoc(); 145 Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices()); 146 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, flatMemref); 147 return success(); 148 } 149 }; 150 151 struct FlattenStore : public OpRewritePattern<memref::StoreOp> { 152 using OpRewritePattern::OpRewritePattern; 153 154 LogicalResult matchAndRewrite(memref::StoreOp op, 155 PatternRewriter &rewriter) const override { 156 if (!isInsideLaunch(op)) 157 return rewriter.notifyMatchFailure(op, "not inside gpu.launch"); 158 159 Value memref = op.getMemref(); 160 if (!needFlatten(memref)) 161 return rewriter.notifyMatchFailure(op, "nothing to do"); 162 163 if (!checkLayout(memref)) 164 return rewriter.notifyMatchFailure(op, "unsupported layout"); 165 166 Location loc = op.getLoc(); 167 Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices()); 168 Value value = op.getValue(); 169 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, value, flatMemref); 170 return success(); 171 } 172 }; 173 174 struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> { 175 using OpRewritePattern::OpRewritePattern; 176 177 LogicalResult matchAndRewrite(memref::SubViewOp op, 178 PatternRewriter &rewriter) const override { 179 if (!isInsideLaunch(op)) 180 return rewriter.notifyMatchFailure(op, "not inside gpu.launch"); 181 182 Value memref = op.getSource(); 183 if (!needFlatten(memref)) 184 return rewriter.notifyMatchFailure(op, "nothing to do"); 185 186 if (!checkLayout(memref)) 187 return rewriter.notifyMatchFailure(op, "unsupported layout"); 188 189 Location loc = op.getLoc(); 190 SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets(); 191 SmallVector<OpFoldResult> subSizes = op.getMixedSizes(); 192 SmallVector<OpFoldResult> subStrides = op.getMixedStrides(); 193 auto &&[base, finalOffset, strides] = 194 getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides); 195 196 auto srcType = cast<MemRefType>(memref.getType()); 197 auto resultType = cast<MemRefType>(op.getType()); 198 unsigned subRank = static_cast<unsigned>(resultType.getRank()); 199 200 llvm::SmallBitVector droppedDims = op.getDroppedDims(); 201 202 SmallVector<OpFoldResult> finalSizes; 203 finalSizes.reserve(subRank); 204 205 SmallVector<OpFoldResult> finalStrides; 206 finalStrides.reserve(subRank); 207 208 for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) { 209 if (droppedDims.test(i)) 210 continue; 211 212 finalSizes.push_back(subSizes[i]); 213 finalStrides.push_back(strides[i]); 214 } 215 216 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( 217 op, resultType, base, finalOffset, finalSizes, finalStrides); 218 return success(); 219 } 220 }; 221 222 struct GpuDecomposeMemrefsPass 223 : public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> { 224 225 void runOnOperation() override { 226 RewritePatternSet patterns(&getContext()); 227 228 populateGpuDecomposeMemrefsPatterns(patterns); 229 230 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 231 return signalPassFailure(); 232 } 233 }; 234 235 } // namespace 236 237 void mlir::populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns) { 238 patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>( 239 patterns.getContext()); 240 } 241 242 std::unique_ptr<Pass> mlir::createGpuDecomposeMemrefsPass() { 243 return std::make_unique<GpuDecomposeMemrefsPass>(); 244 } 245