1 //===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass resolves `memref.dim` operations of result values in terms of 10 // shapes of their operands using the `InferShapedTypeOpInterface`. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/MemRef/Transforms/Passes.h" 15 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Arith/IR/Arith.h" 18 #include "mlir/Dialect/Arith/Utils/Utils.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/MemRef/Transforms/Transforms.h" 21 #include "mlir/Dialect/SCF/IR/SCF.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Interfaces/InferTypeOpInterface.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 26 namespace mlir { 27 namespace memref { 28 #define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMS 29 #define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMS 30 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" 31 } // namespace memref 32 } // namespace mlir 33 34 using namespace mlir; 35 36 namespace { 37 /// Fold dim of an operation that implements the InferShapedTypeOpInterface 38 template <typename OpTy> 39 struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> { 40 using OpRewritePattern<OpTy>::OpRewritePattern; 41 42 LogicalResult matchAndRewrite(OpTy dimOp, 43 PatternRewriter &rewriter) const override { 44 OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource()); 45 if (!dimValue) 46 return failure(); 47 auto shapedTypeOp = 48 dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner()); 49 if (!shapedTypeOp) 50 return failure(); 51 52 std::optional<int64_t> dimIndex = dimOp.getConstantIndex(); 53 if (!dimIndex) 54 return failure(); 55 56 SmallVector<Value> reifiedResultShapes; 57 if (failed(shapedTypeOp.reifyReturnTypeShapes( 58 rewriter, shapedTypeOp->getOperands(), reifiedResultShapes))) 59 return failure(); 60 61 if (reifiedResultShapes.size() != shapedTypeOp->getNumResults()) 62 return failure(); 63 64 Value resultShape = reifiedResultShapes[dimValue.getResultNumber()]; 65 auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.getType()); 66 if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType())) 67 return failure(); 68 69 Location loc = dimOp->getLoc(); 70 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 71 dimOp, resultShape, 72 rewriter.create<arith::ConstantIndexOp>(loc, *dimIndex).getResult()); 73 return success(); 74 } 75 }; 76 77 /// Fold dim of an operation that implements the InferShapedTypeOpInterface 78 template <typename OpTy> 79 struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> { 80 using OpRewritePattern<OpTy>::OpRewritePattern; 81 82 void initialize() { OpRewritePattern<OpTy>::setHasBoundedRewriteRecursion(); } 83 84 LogicalResult matchAndRewrite(OpTy dimOp, 85 PatternRewriter &rewriter) const override { 86 OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource()); 87 if (!dimValue) 88 return failure(); 89 std::optional<int64_t> dimIndex = dimOp.getConstantIndex(); 90 if (!dimIndex) 91 return failure(); 92 93 ReifiedRankedShapedTypeDims reifiedResultShapes; 94 if (failed(reifyResultShapes(rewriter, dimValue.getOwner(), 95 reifiedResultShapes))) 96 return failure(); 97 unsigned resultNumber = dimValue.getResultNumber(); 98 // Do not apply pattern if the IR is invalid (dim out of bounds). 99 if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size()) 100 return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds"); 101 Value replacement = getValueOrCreateConstantIndexOp( 102 rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]); 103 rewriter.replaceOp(dimOp, replacement); 104 return success(); 105 } 106 }; 107 108 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.: 109 /// 110 /// ``` 111 /// %0 = ... : tensor<?x?xf32> 112 /// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) { 113 /// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 114 /// ... 115 /// } 116 /// ``` 117 /// 118 /// is folded to: 119 /// 120 /// ``` 121 /// %0 = ... : tensor<?x?xf32> 122 /// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) { 123 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32> 124 /// ... 125 /// } 126 /// ``` 127 struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> { 128 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 129 130 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 131 PatternRewriter &rewriter) const final { 132 auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource()); 133 if (!blockArg) 134 return failure(); 135 // TODO: Enable this for loopLikeInterface. Restricting for scf.for 136 // because the init args shape might change in the loop body. 137 // For e.g.: 138 // ``` 139 // %0 = tensor.empty(%c1) : tensor<?xf32> 140 // %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) -> 141 // tensor<?xf32> { 142 // %1 = tensor.dim %arg0, %c0 : tensor<?xf32> 143 // %2 = arith.addi %c1, %1 : index 144 // %3 = tensor.empty(%2) : tensor<?xf32> 145 // scf.yield %3 : tensor<?xf32> 146 // } 147 // 148 // ``` 149 auto forAllOp = 150 dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp()); 151 if (!forAllOp) 152 return failure(); 153 Value initArg = forAllOp.getTiedLoopInit(blockArg)->get(); 154 rewriter.modifyOpInPlace( 155 dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); }); 156 return success(); 157 } 158 }; 159 } // namespace 160 161 //===----------------------------------------------------------------------===// 162 // Pass registration 163 //===----------------------------------------------------------------------===// 164 165 namespace { 166 struct ResolveRankedShapeTypeResultDimsPass final 167 : public memref::impl::ResolveRankedShapeTypeResultDimsBase< 168 ResolveRankedShapeTypeResultDimsPass> { 169 void runOnOperation() override; 170 }; 171 172 struct ResolveShapedTypeResultDimsPass final 173 : public memref::impl::ResolveShapedTypeResultDimsBase< 174 ResolveShapedTypeResultDimsPass> { 175 void runOnOperation() override; 176 }; 177 178 } // namespace 179 180 void memref::populateResolveRankedShapedTypeResultDimsPatterns( 181 RewritePatternSet &patterns) { 182 patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>, 183 DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>, 184 IterArgsToInitArgs>(patterns.getContext()); 185 } 186 187 void memref::populateResolveShapedTypeResultDimsPatterns( 188 RewritePatternSet &patterns) { 189 // TODO: Move tensor::DimOp pattern to the Tensor dialect. 190 patterns.add<DimOfShapedTypeOpInterface<memref::DimOp>, 191 DimOfShapedTypeOpInterface<tensor::DimOp>>( 192 patterns.getContext()); 193 } 194 195 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() { 196 RewritePatternSet patterns(&getContext()); 197 memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); 198 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 199 return signalPassFailure(); 200 } 201 202 void ResolveShapedTypeResultDimsPass::runOnOperation() { 203 RewritePatternSet patterns(&getContext()); 204 memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); 205 memref::populateResolveShapedTypeResultDimsPatterns(patterns); 206 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 207 return signalPassFailure(); 208 } 209 210 std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() { 211 return std::make_unique<ResolveShapedTypeResultDimsPass>(); 212 } 213 214 std::unique_ptr<Pass> memref::createResolveRankedShapeTypeResultDimsPass() { 215 return std::make_unique<ResolveRankedShapeTypeResultDimsPass>(); 216 } 217