//===- Transforms.cpp - Linalg transformations as patterns ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements logic and helpers to expose Linalg transforms as rewrite // patterns. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include #include #define DEBUG_TYPE "linalg-transforms" using namespace mlir; using namespace mlir::linalg; #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") #define DBGSNL() (llvm::dbgs() << "\n") //===----------------------------------------------------------------------===// // Transformations exposed as functional-style API calls. //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // peelLoop transformation. //===----------------------------------------------------------------------===// /// Try to peel and canonicalize loop `op` and return the new result. /// Also applies affine_min/max bounds simplification on the fly where relevant. // TODO: Add support for scf.parallel and affine.for loops. SmallVector mlir::linalg::peelLoop(RewriterBase &rewriter, Operation *op) { return llvm::TypeSwitch>(op) .Case([&](scf::ForOp forOp) { scf::ForOp partialIteration; if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp, partialIteration))) return partialIteration->getResults(); assert(!partialIteration && "expected that loop was not peeled"); return forOp->getResults(); }) .Default([&](Operation *op) { return op->getResults(); }); } /// Peel 'loops' and applies affine_min/max bounds simplification on the fly /// where relevant. void mlir::linalg::peelLoops(RewriterBase &rewriter, ArrayRef loops) { for (auto loopOp : loops) peelLoop(rewriter, loopOp); } //===----------------------------------------------------------------------===// // pack transformation. //===----------------------------------------------------------------------===// #ifndef NDEBUG /// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim). static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) { bool found = false; for (AffineExpr e : map.getResults()) { if (!e.isFunctionOfDim(dim)) continue; if (found) return false; found = true; } return true; } #endif // NDEBUG /// Return the index of the first result of `map` that is a function of /// AffineDimExpr(dim), std::nullopt otherwise. static std::optional getFirstResultIndexFunctionOf(AffineMap map, int64_t dim) { for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { AffineExpr expr = map.getResult(i); if (!expr.isFunctionOfDim(dim)) continue; return i; } return std::nullopt; } /// Perform one step of packing of a LinalgOp's metadata along `dim` into the /// `newDim` at `iteratorTypes.size()` by: /// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`. /// 2. Appending a `newDim` to the domain of every indexing map. /// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing /// by potentially adding a `newDim` result to `map`. /// The preserved invariant is that `iteratorTypes.size()` is always equal to /// `map.getNumDims()` for every map in `indexingMaps`. /// /// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update. /// Return a vector that records the optional packing for each operand. /// Return failure if the packed indexing cannot be represented with a LinalgOp. /// /// Further details: /// ================ /// The current implementation of packing (i.e. data tiling) consists of /// rewriting a linearized strip-mined form into a higher-dimensional access. /// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite /// `I` into `4 * i + ii`, where `0 <= ii < 4`. /// The access is further rewritten as `A[i][f(j, k, l)][ii]`. /// /// This rewrite into higher dimensional access is not possible for general /// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr: /// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we /// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`. /// The rewrite of the access would be a form not representable in Linalg: /// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`. /// Note however that as `J` and `ii` iterate, the accesses do not have a /// particular alignment, so packing does not achieve alignment in this case /// /// In the future, we may want to consider a mixed-form that allows some /// alignment in the presence of multiple accesses: /// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]` /// And would rewrite accesses as: /// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]` static FailureOr>> packLinalgMetadataOnce(SmallVectorImpl &indexingMaps, SmallVectorImpl &iteratorTypes, int64_t dim) { int64_t newDim = iteratorTypes.size(); iteratorTypes.push_back(iteratorTypes[dim]); SmallVector> packedDimPerIndexingMap( indexingMaps.size(), std::nullopt); SmallVector newMaps; for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e; ++operandIdx) { AffineMap map = indexingMaps[operandIdx]; // Add the `newDim` to map whatever the case. assert(map.getNumDims() == newDim && "num dims invariant violation"); map = map.shiftDims(1, newDim); // Get the at-most-1 index of the result that is a function of `dim`. // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which // logically chunks dimension `dim` into `K * dim + newDim`, where the // packing factor `K` is specified separately. assert(hasAtMostOneResultFunctionOfDim(map, dim) && "num results invariant violation"); auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim); if (!maybeOperandDimensionToPack.has_value()) { newMaps.push_back(map); continue; } // We can only pack AffineDimExpr atm. if (!isa(map.getResult(maybeOperandDimensionToPack.value()))) return failure(); // Add `newDim` to the results of the map. map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim), map.getNumResults()); newMaps.push_back(map); // Record the that `operandIdx` is packed. packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack; } indexingMaps = newMaps; return packedDimPerIndexingMap; } namespace { /// Helper struct to encode packing along one dimension of a LinalgOp. struct PackedOperandsDim { OpFoldResult packedSize; SmallVector> packedDimForEachOperand; }; /// Helper struct to encode packing along all dimensions of a LinalgOp. struct PackedOperandsDimList { void pushBack(PackedOperandsDim &&packedOperandsDims) { spec.emplace_back(packedOperandsDims); } /// Return all the dims that have been packed for operand @ `operandPos`. SmallVector extractPackedDimsForOperand(int64_t operandPos); /// Return all the pack sizes by which an operand @ `operandPos` is packed. SmallVector extractPackSizesForOperand(int64_t operandPos); private: SmallVector spec; }; } // namespace FailureOr linalg::lowerPack(RewriterBase &rewriter, tensor::PackOp packOp, bool lowerPadLikeWithInsertSlice) { // 1. Filter out NYI cases. auto packedTensorType = cast(packOp->getResultTypes().front()); if (llvm::any_of(packOp.getStaticInnerTiles(), [](int64_t size) { return ShapedType::isDynamic(size); })) { return rewriter.notifyMatchFailure( packOp, "non-static shape NYI, needs a more powerful tensor.expand_shape op"); } Location loc = packOp->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(packOp); // 2. Compute the permutation vector to shuffle packed shape into the shape // before any outer or inner permutations have been applied. PackingMetadata packingMetadata = computePackingMetadata( packedTensorType.getRank(), packOp.getInnerDimsPos()); SmallVector packedToStripMinedShapePerm = tensor::getPackInverseDestPerm(packOp); // 3. Compute the stripMinedShape: this is the packed shape before any outer // or inner permutations have been applied. SmallVector stripMinedShape(packedTensorType.getShape()); applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm); // 4. Pad the source of packOp to a shape we can expand into stripMinedShape. SmallVector lows(packOp.getSourceRank(), rewriter.getIndexAttr(0)); SmallVector highs(packOp.getSourceRank(), rewriter.getIndexAttr(0)); for (auto [pos, innerSize] : llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) { int outerPos = packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]]; OpFoldResult origSize = tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos); OpFoldResult outerSize = tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos); AffineExpr s0, d0, d1; bindDims(rewriter.getContext(), d0, d1); bindSymbols(rewriter.getContext(), s0); auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1); highs[pos] = affine::makeComposedFoldedAffineApply( rewriter, loc, map, {outerSize, origSize, innerSize}); } RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType( RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), packingMetadata.reassociations); Value paddingValue = packOp.getPaddingValue(); if (!paddingValue) { paddingValue = rewriter.create( loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed))); } auto padOp = rewriter.create(loc, collapsed, packOp.getSource(), lows, highs, paddingValue, /*nofold=*/false); LLVM_DEBUG( DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, DBGS() << "insertPositions: "); DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions, DBGS() << "outerPositions: "); DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), DBGS() << "packedShape: "); DBGSNL(); llvm::interleaveComma(packedToStripMinedShapePerm, DBGS() << "packedToStripMinedShapePerm: "); DBGSNL(); llvm::interleaveComma( packingMetadata.reassociations, DBGS() << "reassociations: ", [&](ReassociationIndices ri) { llvm::interleaveComma(ri, llvm::dbgs() << "|"); }); DBGSNL(); llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) { // Pack ops which operate as simple pads may not produce legal // tensor.insert_slice operations when the packed type does not rank reduce // to the padded type. SliceVerificationResult rankReduces = isRankReducedType(packedTensorType, padOp.getResultType()); if (rankReduces == SliceVerificationResult::Success) { // This pack is just a plain pad. // Just insert the pad in the higher ranked tensor. // Offsets. SmallVector zeros(packOp.getDestRank(), rewriter.getIndexAttr(0)); // Strides. SmallVector ones(packOp.getDestRank(), rewriter.getIndexAttr(1)); SmallVector sizes = tensor::getMixedSizes(rewriter, loc, packOp.getDest()); auto insertSliceOp = rewriter.create( loc, /*source=*/padOp, /*dest=*/packOp.getDest(), /*offsets=*/zeros, sizes, /*strides=*/ones); LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL();); rewriter.replaceOp(packOp, insertSliceOp->getResults()); return LowerPackResult{padOp, /*reshapeOp=*/nullptr, /*transposeOp=*/nullptr}; } } // 5. Expand from the padded result to the stripMinedShape. auto expandShapeResultType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); auto reshapeOp = rewriter.create( loc, expandShapeResultType, padOp.getResult(), packingMetadata.reassociations); // 6. Transpose stripMinedShape to packedShape. SmallVector transpPerm = invertPermutationVector(packedToStripMinedShapePerm); auto transposeOp = rewriter.create( loc, reshapeOp.getResult(), packOp.getDest(), transpPerm); LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); DBGS() << "reshape op: " << reshapeOp; DBGSNL(); llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: "); DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); // 7. Replace packOp by transposeOp. rewriter.replaceOp(packOp, transposeOp->getResults()); return LowerPackResult{padOp, reshapeOp, transposeOp}; } FailureOr linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice) { Location loc = unPackOp->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(unPackOp); RankedTensorType packedTensorType = unPackOp.getSourceType(); int64_t packedRank = packedTensorType.getRank(); OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); auto destTensorType = cast(unPackOp.getDest().getType()); if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) { // This unpack is just a plain unpad. // Just extract the slice from the higher ranked tensor. ArrayRef destShape = destTensorType.getShape(); // The inner dimensions stay the same as the destination tensor, but the // outer ones are additional 1s. SmallVector sizes(packedRank - destShape.size(), one); sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest())); auto extractSliceOp = rewriter.create( loc, destTensorType, unPackOp.getSource(), SmallVector(packedRank, zero), sizes, SmallVector(packedRank, one)); rewriter.replaceOp(unPackOp, extractSliceOp->getResults()); return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr, /*reshapeOp=*/nullptr, extractSliceOp}; } // 1. Compute the permutation vector to shuffle packed shape into the shape // before any outer or inner permutations have been applied. PackingMetadata packingMetadata; SmallVector packedToStripMinedShapePerm = tensor::getUnPackInverseSrcPerm(unPackOp, packingMetadata); // 2. Compute the stripMinedShape: this is the packed shape without outer and // inner permutations. SmallVector stripMinedShape(packedTensorType.getShape()); applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm); // 3. Transpose packedShape to stripMinedShape. RankedTensorType stripMinedTensorType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( stripMinedTensorType, packingMetadata.reassociations); // Get dynamic dims from input tensor based on packedToStripMinedShapePerm // permutation. SmallVector dims = tensor::getMixedSizes(rewriter, loc, unPackOp.getSource()); applyPermutationToVector(dims, packedToStripMinedShapePerm); auto emptyOp = rewriter.create( loc, dims, stripMinedTensorType.getElementType()); auto transposeOp = rewriter.create( loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm); LLVM_DEBUG( DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, DBGS() << "insertPositions: "); DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), DBGS() << "packedShape: "); DBGSNL(); llvm::interleaveComma(packedToStripMinedShapePerm, DBGS() << "packedToStripMinedShapePerm: "); DBGSNL(); llvm::interleaveComma( packingMetadata.reassociations, DBGS() << "reassociations: ", [&](ReassociationIndices ri) { llvm::interleaveComma(ri, llvm::dbgs() << "|"); }); DBGSNL(); llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL();); // 4. Collapse from the stripMinedShape to the padded result. auto reshapeOp = rewriter.create( loc, collapsedType, transposeOp->getResult(0), packingMetadata.reassociations); // 5. ExtractSlice. int64_t destRank = destTensorType.getRank(); auto extractSliceOp = rewriter.create( loc, destTensorType, reshapeOp->getResult(0), SmallVector(destRank, zero), tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()), SmallVector(destRank, one)); // 6. Inject a copy to preserve DPS. auto copyOp = rewriter.create( loc, extractSliceOp->getResult(0), unPackOp.getDest()); // 7. Replace unPackOp by copyOp. rewriter.replaceOp(unPackOp, copyOp->getResults()); return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp}; } SmallVector PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) { SmallVector res; for (auto &i : spec) { if (!i.packedDimForEachOperand[operandPos].has_value()) continue; res.push_back(i.packedDimForEachOperand[operandPos].value()); } return res; } SmallVector PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) { SmallVector res; for (auto &i : spec) { if (!i.packedDimForEachOperand[operandPos].has_value()) continue; res.push_back(i.packedSize); } return res; } /// Implement packing of a single LinalgOp by performing packing by /// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator. /// Return the packed Linalg op on success, failure otherwise. FailureOr linalg::pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef packedSizes) { if (packedSizes.size() != linalgOp.getNumLoops()) { return rewriter.notifyMatchFailure(linalgOp, "incorrect number of pack sizes"); } Location loc = linalgOp->getLoc(); SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); SmallVector iteratorTypes = linalgOp.getIteratorTypesArray(); LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n"; llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL(); llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL();); SmallVector packOps; SmallVector unPackOps; // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i]. PackedOperandsDimList listOfPackedOperandsDim; for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) { std::optional maybeConstant = getConstantIntValue(packedSizes[i]); // Skip tile sizes explicitly set to 0. if (maybeConstant.has_value() && maybeConstant.value() == 0) continue; PackedOperandsDim packedOperandsDims; packedOperandsDims.packedSize = packedSizes[i]; FailureOr>> maybePackedDimForEachOperand = packLinalgMetadataOnce(indexingMaps, iteratorTypes, i); if (failed(maybePackedDimForEachOperand)) return failure(); packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand; listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims)); LLVM_DEBUG( DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i] << "\n"; llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL(); llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL(); llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand, DBGS() << "packedDimForEachOperand: "); DBGSNL();); } // Step 2. Propagate packing to all LinalgOp operands. SmallVector inputsAndInits, results; SmallVector initOperands = llvm::to_vector(llvm::map_range( linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); SmallVector inputOperands = linalgOp.getDpsInputOperands(); for (const auto &operandsList : {inputOperands, initOperands}) { for (OpOperand *opOperand : operandsList) { int64_t pos = opOperand->getOperandNumber(); Value operand = opOperand->get(); SmallVector innerPos = listOfPackedOperandsDim.extractPackedDimsForOperand(pos); SmallVector innerPackSizes = listOfPackedOperandsDim.extractPackSizesForOperand(pos); LLVM_DEBUG( DBGS() << "operand: " << operand << "\n"; llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL(); llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: "); DBGSNL();); if (innerPackSizes.empty()) { inputsAndInits.push_back(operand); continue; } Value dest = tensor::PackOp::createDestinationTensor( rewriter, loc, operand, innerPackSizes, innerPos, /*outerDimsPerm=*/{}); ShapedType operandType = cast(operand.getType()); bool areConstantTiles = llvm::all_of(innerPackSizes, [](OpFoldResult tile) { return getConstantIntValue(tile).has_value(); }); if (areConstantTiles && operandType.hasStaticShape() && !tensor::PackOp::requirePaddingValue( operandType.getShape(), innerPos, cast(dest.getType()).getShape(), {}, innerPackSizes)) { packOps.push_back(rewriter.create( loc, operand, dest, innerPos, innerPackSizes)); } else { // TODO: value of the padding attribute should be determined by // consumers. auto zeroAttr = rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); Value zero = rewriter.create(loc, zeroAttr); packOps.push_back(rewriter.create( loc, operand, dest, innerPos, innerPackSizes, zero)); } inputsAndInits.push_back(packOps.back()); } } // Step 3. Build the packed op, use the type of `inits` as result types. ValueRange inputs = ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs()); ValueRange inits = ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits()); auto packedLinalgOp = rewriter.create( linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps, iteratorTypes); packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0)); // Step 4. Propagate packing to all the op results. for (OpResult result : packedLinalgOp->getResults()) { int64_t resultNum = result.getResultNumber(); tensor::PackOp maybePackedInit = inits[resultNum].getDefiningOp(); if (!maybePackedInit) { results.push_back(result); continue; } // Build the symmetrical UnPackOp to the existing PackOp. unPackOps.push_back(rewriter.create( packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles())); results.push_back(unPackOps.back()); } // Step 5. Replace `linalgOp`. rewriter.replaceOp(linalgOp, results); // Return packedLinalgOp. return PackResult{packOps, cast(packedLinalgOp.getOperation()), unPackOps}; } //===----------------------------------------------------------------------===// // packTranspose transformation. //===----------------------------------------------------------------------===// /// Return a copy of `tensorType` after permutation by `permutationVector`. // Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder // but this would introduce a dependence on Dialect in IR. // TODO: Restructure. static RankedTensorType permuteShape(RankedTensorType tensorType, ArrayRef permutationVector) { SmallVector shape(tensorType.getShape()); applyPermutationToVector(shape, permutationVector); return RankedTensorType::Builder(tensorType).setShape(shape); } /// Return a new GenericOp obtained by transposing opOperand by the permutation /// vector: /// - the corresponding indexing map is transposed by `permutation` /// - the corresponding operand value is replaced by `transposedValue` /// `linalgOp` is replaced by the return op in the process. /// Asserts that `transposedValue` is of the proper transposed ShapedType. static LinalgOp transposeOneLinalgOperandAndReplace( RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, ArrayRef permutation, Value transposedValue) { // Sanity check the operand. assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand"); // Sanity check of the expected transposed tensor type. auto tensorType = permuteShape( cast(opOperand.get().getType()), permutation); (void)tensorType; assert(tensorType == transposedValue.getType() && "expected tensor type mismatch"); // Compute the transposed indexing map. // Sigh unsigned pollution. SmallVector tmpTransposition = llvm::to_vector( llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; })); AffineMap permutationMap = AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext()); AffineMap transposedMap = permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand)); // Set the transposed indexing map in the proper position. SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap; // Set the transposedValue in the proper operand position. SmallVector operands = linalgOp->getOperands(); operands[opOperand.getOperandNumber()] = transposedValue; ValueRange operandsRef(operands); auto transposedGenericOp = rewriter.create( /*location=*/linalgOp->getLoc(), /*resultTensorTypes=*/ operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(), /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()), /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()), /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/linalgOp.getIteratorTypesArray()); transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0)); rewriter.replaceOp(linalgOp, transposedGenericOp->getResults()); return cast(transposedGenericOp.getOperation()); } FailureOr linalg::packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef outerPerm, ArrayRef innerPerm) { Location loc = linalgOp.getLoc(); // Step 1. Transpose packOp. rewriter.setInsertionPoint(packOp); tensor::PackOp transposedPackOp = packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm); if (!packOp.getResult().hasOneUse()) return rewriter.notifyMatchFailure(linalgOp, "expect single pack use"); OpOperand &packUse = *packOp->getUses().begin(); if (packUse.getOwner() != linalgOp) { return rewriter.notifyMatchFailure( linalgOp, "not a single use by the LinalgOp target"); } if (maybeUnPackOp && (!linalgOp.isDpsInit(&packUse) || maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) { return rewriter.notifyMatchFailure(linalgOp, "not produced by the LinalgOp target"); } // Step 2. Transpose linalgOp. // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the // identity. Don't rely on it. int64_t numLeadingDims = packOp.getSourceRank(); int64_t numTrailingDims = packOp.getInnerDimsPos().size(); // Step 2.a. Compute the permutation on the whole operand. // Leading part just reuse the outerPerm. SmallVector permutation(outerPerm); if (permutation.empty()) llvm::append_range(permutation, llvm::seq(0, numLeadingDims)); // Trailing part needs to reindex positions by `numLeadingDims`. if (innerPerm.empty()) { llvm::append_range( permutation, llvm::seq(numLeadingDims, numLeadingDims + numTrailingDims)); } else { llvm::append_range(permutation, llvm::map_range(innerPerm, [&](int64_t pos) { return numLeadingDims + pos; })); } if (!isPermutationVector(permutation)) return rewriter.notifyMatchFailure(linalgOp, "invalid permutation"); // Step 2.b. Save the transposedPackUse operand number in case we need to // get the tied OpResult after `linalgOp` has been replaced. int64_t packUseOperandNumber = packUse.getOperandNumber(); // Step 2.c. Actually perform the transposition. rewriter.setInsertionPoint(linalgOp); linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace( rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult()); // Step 3. Maybe transpose unPackOp. tensor::UnPackOp transposedUnPackOp; if (maybeUnPackOp) { OpOperand &opOperand = transposedLinalgOp->getOpOperand(packUseOperandNumber); OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand); rewriter.setInsertionPoint(maybeUnPackOp); transposedUnPackOp = maybeUnPackOp.createTransposedClone( rewriter, loc, transposedResult, innerPerm, outerPerm); rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults()); } // Step 4. Finally, replace packOp now that we don't need it anymore. rewriter.replaceOp(packOp, transposedPackOp->getResults()); return PackTransposeResult{transposedPackOp, transposedLinalgOp, transposedUnPackOp}; } //===----------------------------------------------------------------------===// // packMatmulGreedily transformation. //===----------------------------------------------------------------------===// /// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m /// and n are proper parallel dimensions and k is a proper reduction /// dimension. Packing occurs by rewriting the op as a linalg.generic and /// calling linalg::pack by `mnkPackedSizes`. The order of the packed /// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2} /// to reorder {m, n, k} into one of the 8 possible forms. The outer /// dimensions of the operands are not permuted at this time, this is left for /// future work. FailureOr linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef mnkPackedSizes, ArrayRef mnkPaddedSizesNextMultipleOf, ArrayRef mnkOrder) { assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes"); assert((mnkPaddedSizesNextMultipleOf.empty() || mnkPaddedSizesNextMultipleOf.size() == 3) && "num of packing sizes next multiple should be empty or of size 3"); assert(mnkOrder.size() == 3 && "unexpected mnkOrder size"); assert(isPermutationVector(mnkOrder) && "expected a permutation"); int64_t numLoops = linalgOp.getNumLoops(); if (numLoops <= 2) { LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got " << numLoops << "\nin: " << linalgOp << "\n"); return rewriter.notifyMatchFailure( linalgOp, "need 3+ loops to find a matmul to pack"); } // Locally adjust the desired iterator position of mnk and packing sizes. int64_t numPackedDims = mnkPackedSizes.size(); SmallVector mmnnkkPos(numPackedDims); for (int64_t i = 0, e = numPackedDims; i < e; ++i) mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i]; SmallVector packedSizes(numPackedDims); for (int64_t i = 0, e = numPackedDims; i < e; ++i) packedSizes[mnkOrder[i]] = mnkPackedSizes[i]; SmallVector paddedSizesNextMultipleOf(numPackedDims); for (int64_t i = 0, e = numPackedDims; i < e; ++i) { paddedSizesNextMultipleOf[mnkOrder[i]] = mnkPaddedSizesNextMultipleOf.empty() ? 0 : mnkPaddedSizesNextMultipleOf[i]; } // 1. Infer dims that are important for matmul. FailureOr maybeDimensions = inferContractionDims(linalgOp); if (failed(maybeDimensions)) { LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp << "\n"); return rewriter.notifyMatchFailure(linalgOp, "couldn't infer matmul iterators"); } // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most // minor iterators. In cases with multiple options for m, n, k bias towards // the most minor embedding. // If we wanted a different normalization order, this is where it would have // to plug a heuristic. int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(), kPos = maybeDimensions->k.back(); LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); DBGS() << "Start packing generic op greedily with (m@" << mPos << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp << "\n";); // 2.a. Rewrite as a generic. auto genericOp = dyn_cast(linalgOp.getOperation()); if (!genericOp) { FailureOr generalizeResult = generalizeNamedOp(rewriter, linalgOp); assert(succeeded(generalizeResult) && "unexpected failure generalizing op"); genericOp = *generalizeResult; } // 2.b. Interchange to move the dimensions (k, m, n) as most-minor // iterators. Note that this only normalized the iteration order and does // not change the indexings of any operand. SmallVector permutation = computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos); LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL();); // Sign .. unsigned pollution. SmallVector unsignedPerm(permutation.begin(), permutation.end()); FailureOr interchangeResult = interchangeGenericOp(rewriter, genericOp, unsignedPerm); assert(succeeded(interchangeResult) && "unexpected failure interchanging op"); genericOp = *interchangeResult; LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";); // At this point, the op iterators are normalized to {leading, k, m, n}. // The layouts induced by packing will always be: // - LHS{leading_lhs, kk, mm} // - RHS{leading_rhs, kk, nn} // - RES{leading_res, mm, nn} // If we wanted to change the packed order, we would reorder (k, m, n) to // something else above. // // Additional permutations of the outer dims of the operands (i.e. // leading_lhs, leading_rhs and leading_res) could follow by computing the // desired outerPerm for each operand. // This is left for future work. // TODO: this creates too much IR, go use reifyResultShapes. SmallVector loopRanges = cast(genericOp.getOperation()) .createLoopRanges(rewriter, genericOp.getLoc()); // Add leading zeros to match numLoops, we only pack the last 3 dimensions // post interchange. LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf, DBGS() << "paddedSizesNextMultipleOf: "); DBGSNL();); LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: ", [](Range r) { llvm::dbgs() << r.size; }); DBGSNL();); SmallVector adjustedPackedSizes(numLoops - packedSizes.size(), rewriter.getIndexAttr(0)); for (int64_t i = 0, e = numPackedDims; i < e; ++i) { if (paddedSizesNextMultipleOf[i] == 0) { adjustedPackedSizes.push_back(packedSizes[i]); continue; } AffineExpr d0, s0; bindDims(rewriter.getContext(), d0); bindSymbols(rewriter.getContext(), s0); adjustedPackedSizes.push_back(affine::makeComposedFoldedAffineApply( rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0, {loopRanges[adjustedPackedSizes.size()].size, rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])})); } LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes, DBGS() << "adjustedPackedSizes: "); DBGSNL();); // TODO: If we wanted to give the genericOp a name after packing, after // calling `pack` would be a good time. One would still need to check that // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we // also allow degenerate matmul cases (i.e. matvec, dot). return pack(rewriter, genericOp, adjustedPackedSizes); } //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. //===----------------------------------------------------------------------===// LinalgTilingOptions & mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef ts) { assert(!tileSizeComputationFunction && "tile sizes already set"); SmallVector tileSizes(ts); tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart( &op->getParentOfType().getBody().front()); return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { Value v = b.create(op->getLoc(), s); return v; })); }; return *this; } LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( memref::CopyOp copyOp, PatternRewriter &rewriter) const { return vectorizeCopy(rewriter, copyOp); } /// Filling `dest` using FillOp constant padding value if possible. /// Otherwise, generate a tensor::GenerateOp. Value DecomposePadOpPattern::createFillOrGenerateOp( RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector &dynSizes) const { auto padValue = padOp.getConstantPaddingValue(); if (padValue) return rewriter.create(padOp.getLoc(), padValue, dest).result(); // Fill could not be optimized: Lower to tensor::GenerateOp with region. auto generateOp = rewriter.create( padOp.getLoc(), padOp.getResultType(), dynSizes); // Copy region to new op. IRMapping bvm; padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm); return generateOp; } LogicalResult DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const { // Given an OpFoldResult, return an index-typed value. auto getIdxValue = [&](OpFoldResult ofr) { if (auto val = llvm::dyn_cast_if_present(ofr)) return val; return rewriter .create( padOp.getLoc(), cast(cast(ofr)).getInt()) .getResult(); }; auto resultType = padOp.getResultType(); // Compute size of EmptyOp. Any combination of static/dynamic is supported. SmallVector dynSizes; SmallVector staticSizes; for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { if (resultType.isDynamicDim(dim)) { auto srcSize = getIdxValue(tensor::getMixedSize(rewriter, padOp.getLoc(), padOp.getSource(), dim)); // Add low and high padding value. auto plusLow = rewriter.createOrFold( padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); auto plusHigh = rewriter.createOrFold( padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); dynSizes.push_back(plusHigh); } staticSizes.push_back(resultType.getDimSize(dim)); } // Init tensor and fill it with padding. Value emptyTensor = rewriter.create( padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes); Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes); // Generate a InsertSliceOp for copying the PadOp source. auto sourceType = padOp.getSourceType(); // Compute size of source of tensor::PadOp. SmallVector srcSizes = tensor::getMixedSizes(rewriter, padOp.getLoc(), padOp.getSource()); // Strides of InsertSliceOp are all 1. SmallVector strides(sourceType.getRank(), rewriter.getIndexAttr(1)); rewriter.replaceOpWithNewOp( padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes, strides); return success(); } LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { if (!sliceOp.hasUnitStride()) return failure(); auto padOp = sliceOp.getSource().getDefiningOp(); if (!padOp) return failure(); bool zeroSliceGuard = true; if (controlFn) { if (std::optional control = controlFn(sliceOp)) zeroSliceGuard = *control; else return failure(); } FailureOr tilingResult = tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), zeroSliceGuard); if (failed(tilingResult)) return failure(); // All shapes are static and the data source is actually used. Rewrite into // pad(extract_slice(x)). rewriter.replaceOp(sliceOp, tilingResult->tiledValues); return success(); } /// If padding value is set, returns a tensor.pad Op for the source tensor, /// with the output shape matching the output of `packOp`. Otherwise, returns /// the source directly. /// /// This method assumes that all outer dims for this pack Op are 1. static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, tensor::PackOp packOp) { Value input = packOp.getSource(); if (!packOp.getPaddingValue()) { return input; } assert(llvm::all_of(packOp.getAllOuterDims(), [](int64_t val) { return val == 1; }) && "some outer dims are != 1"); Location loc = packOp.getLoc(); ShapedType inputType = packOp.getSourceType(); int64_t inputRank = inputType.getRank(); DenseMap tileAndPosMapping = packOp.getDimAndTileMapping(); // The sizes of dynamic tiles SmallVector dynamicTileSizes; // Collect dims for the padded shape. SmallVector paddedShape; for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) { // 1. Non-tiled outer dims. // These dims should be 1 and we simply preserve them. if (!tileAndPosMapping.count(dimIdx)) { int64_t inputDimSize = inputType.getDimSize(dimIdx); assert(inputDimSize == 1 && "with all outer dims == 1, this non-tiled input dim should be 1!"); paddedShape.push_back(inputDimSize); continue; } // 2. Tiled outer dims // As all outer dims == 1, it is safe to use the tile size for the padded // shape. OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx); // 2.1 Static tile sizes std::optional cstTileSize = getConstantIntValue(tileSizeForDim); if (cstTileSize.has_value()) { paddedShape.push_back(cstTileSize.value()); continue; } // 2.2 Dynamic tile sizes paddedShape.push_back(ShapedType::kDynamic); // Get the value that holds the dynamic size. dynamicTileSizes.push_back(llvm::dyn_cast(tileSizeForDim)); } auto resultType = RankedTensorType::get(paddedShape, inputType.getElementType()); return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(), /*nofold=*/false, loc, builder, dynamicTileSizes); } // Normalizes a permutation on a higher rank space to its actual size, e.g. // perm = [1, 4, 2] // becomes // norm = [0, 2, 1] static SmallVector getPackUnpackNormalizedPerm(int rank, ArrayRef perm) { constexpr int64_t kNonTiledMarker = -1; SmallVector vec(rank, kNonTiledMarker); for (auto [index, value] : llvm::enumerate(perm)) vec[value] = index; SmallVector normalizedPerm = llvm::filter_to_vector( vec, [&](int64_t v) { return v != kNonTiledMarker; }); // This inverts the permutation in addition to normalizing so invert back. return invertPermutationVector(normalizedPerm); } // Gets the normalized permutation implied by innerDimsPos and outerDimsPerm // assuming rank reduction of unit outer dims. static SmallVector getPackUnpackRankReducedPerm(ArrayRef shape, ArrayRef innerDimsPos, ArrayRef outerDimsPerm) { SmallVector rankReducedOuterDimsPerm; SmallVector outerDims; SmallVector innerDims; int64_t dim = 0; int64_t unpackedRank = shape.size(); for (auto i : llvm::seq(0, unpackedRank)) { if (llvm::is_contained(innerDimsPos, i)) { innerDims.push_back(dim++); continue; } if (shape[i] == 1) continue; outerDims.push_back(dim++); if (!outerDimsPerm.empty()) rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]); } // Get the position of the inner dims after permutation. SmallVector innerPerm = getPackUnpackNormalizedPerm(unpackedRank, innerDimsPos); applyPermutationToVector(innerDims, innerPerm); // Ditto for the outer dims. SmallVector perm = outerDims; rankReducedOuterDimsPerm = getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm); if (!rankReducedOuterDimsPerm.empty()) applyPermutationToVector(perm, rankReducedOuterDimsPerm); // The tile always ends up as the inner most dims after packing. perm.append(innerDims); return perm; } LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( tensor::PackOp packOp, PatternRewriter &rewriter) const { // TODO: support the case that outer dimensions are not all 1s. A // tensor.expand_shape will be generated in this case. if (llvm::any_of(packOp.getAllOuterDims(), [](int64_t dim) { return dim != 1; })) { return rewriter.notifyMatchFailure( packOp, "not all outer dimensions of the result are 1s"); } Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); Location loc = packOp.getLoc(); Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); DenseMap dimAndTileMapping = packOp.getDimAndTileMapping(); int64_t srcRank = packOp.getSourceRank(); int64_t destRank = packOp.getDestRank(); int64_t numTiles = destRank - srcRank; if (!llvm::all_of(packOp.getInnerDimsPos(), [&srcRank, &numTiles](int64_t dimPos) { return dimPos >= (srcRank - numTiles - 1); })) return rewriter.notifyMatchFailure( packOp, "Attempting to tile non-trailing source dims!"); // 1. Extract the inner tile sizes. // Where possible, values are replaced with constant attributes (to match the // behaviour of `getPackOpSourceOrPaddedSource`). SmallVector tileSizes; for (auto i : llvm::seq(0, srcRank)) { if (dimAndTileMapping.count(i)) { // Rather than taking the tile size as is, extact the actual constant // value Attribute where possible, e.g.: // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8] auto [_, tileSize] = getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter); tileSizes.push_back(tileSize); } } // 2. Transpose the input to match the inner tile order: // %init = tensor.empty() // %transposed_tile = linalg.transpose ins(%source_or_padded_source), // outs(%init) // Two assumptions are made: // 1. All outer dims are 1 - the corresponding transposition doesn't matter. // 2. Inner dims position correspond to the trailing `numTiles` dims. SmallVector tilesPermNormalized = getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos()); SmallVector srcPermForTranspose; for (int64_t i = 0; i < (srcRank - numTiles); i++) srcPermForTranspose.push_back(i); srcPermForTranspose.append(SmallVector(packOp.getInnerDimsPos())); LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"; llvm::interleaveComma(srcPermForTranspose, DBGS() << "perm: "); DBGSNL();); // 2.1 Create tensor.empty (init value for TransposeOp) SmallVector transShapeForEmptyOp(srcRank - numTiles, oneIdxAttr); transShapeForEmptyOp.append(tileSizes); applyPermutationToVector(transShapeForEmptyOp, srcPermForTranspose); Value empty = rewriter.create( loc, transShapeForEmptyOp, packOp.getSourceType().getElementType()); // 2.2 Create linalg.transpose auto transposedOp = rewriter.create(loc, input, empty, srcPermForTranspose); // 3. Insert the inner tile to the destination: // %inserted_tile = tensor.insert_slice(%transposed_tile) SmallVector writeStrides(destRank, oneIdxAttr); SmallVector writeOffsets(destRank, zeroIdxAttr); // Outer dims are all 1s! SmallVector writeSizes(destRank - dimAndTileMapping.size(), oneIdxAttr); SmallVector writeShape; for (auto tileSize : packOp.getMixedTiles()) { auto [tileSizeStatic, tileSizeOfr] = getSimplifiedOfrAndStaticSizePair(tileSize, rewriter); writeSizes.push_back(tileSizeOfr); writeShape.push_back(tileSizeStatic); } // 4. Replace tensor.packOp with tensor.insert_slice created above auto insert = rewriter.create( loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, writeSizes, writeStrides); rewriter.replaceOp(packOp, insert.getResult()); return success(); } LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const { int64_t srcRank = unpackOp.getSourceRank(); int64_t destRank = unpackOp.getDestRank(); ArrayRef srcShape = unpackOp.getSourceType().getShape(); ArrayRef innerDimsPos = unpackOp.getInnerDimsPos(); if (llvm::any_of(unpackOp.getTiledOuterDims(), [](int64_t dim) { return dim != 1; })) { return rewriter.notifyMatchFailure( unpackOp, "require the tiled outer dimensions of the result are all 1s"); } // 1. Use rank-reduced tensor.extract_slice op to extract the tile: // %extracted_tile = tensor.extract_slice(%unpack_op_input) Location loc = unpackOp.getLoc(); Value source = unpackOp.getSource(); DenseMap dimAndTileMapping = unpackOp.getDimAndTileMapping(); Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of // dims: // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ] SmallVector readShapeForExtractSlice; // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and // outer-tiled-dims being all 1), this will be // [ outer-untiled-dims, tile-sizes ] SmallVector extractSliceSizes; // The offset and strides attributes for ExtractSliceOp. SmallVector extractSliceOffsets(srcRank, zeroIdxAttr); SmallVector extractSliceStrides(srcRank, oneIdxAttr); // Shape for EmptyOp that's used as the init value for TransposeOp below. // This should be: // [ outer-untiled-dims, tile-sizes ] // However, skip unit dims - TransposeOp (below) applies rank-reduced // permutation. SmallVector shapeForEmptyOp; for (auto i : llvm::seq(0, destRank)) { // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims. // // As all outer tiled dims are 1, so the corresponding // slice size to read will also 1. As this will be rank-reducing "extract // slice" (i.e. the unit dims will be "collapsed"), there's no need to // update: // * the output shape for ExtractSliceOp, nor // * the shape for EmptyOp. if (dimAndTileMapping.count(i)) { extractSliceSizes.push_back(oneIdxAttr); continue; } // Compute sizes attribute for ExtractSliceOp + EmptyOp - // outer-untiled-dims if (ShapedType::isDynamic(srcShape[i])) { OpFoldResult dynamicDim = rewriter.create(loc, source, i).getResult(); extractSliceSizes.push_back(dynamicDim); shapeForEmptyOp.push_back(dynamicDim); } else { extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i])); if (srcShape[i] != 1) shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i])); } // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take // into account rank-reducing) if (srcShape[i] != 1) { readShapeForExtractSlice.push_back(srcShape[i]); } } // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the // shape for EmptyOp. auto mixedTiles = unpackOp.getMixedTiles(); extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end()); shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end()); // Explicitly create the type for extract_slice op because the inner tile // size could be 1. We want to represent the whole inner tile in this case. auto tileShape = srcShape.drop_front(destRank); // Append the inner tile shape to the permuted and rank-reduced outer shape. readShapeForExtractSlice.append(tileShape.begin(), tileShape.end()); Type elemType = unpackOp.getSourceType().getElementType(); auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType); Value innerTile = rewriter.create( loc, readType, unpackOp.getSource(), extractSliceOffsets, extractSliceSizes, extractSliceStrides); // 2. Transpose the tile to match the outer corresponding tile order. SmallVector perm = getPackUnpackRankReducedPerm( srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm()); // Unpack is a transition out of packed space so we invert the permutation. perm = invertPermutationVector(perm); applyPermutationToVector(shapeForEmptyOp, perm); Value empty = rewriter.create(loc, shapeForEmptyOp, elemType); auto transposedOp = rewriter.create(loc, innerTile, empty, perm); // 3. Handle in-complete tiles if needed. It truncates trailing data from the // transposed tile. int numLoops = shapeForEmptyOp.size(); SmallVector tileStrides(numLoops, oneIdxAttr); SmallVector tileOffsets(numLoops, zeroIdxAttr); SmallVector tileSizes; ArrayRef destShape = unpackOp.getDestType().getShape(); for (auto i : llvm::seq(0, destRank)) { if (dimAndTileMapping.count(i) || destShape[i] != 1) tileSizes.push_back( tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i)); } auto partialTile = rewriter.create( loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides); // 4. Insert the result to the destination tensor. SmallVector writeSizes; SmallVector writeStrides(destRank, oneIdxAttr); SmallVector writeOffsets(destRank, zeroIdxAttr); for (int i = 0, idx = 0; i < destRank; ++i) { if (dimAndTileMapping.count(i) || destShape[i] != 1) writeSizes.push_back(tileSizes[idx++]); else writeSizes.push_back(oneIdxAttr); } auto insert = rewriter.create( loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes, writeStrides); rewriter.replaceOp(unpackOp, insert.getResult()); return success(); } // The following are patterns for downscaling convolution ops with size-1 // window dimensions. // // Note that we'd eventually want to write such transformations in a generic // way, e.g., converting to linalg.generic, removing the size-1 dimensions, // and then turning back to named ops. But for now it's fine to have a few // patterns matching special ops to get started. template FailureOr DownscaleSizeOneWindowed2DConvolution:: returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const { if (convOp.hasPureBufferSemantics()) return failure(); // To be implemented. Value input = convOp.getInputs().front(); Value kernel = convOp.getInputs().back(); Value output = convOp.getOutputs().front(); auto inputType = dyn_cast(input.getType()); auto kernelType = dyn_cast(kernel.getType()); auto outputType = dyn_cast(output.getType()); auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); // Get domain indices based on conv2D layout. auto [khIndex, kwIndex, ohIndex, owIndex] = TypeSwitch>( convOp) .Case([&](linalg::Conv2DNhwcHwcfOp op) { return std::make_tuple(0, 1, 1, 2); }) .Case([&](linalg::Conv2DNchwFchwOp op) { return std::make_tuple(2, 3, 2, 3); }) .Case([&](linalg::PoolingNhwcSumOp op) { return std::make_tuple(0, 1, 1, 2); }) .Case([&](linalg::PoolingNchwSumOp op) { return std::make_tuple(0, 1, 2, 3); }) .Case([&](linalg::PoolingNhwcMaxOp op) { return std::make_tuple(0, 1, 1, 2); }) .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) { return std::make_tuple(0, 1, 1, 2); }) .Case([&](linalg::PoolingNhwcMinOp op) { return std::make_tuple(0, 1, 1, 2); }) .Case([&](linalg::PoolingNhwcMinUnsignedOp op) { return std::make_tuple(0, 1, 1, 2); }) .Case([&](linalg::PoolingNchwMaxOp op) { return std::make_tuple(0, 1, 2, 3); }) .Default([&](Operation *op) { llvm_unreachable("unexpected conv2d/pool2d operation."); return std::make_tuple(0, 0, 0, 0); }); // Only handle the case where at least one of the window dimensions is // of size 1. Other cases can rely on tiling to reduce to such cases. int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex]; int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex]; bool removeH = (khSize == 1 && ohSize == 1); bool removeW = (kwSize == 1 && owSize == 1); if (!removeH && !removeW) return failure(); // Get new shapes and types for all operands by removing the size-1 // dimension. using RTTBuilder = RankedTensorType::Builder; RankedTensorType newInputType = RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex)); RankedTensorType newKernelType = RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex)); RankedTensorType newOutputType = RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex)); // Rank-reduce operands. Location loc = convOp.getLoc(); Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, input, newInputType); Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, kernel, newKernelType); Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, output, newOutputType); // Rank-reduce strides and dilations too. // TODO: dropDim 1-liner helper. auto strides = llvm::to_vector<4>(convOp.getStrides().template getValues()); strides.erase(strides.begin() + (removeH ? 0 : 1)); auto stridesAttr = rewriter.getI64VectorAttr(strides); auto dilations = llvm::to_vector<4>(convOp.getDilations().template getValues()); dilations.erase(dilations.begin() + (removeH ? 0 : 1)); auto dilationsAttr = rewriter.getI64VectorAttr(dilations); auto conv1DOp = rewriter.create( loc, newOutputType, ValueRange{newInput, newKernel}, ValueRange{newOutput}, stridesAttr, dilationsAttr); // Insert back. Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( rewriter, loc, conv1DOp.getResult(0), output); rewriter.replaceOp(convOp, inserted); return conv1DOp; } template struct linalg::DownscaleSizeOneWindowed2DConvolution; template struct linalg::DownscaleSizeOneWindowed2DConvolution; template struct linalg::DownscaleSizeOneWindowed2DConvolution; template struct linalg::DownscaleSizeOneWindowed2DConvolution; template struct linalg::DownscaleSizeOneWindowed2DConvolution; template struct linalg::DownscaleSizeOneWindowed2DConvolution< PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>; template struct linalg::DownscaleSizeOneWindowed2DConvolution; template struct linalg::DownscaleSizeOneWindowed2DConvolution< PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>; template struct linalg::DownscaleSizeOneWindowed2DConvolution; FailureOr DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const { if (convOp.hasPureBufferSemantics()) return failure(); // To be implemented. Value input = convOp.getInputs().front(); Value kernel = convOp.getInputs().back(); Value output = convOp.getOutputs().front(); auto inputType = dyn_cast(input.getType()); auto kernelType = dyn_cast(kernel.getType()); auto outputType = dyn_cast(output.getType()); auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); // Only handle the case where at least one of the window dimensions is // of size 1. Other cases can rely on tiling to reduce to such cases. int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; int64_t ohSize = outputShape[1], owSize = outputShape[2]; bool removeH = (khSize == 1 && ohSize == 1); bool removeW = (kwSize == 1 && owSize == 1); if (!removeH && !removeW) return failure(); // Get new shapes and types for all operands by removing the size-1 // dimension. using RTTBuilder = RankedTensorType::Builder; RankedTensorType newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); RankedTensorType newKernelType = RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); RankedTensorType newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2); // Rank-reduce operands. Location loc = convOp.getLoc(); Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, input, newInputType); Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, kernel, newKernelType); Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, output, newOutputType); // Rank-reduce strides and dilations too. // TODO: dropDim 1-liner helper. auto strides = llvm::to_vector<4>(convOp.getStrides().getValues()); strides.erase(strides.begin() + (removeH ? 0 : 1)); auto stridesAttr = rewriter.getI64VectorAttr(strides); auto dilations = llvm::to_vector<4>(convOp.getDilations().getValues()); dilations.erase(dilations.begin() + (removeH ? 0 : 1)); auto dilationsAttr = rewriter.getI64VectorAttr(dilations); auto conv1DOp = rewriter.create( loc, newOutputType, ValueRange{newInput, newKernel}, ValueRange{newOutput}, stridesAttr, dilationsAttr); // Insert back. Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( rewriter, loc, conv1DOp.getResult(0), output); rewriter.replaceOp(convOp, inserted); return conv1DOp; } FailureOr DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const { if (convOp.hasPureBufferSemantics()) return failure(); // To be implemented. Value input = convOp.getInputs().front(); Value kernel = convOp.getInputs().back(); Value output = convOp.getOutputs().front(); auto inputType = dyn_cast(input.getType()); auto kernelType = dyn_cast(kernel.getType()); auto outputType = dyn_cast(output.getType()); auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); // Only handle the case where at least one of the window dimensions is // of size 1. Other cases can rely on tiling to reduce to such cases. int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; int64_t ohSize = outputShape[0], owSize = outputShape[1]; bool removeH = (khSize == 1 && ohSize == 1); bool removeW = (kwSize == 1 && owSize == 1); if (!removeH && !removeW) return failure(); // Get new shapes and types for all operands by removing the size-1 // dimension. using RTTBuilder = RankedTensorType::Builder; RankedTensorType newInputType = RTTBuilder(inputType).dropDim((removeH ? 0 : 1)); RankedTensorType newKernelType = RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); RankedTensorType newOutputType = RTTBuilder(outputType).dropDim(removeH ? 0 : 1); // Rank-reduce operands. Location loc = convOp.getLoc(); Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, input, newInputType); Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, kernel, newKernelType); Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, output, newOutputType); auto conv1DOp = rewriter.create(loc, newOutputType, ValueRange{newInput, newKernel}, ValueRange{newOutput}); // Insert back. Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( rewriter, loc, conv1DOp.getResult(0), output); rewriter.replaceOp(convOp, inserted); return conv1DOp; } void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add, DownscaleSizeOneWindowed2DConvolution, DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>( patterns.getContext(), benefit); patterns.add< DownscaleSizeOneWindowed2DConvolution, DownscaleSizeOneWindowed2DConvolution, DownscaleSizeOneWindowed2DConvolution, DownscaleSizeOneWindowed2DConvolution, DownscaleSizeOneWindowed2DConvolution, DownscaleSizeOneWindowed2DConvolution, DownscaleSizeOneWindowed2DConvolution>( patterns.getContext(), benefit); } void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); } void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); }