1 //===- DecomposeGenericByUnfoldingPermutation.cpp -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Linalg/IR/Linalg.h" 11 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 12 #include <map> 13 #include <optional> 14 #include <utility> 15 16 using namespace mlir; 17 using namespace mlir::linalg; 18 19 namespace { 20 21 /// This pattern decomposes the input operand(s) of a linalg.generic that has 22 /// a `transpose`, `broadcast`, or a mixture of two, into explicit transpose 23 /// and broadcast. Having them folded into the linalg.generic is a good 24 /// optimization but sometimes we may want to unwrap, i.e., `unfold` them as 25 /// explicit transpose and broadcast. This rewrite pattern helps do it for 26 /// each input operand. This is useful for instance when trying to recognize 27 /// named ops. 28 /// 29 /// The transpose, broadcast, or mixture of both, are expressed in the affine 30 /// map of the operand. Technically it is essentially `projected permutation`. 31 /// 32 /// Example 33 /// 34 /// ```mlir 35 /// 36 /// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)> 37 /// #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 38 /// ... 39 /// %res = linalg.generic 40 /// { indexing_maps = [#projection, #identity, #identity], 41 /// iterator_types = ["parallel", "parallel", "parallel", 42 /// "parallel", "parallel"]} 43 /// ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) 44 /// outs(%z : tensor<5x9x7x8x10xf32>) { 45 /// ^bb0(%in: f32, %in_1: f32, %out: f32): 46 /// %div = arith.divf %in, %in_1 : f32 47 /// linalg.yield %div : f32 48 /// } -> tensor<5x9x7x8x10xf32> 49 /// ``` 50 /// 51 /// In the above IR operand `%x` map is a projected-permutation. This can be 52 /// unfolded as: 53 /// 54 /// ```mlir 55 /// ... 56 /// %x_trans = linalg.transpose 57 /// ins(%x : tensor<7x8x9xf32>) 58 /// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1] 59 /// ... 60 /// %x_trans_bc = linalg.broadcast 61 /// ins(%x_trans : tensor<9x7x8xf32>) 62 /// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4] 63 /// %2 = linalg.div 64 /// ins(%x_trans_bc, %y : 65 /// tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) 66 /// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> 67 /// 68 /// Note that linalg.generic has been 'specialized' to linalg.div. 69 /// 70 /// To unfold it, it is more optimal to transpose first and then do the 71 /// broadcast. However, if transpose is done first, the permutation map needs 72 /// to be expressed in terms of reduced dimension as broadcast hasn't happened 73 /// yet. Also, the broadcast dimensions in a linalg.generic come from other 74 /// operands (those not broadcasted along that particular dimension). We work 75 /// this out by computing the convex-polyhedron shape of the linalg.generic 76 /// iteration space from shapes of all the operands, both inputs and outputs. 77 /// 78 struct DecomposeProjectedPermutation : public OpRewritePattern<GenericOp> { 79 using OpRewritePattern<GenericOp>::OpRewritePattern; 80 81 LogicalResult matchAndRewrite(GenericOp genericOp, 82 PatternRewriter &rewriter) const override; 83 }; 84 85 /// For the given `map`, determine what dimensions are transposed and what 86 /// dimensions are broadcasted. 87 /// Returns : 88 /// transpose-permutation, broadcast-dimensions` (empty if not needed) 89 /// 90 std::pair<SmallVector<int64_t>, SmallVector<int64_t>> 91 computeTransposeBroadcast(AffineMap &map) { 92 assert(map.isProjectedPermutation(false) && "not a projection"); 93 94 // As the map is a projection it likely operates on a smaller set of 95 // dimensions as far as the transpose is concerned (rest are broadcast). 96 int64_t minorSize = map.getNumResults(); 97 98 SmallVector<int64_t> minorResult; 99 for (int64_t i = 0; i < minorSize; ++i) { 100 auto expr = cast<AffineDimExpr>(map.getResults()[i]); 101 minorResult.push_back(expr.getPosition()); 102 } 103 104 // If dims are not monotonically increasing then transpose is present. 105 SmallVector<int64_t> sortedResMap(minorResult); 106 std::sort(sortedResMap.begin(), sortedResMap.end()); 107 bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(), 108 sortedResMap.begin(), sortedResMap.end()); 109 110 // Walk the sorted map result to determine which dimensions are broadcasted. 111 SmallVector<int64_t> broadcast; 112 for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) { 113 if (j < minorSize && sortedResMap[j] == i) { 114 j++; 115 continue; 116 } 117 broadcast.push_back(i); 118 } 119 120 SmallVector<int64_t> permutation; 121 if (hasTranspose) { 122 // Consider an operand `x : tensor<7x8x9>` of a genericOp that has 123 // affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>` 124 // `x`s access is both transposed and broadcast. But when specifying 125 // the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be 126 // specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of 127 // refering to d3, d4. Therefore, re-base the transpose dimensions so 128 // that they start from d0. 129 permutation.resize(minorSize); 130 std::map<int64_t, int64_t> minorMap; 131 for (int64_t i = 0; i < minorSize; ++i) 132 minorMap.insert({sortedResMap[i], i}); 133 134 // Re-map the dimensions. 135 SmallVector<int64_t> remappedResult(minorSize); 136 for (int64_t i = 0; i < minorSize; ++i) 137 remappedResult[i] = minorMap[minorResult[i]]; 138 139 /// Calculate the permutation for the transpose. 140 for (unsigned i = 0; i < minorSize; ++i) { 141 permutation[remappedResult[i]] = i; 142 } 143 } 144 return {permutation, broadcast}; 145 } 146 147 LogicalResult DecomposeProjectedPermutation::matchAndRewrite( 148 GenericOp op, PatternRewriter &rewriter) const { 149 if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() || 150 op.isSingleYieldOp() || !op.isAllParallelLoops()) 151 return failure(); 152 153 // If the map of an operand is not a `projected permutation` then 154 // it cannot be decomposed to mere transpose and broadcast. 155 // The requirement that all maps be `projected permutation` may be 156 // over-restrictive but since we need to determine shape of the 157 // iteration space as well, reject if any map violates assumption. 158 for (auto &opOperand : op->getOpOperands()) { 159 auto map = op.getMatchingIndexingMap(&opOperand); 160 if (!map.isProjectedPermutation(false)) 161 return failure(); 162 } 163 164 // Decomposing linalg.generic involves creating `tensor.empty` 165 // which can have dynamic shapes but then we would have to work 166 // out which operand can supply that runtime-value (tensor.dim). 167 // Leaving it as a future TODO. 168 if (llvm::any_of(op->getOpOperands(), [](OpOperand &oper) { 169 auto opType = cast<RankedTensorType>(oper.get().getType()); 170 return ShapedType::isDynamicShape(opType.getShape()); 171 })) 172 return failure(); 173 174 auto outputShape = op.getStaticLoopRanges(); 175 176 auto loc = op.getLoc(); 177 bool isChanged = false; 178 SmallVector<Value> newInitValues = op.getDpsInputs(); 179 SmallVector<AffineMap> newMap = op.getIndexingMapsArray(); 180 181 // Walk over each input operand and unfold if it is transposed, broadcast 182 // or mix of two via operand's affine-map. 183 for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) { 184 auto &map = newMap[i]; 185 auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType()); 186 auto elType = inputRTType.getElementType(); 187 188 /// Nothing to do if map is already an identity. 189 if (map.isIdentity()) 190 continue; 191 192 auto [permutation, broadcastedDims] = computeTransposeBroadcast(map); 193 194 // Does it need transpose? 195 if (!permutation.empty()) { 196 /// linalg.transpose permutes the dimensions of input using 197 /// rule: dim(result, i) = dim(input, permutation[i]) 198 SmallVector<int64_t> transposedShape(map.getNumResults()); 199 for (int64_t i = 0; i < map.getNumResults(); ++i) 200 transposedShape[i] = inputRTType.getShape()[permutation[i]]; 201 202 Value emptyTensor = 203 rewriter.create<tensor::EmptyOp>(loc, transposedShape, elType); 204 205 auto transposeOp = rewriter.create<TransposeOp>(loc, newInitValues[i], 206 emptyTensor, permutation); 207 newInitValues[i] = transposeOp->getResult(0); 208 isChanged = true; 209 } 210 211 // Does it require broadcast? 212 if (!broadcastedDims.empty()) { 213 assert(broadcastedDims.size() && "should have non size broadcast"); 214 Value emptyTensor = rewriter.create<tensor::EmptyOp>( 215 loc, outputShape, inputRTType.getElementType()); 216 217 auto broadcastOp = rewriter.create<linalg::BroadcastOp>( 218 loc, newInitValues[i], emptyTensor, broadcastedDims); 219 220 newInitValues[i] = broadcastOp->getResult(0); 221 isChanged = true; 222 } 223 newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims()); 224 } 225 226 if (isChanged) { 227 SmallVector<Value> operands = op->getOperands(); 228 ValueRange operandsRef(operands); 229 230 auto newOp = rewriter.create<linalg::GenericOp>( 231 /*location=*/op.getLoc(), 232 /*resultTensorTypes=*/op->getResultTypes(), 233 /*inputs=*/newInitValues, 234 /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()), 235 /*indexingMaps=*/newMap, 236 /*iteratorTypes=*/op.getIteratorTypesArray()); 237 238 newOp.getRegion().takeBody(op->getRegion(0)); 239 rewriter.replaceOp(op, newOp->getResults()); 240 } 241 return success(); 242 } 243 244 } // namespace 245 246 void mlir::linalg::populateDecomposeProjectedPermutationPatterns( 247 RewritePatternSet &patterns) { 248 patterns.insert<DecomposeProjectedPermutation>(patterns.getContext()); 249 } 250