1*0ac4821bSJaved Absar //===- DecomposeGenericByUnfoldingPermutation.cpp -------===// 2*0ac4821bSJaved Absar // 3*0ac4821bSJaved Absar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*0ac4821bSJaved Absar // See https://llvm.org/LICENSE.txt for license information. 5*0ac4821bSJaved Absar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*0ac4821bSJaved Absar // 7*0ac4821bSJaved Absar //===----------------------------------------------------------------------===// 8*0ac4821bSJaved Absar // 9*0ac4821bSJaved Absar #include "mlir/Dialect/Affine/IR/AffineOps.h" 10*0ac4821bSJaved Absar #include "mlir/Dialect/Linalg/IR/Linalg.h" 11*0ac4821bSJaved Absar #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 12*0ac4821bSJaved Absar #include <map> 13*0ac4821bSJaved Absar #include <optional> 14*0ac4821bSJaved Absar #include <utility> 15*0ac4821bSJaved Absar 16*0ac4821bSJaved Absar using namespace mlir; 17*0ac4821bSJaved Absar using namespace mlir::linalg; 18*0ac4821bSJaved Absar 19*0ac4821bSJaved Absar namespace { 20*0ac4821bSJaved Absar 21*0ac4821bSJaved Absar /// This pattern decomposes the input operand(s) of a linalg.generic that has 22*0ac4821bSJaved Absar /// a `transpose`, `broadcast`, or a mixture of two, into explicit transpose 23*0ac4821bSJaved Absar /// and broadcast. Having them folded into the linalg.generic is a good 24*0ac4821bSJaved Absar /// optimization but sometimes we may want to unwrap, i.e., `unfold` them as 25*0ac4821bSJaved Absar /// explicit transpose and broadcast. This rewrite pattern helps do it for 26*0ac4821bSJaved Absar /// each input operand. This is useful for instance when trying to recognize 27*0ac4821bSJaved Absar /// named ops. 28*0ac4821bSJaved Absar /// 29*0ac4821bSJaved Absar /// The transpose, broadcast, or mixture of both, are expressed in the affine 30*0ac4821bSJaved Absar /// map of the operand. Technically it is essentially `projected permutation`. 31*0ac4821bSJaved Absar /// 32*0ac4821bSJaved Absar /// Example 33*0ac4821bSJaved Absar /// 34*0ac4821bSJaved Absar /// ```mlir 35*0ac4821bSJaved Absar /// 36*0ac4821bSJaved Absar /// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)> 37*0ac4821bSJaved Absar /// #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 38*0ac4821bSJaved Absar /// ... 39*0ac4821bSJaved Absar /// %res = linalg.generic 40*0ac4821bSJaved Absar /// { indexing_maps = [#projection, #identity, #identity], 41*0ac4821bSJaved Absar /// iterator_types = ["parallel", "parallel", "parallel", 42*0ac4821bSJaved Absar /// "parallel", "parallel"]} 43*0ac4821bSJaved Absar /// ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) 44*0ac4821bSJaved Absar /// outs(%z : tensor<5x9x7x8x10xf32>) { 45*0ac4821bSJaved Absar /// ^bb0(%in: f32, %in_1: f32, %out: f32): 46*0ac4821bSJaved Absar /// %div = arith.divf %in, %in_1 : f32 47*0ac4821bSJaved Absar /// linalg.yield %div : f32 48*0ac4821bSJaved Absar /// } -> tensor<5x9x7x8x10xf32> 49*0ac4821bSJaved Absar /// ``` 50*0ac4821bSJaved Absar /// 51*0ac4821bSJaved Absar /// In the above IR operand `%x` map is a projected-permutation. This can be 52*0ac4821bSJaved Absar /// unfolded as: 53*0ac4821bSJaved Absar /// 54*0ac4821bSJaved Absar /// ```mlir 55*0ac4821bSJaved Absar /// ... 56*0ac4821bSJaved Absar /// %x_trans = linalg.transpose 57*0ac4821bSJaved Absar /// ins(%x : tensor<7x8x9xf32>) 58*0ac4821bSJaved Absar /// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1] 59*0ac4821bSJaved Absar /// ... 60*0ac4821bSJaved Absar /// %x_trans_bc = linalg.broadcast 61*0ac4821bSJaved Absar /// ins(%x_trans : tensor<9x7x8xf32>) 62*0ac4821bSJaved Absar /// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4] 63*0ac4821bSJaved Absar /// %2 = linalg.div 64*0ac4821bSJaved Absar /// ins(%x_trans_bc, %y : 65*0ac4821bSJaved Absar /// tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) 66*0ac4821bSJaved Absar /// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> 67*0ac4821bSJaved Absar /// 68*0ac4821bSJaved Absar /// Note that linalg.generic has been 'specialized' to linalg.div. 69*0ac4821bSJaved Absar /// 70*0ac4821bSJaved Absar /// To unfold it, it is more optimal to transpose first and then do the 71*0ac4821bSJaved Absar /// broadcast. However, if transpose is done first, the permutation map needs 72*0ac4821bSJaved Absar /// to be expressed in terms of reduced dimension as broadcast hasn't happened 73*0ac4821bSJaved Absar /// yet. Also, the broadcast dimensions in a linalg.generic come from other 74*0ac4821bSJaved Absar /// operands (those not broadcasted along that particular dimension). We work 75*0ac4821bSJaved Absar /// this out by computing the convex-polyhedron shape of the linalg.generic 76*0ac4821bSJaved Absar /// iteration space from shapes of all the operands, both inputs and outputs. 77*0ac4821bSJaved Absar /// 78*0ac4821bSJaved Absar struct DecomposeProjectedPermutation : public OpRewritePattern<GenericOp> { 79*0ac4821bSJaved Absar using OpRewritePattern<GenericOp>::OpRewritePattern; 80*0ac4821bSJaved Absar 81*0ac4821bSJaved Absar LogicalResult matchAndRewrite(GenericOp genericOp, 82*0ac4821bSJaved Absar PatternRewriter &rewriter) const override; 83*0ac4821bSJaved Absar }; 84*0ac4821bSJaved Absar 85*0ac4821bSJaved Absar /// For the given `map`, determine what dimensions are transposed and what 86*0ac4821bSJaved Absar /// dimensions are broadcasted. 87*0ac4821bSJaved Absar /// Returns : 88*0ac4821bSJaved Absar /// transpose-permutation, broadcast-dimensions` (empty if not needed) 89*0ac4821bSJaved Absar /// 90*0ac4821bSJaved Absar std::pair<SmallVector<int64_t>, SmallVector<int64_t>> 91*0ac4821bSJaved Absar computeTransposeBroadcast(AffineMap &map) { 92*0ac4821bSJaved Absar assert(map.isProjectedPermutation(false) && "not a projection"); 93*0ac4821bSJaved Absar 94*0ac4821bSJaved Absar // As the map is a projection it likely operates on a smaller set of 95*0ac4821bSJaved Absar // dimensions as far as the transpose is concerned (rest are broadcast). 96*0ac4821bSJaved Absar int64_t minorSize = map.getNumResults(); 97*0ac4821bSJaved Absar 98*0ac4821bSJaved Absar SmallVector<int64_t> minorResult; 99*0ac4821bSJaved Absar for (int64_t i = 0; i < minorSize; ++i) { 100*0ac4821bSJaved Absar auto expr = cast<AffineDimExpr>(map.getResults()[i]); 101*0ac4821bSJaved Absar minorResult.push_back(expr.getPosition()); 102*0ac4821bSJaved Absar } 103*0ac4821bSJaved Absar 104*0ac4821bSJaved Absar // If dims are not monotonically increasing then transpose is present. 105*0ac4821bSJaved Absar SmallVector<int64_t> sortedResMap(minorResult); 106*0ac4821bSJaved Absar std::sort(sortedResMap.begin(), sortedResMap.end()); 107*0ac4821bSJaved Absar bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(), 108*0ac4821bSJaved Absar sortedResMap.begin(), sortedResMap.end()); 109*0ac4821bSJaved Absar 110*0ac4821bSJaved Absar // Walk the sorted map result to determine which dimensions are broadcasted. 111*0ac4821bSJaved Absar SmallVector<int64_t> broadcast; 112*0ac4821bSJaved Absar for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) { 113*0ac4821bSJaved Absar if (j < minorSize && sortedResMap[j] == i) { 114*0ac4821bSJaved Absar j++; 115*0ac4821bSJaved Absar continue; 116*0ac4821bSJaved Absar } 117*0ac4821bSJaved Absar broadcast.push_back(i); 118*0ac4821bSJaved Absar } 119*0ac4821bSJaved Absar 120*0ac4821bSJaved Absar SmallVector<int64_t> permutation; 121*0ac4821bSJaved Absar if (hasTranspose) { 122*0ac4821bSJaved Absar // Consider an operand `x : tensor<7x8x9>` of a genericOp that has 123*0ac4821bSJaved Absar // affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>` 124*0ac4821bSJaved Absar // `x`s access is both transposed and broadcast. But when specifying 125*0ac4821bSJaved Absar // the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be 126*0ac4821bSJaved Absar // specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of 127*0ac4821bSJaved Absar // refering to d3, d4. Therefore, re-base the transpose dimensions so 128*0ac4821bSJaved Absar // that they start from d0. 129*0ac4821bSJaved Absar permutation.resize(minorSize); 130*0ac4821bSJaved Absar std::map<int64_t, int64_t> minorMap; 131*0ac4821bSJaved Absar for (int64_t i = 0; i < minorSize; ++i) 132*0ac4821bSJaved Absar minorMap.insert({sortedResMap[i], i}); 133*0ac4821bSJaved Absar 134*0ac4821bSJaved Absar // Re-map the dimensions. 135*0ac4821bSJaved Absar SmallVector<int64_t> remappedResult(minorSize); 136*0ac4821bSJaved Absar for (int64_t i = 0; i < minorSize; ++i) 137*0ac4821bSJaved Absar remappedResult[i] = minorMap[minorResult[i]]; 138*0ac4821bSJaved Absar 139*0ac4821bSJaved Absar /// Calculate the permutation for the transpose. 140*0ac4821bSJaved Absar for (unsigned i = 0; i < minorSize; ++i) { 141*0ac4821bSJaved Absar permutation[remappedResult[i]] = i; 142*0ac4821bSJaved Absar } 143*0ac4821bSJaved Absar } 144*0ac4821bSJaved Absar return {permutation, broadcast}; 145*0ac4821bSJaved Absar } 146*0ac4821bSJaved Absar 147*0ac4821bSJaved Absar LogicalResult DecomposeProjectedPermutation::matchAndRewrite( 148*0ac4821bSJaved Absar GenericOp op, PatternRewriter &rewriter) const { 149*0ac4821bSJaved Absar if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() || 150*0ac4821bSJaved Absar op.isSingleYieldOp() || !op.isAllParallelLoops()) 151*0ac4821bSJaved Absar return failure(); 152*0ac4821bSJaved Absar 153*0ac4821bSJaved Absar // If the map of an operand is not a `projected permutation` then 154*0ac4821bSJaved Absar // it cannot be decomposed to mere transpose and broadcast. 155*0ac4821bSJaved Absar // The requirement that all maps be `projected permutation` may be 156*0ac4821bSJaved Absar // over-restrictive but since we need to determine shape of the 157*0ac4821bSJaved Absar // iteration space as well, reject if any map violates assumption. 158*0ac4821bSJaved Absar for (auto &opOperand : op->getOpOperands()) { 159*0ac4821bSJaved Absar auto map = op.getMatchingIndexingMap(&opOperand); 160*0ac4821bSJaved Absar if (!map.isProjectedPermutation(false)) 161*0ac4821bSJaved Absar return failure(); 162*0ac4821bSJaved Absar } 163*0ac4821bSJaved Absar 164*0ac4821bSJaved Absar // Decomposing linalg.generic involves creating `tensor.empty` 165*0ac4821bSJaved Absar // which can have dynamic shapes but then we would have to work 166*0ac4821bSJaved Absar // out which operand can supply that runtime-value (tensor.dim). 167*0ac4821bSJaved Absar // Leaving it as a future TODO. 168*0ac4821bSJaved Absar if (llvm::any_of(op->getOpOperands(), [](OpOperand &oper) { 169*0ac4821bSJaved Absar auto opType = cast<RankedTensorType>(oper.get().getType()); 170*0ac4821bSJaved Absar return ShapedType::isDynamicShape(opType.getShape()); 171*0ac4821bSJaved Absar })) 172*0ac4821bSJaved Absar return failure(); 173*0ac4821bSJaved Absar 174*0ac4821bSJaved Absar auto outputShape = op.getStaticLoopRanges(); 175*0ac4821bSJaved Absar 176*0ac4821bSJaved Absar auto loc = op.getLoc(); 177*0ac4821bSJaved Absar bool isChanged = false; 178*0ac4821bSJaved Absar SmallVector<Value> newInitValues = op.getDpsInputs(); 179*0ac4821bSJaved Absar SmallVector<AffineMap> newMap = op.getIndexingMapsArray(); 180*0ac4821bSJaved Absar 181*0ac4821bSJaved Absar // Walk over each input operand and unfold if it is transposed, broadcast 182*0ac4821bSJaved Absar // or mix of two via operand's affine-map. 183*0ac4821bSJaved Absar for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) { 184*0ac4821bSJaved Absar auto &map = newMap[i]; 185*0ac4821bSJaved Absar auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType()); 186*0ac4821bSJaved Absar auto elType = inputRTType.getElementType(); 187*0ac4821bSJaved Absar 188*0ac4821bSJaved Absar /// Nothing to do if map is already an identity. 189*0ac4821bSJaved Absar if (map.isIdentity()) 190*0ac4821bSJaved Absar continue; 191*0ac4821bSJaved Absar 192*0ac4821bSJaved Absar auto [permutation, broadcastedDims] = computeTransposeBroadcast(map); 193*0ac4821bSJaved Absar 194*0ac4821bSJaved Absar // Does it need transpose? 195*0ac4821bSJaved Absar if (!permutation.empty()) { 196*0ac4821bSJaved Absar /// linalg.transpose permutes the dimensions of input using 197*0ac4821bSJaved Absar /// rule: dim(result, i) = dim(input, permutation[i]) 198*0ac4821bSJaved Absar SmallVector<int64_t> transposedShape(map.getNumResults()); 199*0ac4821bSJaved Absar for (int64_t i = 0; i < map.getNumResults(); ++i) 200*0ac4821bSJaved Absar transposedShape[i] = inputRTType.getShape()[permutation[i]]; 201*0ac4821bSJaved Absar 202*0ac4821bSJaved Absar Value emptyTensor = 203*0ac4821bSJaved Absar rewriter.create<tensor::EmptyOp>(loc, transposedShape, elType); 204*0ac4821bSJaved Absar 205*0ac4821bSJaved Absar auto transposeOp = rewriter.create<TransposeOp>(loc, newInitValues[i], 206*0ac4821bSJaved Absar emptyTensor, permutation); 207*0ac4821bSJaved Absar newInitValues[i] = transposeOp->getResult(0); 208*0ac4821bSJaved Absar isChanged = true; 209*0ac4821bSJaved Absar } 210*0ac4821bSJaved Absar 211*0ac4821bSJaved Absar // Does it require broadcast? 212*0ac4821bSJaved Absar if (!broadcastedDims.empty()) { 213*0ac4821bSJaved Absar assert(broadcastedDims.size() && "should have non size broadcast"); 214*0ac4821bSJaved Absar Value emptyTensor = rewriter.create<tensor::EmptyOp>( 215*0ac4821bSJaved Absar loc, outputShape, inputRTType.getElementType()); 216*0ac4821bSJaved Absar 217*0ac4821bSJaved Absar auto broadcastOp = rewriter.create<linalg::BroadcastOp>( 218*0ac4821bSJaved Absar loc, newInitValues[i], emptyTensor, broadcastedDims); 219*0ac4821bSJaved Absar 220*0ac4821bSJaved Absar newInitValues[i] = broadcastOp->getResult(0); 221*0ac4821bSJaved Absar isChanged = true; 222*0ac4821bSJaved Absar } 223*0ac4821bSJaved Absar newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims()); 224*0ac4821bSJaved Absar } 225*0ac4821bSJaved Absar 226*0ac4821bSJaved Absar if (isChanged) { 227*0ac4821bSJaved Absar SmallVector<Value> operands = op->getOperands(); 228*0ac4821bSJaved Absar ValueRange operandsRef(operands); 229*0ac4821bSJaved Absar 230*0ac4821bSJaved Absar auto newOp = rewriter.create<linalg::GenericOp>( 231*0ac4821bSJaved Absar /*location=*/op.getLoc(), 232*0ac4821bSJaved Absar /*resultTensorTypes=*/op->getResultTypes(), 233*0ac4821bSJaved Absar /*inputs=*/newInitValues, 234*0ac4821bSJaved Absar /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()), 235*0ac4821bSJaved Absar /*indexingMaps=*/newMap, 236*0ac4821bSJaved Absar /*iteratorTypes=*/op.getIteratorTypesArray()); 237*0ac4821bSJaved Absar 238*0ac4821bSJaved Absar newOp.getRegion().takeBody(op->getRegion(0)); 239*0ac4821bSJaved Absar rewriter.replaceOp(op, newOp->getResults()); 240*0ac4821bSJaved Absar } 241*0ac4821bSJaved Absar return success(); 242*0ac4821bSJaved Absar } 243*0ac4821bSJaved Absar 244*0ac4821bSJaved Absar } // namespace 245*0ac4821bSJaved Absar 246*0ac4821bSJaved Absar void mlir::linalg::populateDecomposeProjectedPermutationPatterns( 247*0ac4821bSJaved Absar RewritePatternSet &patterns) { 248*0ac4821bSJaved Absar patterns.insert<DecomposeProjectedPermutation>(patterns.getContext()); 249*0ac4821bSJaved Absar } 250