//===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns/pass to remove usage of unit-extent dimensions // to specify broadcasting in favor of more canonical representation of the // computation // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" namespace mlir { #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS #include "mlir/Dialect/Linalg/Passes.h.inc" } // namespace mlir #define DEBUG_TYPE "linalg-drop-unit-dims" using namespace mlir; using namespace mlir::linalg; namespace { /// Pattern to move init operands to ins when all the loops are parallel and /// blockArgument corresponding to init is used in the region. This is a fix-up /// when unit reduction dimensions are all folded away. In this context, it /// becomes a elementwise generic op. E.g., it converts /// /// %0 = tensor.empty() : tensor<1x1xf32> /// %1 = linalg.fill /// ins(%cst : f32) /// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32> /// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>, /// affine_map<(d0) -> (0, d0)>], /// iterator_types = ["parallel"]} /// ins(%arg0 : tensor<1x?x1x1xf32>) /// outs(%1 : tensor<1x1xf32>) { /// ^bb0(%in: f32, %out: f32): /// %3 = arith.addf %in, %out : f32 /// linalg.yield %3 : f32 /// } -> tensor<1x1xf32> /// /// into /// /// %0 = tensor.empty() : tensor<1x1xf32> /// %1 = linalg.fill /// ins(%cst : f32) /// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32> /// %2 = tensor.empty() : tensor<1x1xf32> /// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>, /// affine_map<(d0) -> (0, d0)>, /// affine_map<(d0) -> (0, d0)>], /// iterator_types = ["parallel"]} /// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>) /// outs(%2 : tensor<1x1xf32>) { /// ^bb0(%in: f32, %in_0: f32, %out: f32): /// %4 = arith.addf %in, %in_0 : f32 /// linalg.yield %4 : f32 /// } -> tensor<1x1xf32> struct MoveInitOperandsToInput : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { if (!genericOp.hasPureTensorSemantics()) return failure(); if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) return failure(); auto outputOperands = genericOp.getDpsInitsMutable(); SetVector candidates; for (OpOperand &op : outputOperands) { if (genericOp.getMatchingBlockArgument(&op).use_empty()) continue; candidates.insert(&op); } if (candidates.empty()) return failure(); // Compute the modified indexing maps. int64_t origNumInput = genericOp.getNumDpsInputs(); SmallVector newInputOperands = genericOp.getDpsInputs(); SmallVector indexingMaps = genericOp.getIndexingMapsArray(); SmallVector newIndexingMaps; newIndexingMaps.append(indexingMaps.begin(), std::next(indexingMaps.begin(), origNumInput)); for (OpOperand *op : candidates) { newInputOperands.push_back(op->get()); newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op)); } newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput), indexingMaps.end()); Location loc = genericOp.getLoc(); SmallVector newOutputOperands = llvm::to_vector(genericOp.getDpsInits()); for (OpOperand *op : candidates) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfterValue(op->get()); auto elemType = cast(op->get().getType()).getElementType(); auto empty = rewriter.create( loc, tensor::getMixedSizes(rewriter, loc, op->get()), elemType); unsigned start = genericOp.getDpsInits().getBeginOperandIndex(); newOutputOperands[op->getOperandNumber() - start] = empty.getResult(); } auto newOp = rewriter.create( loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands, newIndexingMaps, genericOp.getIteratorTypesArray(), /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); OpBuilder::InsertionGuard guard(rewriter); Region ®ion = newOp.getRegion(); Block *block = rewriter.createBlock(®ion); IRMapping mapper; for (auto bbarg : genericOp.getRegionInputArgs()) mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); for (OpOperand *op : candidates) { BlockArgument bbarg = genericOp.getMatchingBlockArgument(op); mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); } for (OpOperand &op : outputOperands) { BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op); if (candidates.count(&op)) block->addArgument(bbarg.getType(), loc); else mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); } for (auto &op : genericOp.getBody()->getOperations()) { rewriter.clone(op, mapper); } rewriter.replaceOp(genericOp, newOp.getResults()); return success(); } }; } // namespace //===---------------------------------------------------------------------===// // Drop loops that are unit-extents within Linalg operations. //===---------------------------------------------------------------------===// /// Implements a pass that canonicalizes the uses of unit-extent dimensions for /// broadcasting. For example, /// /// ```mlir /// #accesses = [ /// affine_map<(d0, d1) -> (0, d1)>, /// affine_map<(d0, d1) -> (d0, 0)>, /// affine_map<(d0, d1) -> (d0, d1)> /// ] /// /// #trait = { /// indexing_maps = #accesses, /// iterator_types = ["parallel", "parallel"], /// library_call = "some_external_fn" /// } /// /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> /// tensor<5x5xf32> /// { /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : /// tensor<5xf32> into tensor<1x5xf32> /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : /// tensor<5xf32> into tensor<5x1xf32> /// %2 = linalg.generic #trait %0, %1 { /// ^bb0(%arg2: f32, %arg3: f32): /// %3 = arith.addf %arg2, %arg3 : f32 /// linalg.yield %3 : f32 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> /// return %2 : tensor<5x5xf32> /// } /// /// would canonicalize to /// /// ```mlir /// #accesses = [ /// affine_map<(d0, d1) -> (d1)>, /// affine_map<(d0, d1) -> (d0)>, /// affine_map<(d0, d1) -> (d0, d1)> /// ] /// /// #trait = { /// indexing_maps = #accesses, /// iterator_types = ["parallel", "parallel"], /// library_call = "some_external_fn" /// } /// /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> /// tensor<5x5xf32> /// { /// %0 = linalg.generic #trait %arg0, %arg1 { /// ^bb0(%arg2: f32, %arg3: f32): /// %3 = arith.addf %arg2, %arg3 : f32 /// linalg.yield %3 : f32 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32> /// return %0 : tensor<5x5xf32> /// } /// Update the index accesses of linalg operations having index semantics. static void replaceUnitDimIndexOps(GenericOp genericOp, const llvm::SmallDenseSet &unitDims, RewriterBase &rewriter) { for (IndexOp indexOp : llvm::make_early_inc_range(genericOp.getBody()->getOps())) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(indexOp); if (unitDims.count(indexOp.getDim()) != 0) { rewriter.replaceOpWithNewOp(indexOp, 0); } else { // Update the dimension of the index operation if needed. unsigned droppedDims = llvm::count_if( unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); }); if (droppedDims != 0) rewriter.replaceOpWithNewOp(indexOp, indexOp.getDim() - droppedDims); } } } /// Expand the given `value` so that the type matches the type of `origDest`. /// The `reassociation` is used when `rankReductionStrategy` is set to /// `RankReductionStrategy::ReassociativeReshape`. static Value expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, ArrayRef reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) { // There are no results for memref outputs. auto origResultType = cast(origDest.getType()); if (rankReductionStrategy == ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { unsigned rank = origResultType.getRank(); SmallVector offsets(rank, rewriter.getIndexAttr(0)); SmallVector sizes = tensor::getMixedSizes(rewriter, loc, origDest); SmallVector strides(rank, rewriter.getIndexAttr(1)); return rewriter.createOrFold( loc, result, origDest, offsets, sizes, strides); } assert(rankReductionStrategy == ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape && "unknown rank reduction strategy"); return rewriter .create(loc, origResultType, result, reassociation) .getResult(); } /// Collapse the given `value` so that the type matches the type of /// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is /// set to `RankReductionStrategy::ReassociativeReshape`. static Value collapseValue( RewriterBase &rewriter, Location loc, Value operand, ArrayRef targetShape, ArrayRef reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) { if (auto memrefType = dyn_cast(operand.getType())) { if (rankReductionStrategy == ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { FailureOr rankReducingExtract = memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand, targetShape); assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); return *rankReducingExtract; } assert( rankReductionStrategy == ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape && "unknown rank reduction strategy"); MemRefLayoutAttrInterface layout; auto targetType = MemRefType::get(targetShape, memrefType.getElementType(), layout, memrefType.getMemorySpace()); return rewriter.create(loc, targetType, operand, reassociation); } if (auto tensorType = dyn_cast(operand.getType())) { if (rankReductionStrategy == ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { FailureOr rankReducingExtract = tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand, targetShape); assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); return *rankReducingExtract; } assert( rankReductionStrategy == ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape && "unknown rank reduction strategy"); auto targetType = RankedTensorType::get(targetShape, tensorType.getElementType()); return rewriter.create(loc, targetType, operand, reassociation); } llvm_unreachable("unsupported operand type"); } /// Compute the modified metadata for an operands of operation /// whose unit dims are being dropped. Return the new indexing map /// to use, the shape of the operand in the replacement op /// and the `reassocation` to use to go from original operand shape /// to modified operand shape. struct UnitExtentReplacementInfo { AffineMap indexMap; SmallVector reassociation; SmallVector targetShape; }; static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata( MLIRContext *context, GenericOp genericOp, OpOperand *opOperand, llvm::SmallDenseMap &oldDimsToNewDimsMap, ArrayRef dimReplacements) { UnitExtentReplacementInfo info; ReassociationIndices reassociationGroup; SmallVector newIndexExprs; AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); ArrayRef operandShape = genericOp.getShape(opOperand); ArrayRef exprs = indexingMap.getResults(); auto isUnitDim = [&](unsigned dim) { if (auto dimExpr = dyn_cast(exprs[dim])) { unsigned oldPosition = dimExpr.getPosition(); return !oldDimsToNewDimsMap.count(oldPosition) && (operandShape[dim] == 1); } // Handle the other case where the shape is 1, and is accessed using a // constant 0. if (operandShape[dim] == 1) { auto constAffineExpr = dyn_cast(exprs[dim]); return constAffineExpr && constAffineExpr.getValue() == 0; } return false; }; unsigned dim = 0; while (dim < operandShape.size() && isUnitDim(dim)) reassociationGroup.push_back(dim++); while (dim < operandShape.size()) { assert(!isUnitDim(dim) && "expected non unit-extent"); reassociationGroup.push_back(dim); AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements); newIndexExprs.push_back(newExpr); info.targetShape.push_back(operandShape[dim]); ++dim; // Fold all following dimensions that are unit-extent. while (dim < operandShape.size() && isUnitDim(dim)) { reassociationGroup.push_back(dim++); } info.reassociation.push_back(reassociationGroup); reassociationGroup.clear(); } info.indexMap = AffineMap::get(oldDimsToNewDimsMap.size(), indexingMap.getNumSymbols(), newIndexExprs, context); return info; } FailureOr linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, const ControlDropUnitDims &options) { SmallVector indexingMaps = genericOp.getIndexingMapsArray(); if (indexingMaps.empty()) return failure(); // 1. Check if any of the iteration dimensions are unit-trip count. They will // end up being unit-trip count if they are used to index into a unit-dim // tensor/memref. AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext())); if (!invertedMap) { return rewriter.notifyMatchFailure(genericOp, "invalid indexing maps for operation"); } SmallVector dims = genericOp.getStaticShape(); // 1a. Get the allowed list of dimensions to drop from the `options`. SmallVector allowedUnitDims = options.controlFn(genericOp); if (allowedUnitDims.empty()) { return rewriter.notifyMatchFailure( genericOp, "control function returns no allowed unit dims to prune"); } llvm::SmallDenseSet unitDimsFilter(allowedUnitDims.begin(), allowedUnitDims.end()); llvm::SmallDenseSet unitDims; for (const auto &expr : enumerate(invertedMap.getResults())) { if (AffineDimExpr dimExpr = dyn_cast(expr.value())) { if (dims[dimExpr.getPosition()] == 1 && unitDimsFilter.count(expr.index())) unitDims.insert(expr.index()); } } // 2. Compute the iterator types of the modified op by dropping the one-trip // count loops. SmallVector newIteratorTypes; llvm::SmallDenseMap oldDimToNewDimMap; SmallVector dimReplacements; unsigned newDims = 0; for (auto [index, attr] : llvm::enumerate(genericOp.getIteratorTypesArray())) { if (unitDims.count(index)) { dimReplacements.push_back( getAffineConstantExpr(0, rewriter.getContext())); } else { newIteratorTypes.push_back(attr); oldDimToNewDimMap[index] = newDims; dimReplacements.push_back( getAffineDimExpr(newDims, rewriter.getContext())); newDims++; } } // 3. For each of the operands, find the // - modified affine map to use. // - shape of the operands after the unit-dims are dropped. // - the reassociation indices used to convert from the original // operand type to modified operand (needed only when using reshapes // for rank reduction strategy) // Note that the indexing maps might need changing even if there are no // unit dimensions that are dropped to handle cases where `0` is used to // access a unit-extent tensor. Consider moving this out of this specific // transformation as a stand-alone transformation. Kept here right now due // to legacy. SmallVector newIndexingMaps; SmallVector> reassociations; SmallVector> targetShapes; SmallVector collapsed; auto hasCollapsibleType = [](OpOperand &operand) { Type operandType = operand.get().getType(); if (auto memrefOperandType = dyn_cast_or_null(operandType)) { return memrefOperandType.getLayout().isIdentity(); } if (auto tensorOperandType = dyn_cast(operandType)) { return tensorOperandType.getEncoding() == nullptr; } return false; }; for (OpOperand &opOperand : genericOp->getOpOperands()) { auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand); ArrayRef shape = genericOp.getShape(&opOperand); if (!hasCollapsibleType(opOperand)) { AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols( dimReplacements, ArrayRef{}, oldDimToNewDimMap.size(), 0); newIndexingMaps.push_back(newIndexingMap); targetShapes.push_back(llvm::to_vector(shape)); collapsed.push_back(false); reassociations.push_back({}); continue; } auto replacementInfo = dropUnitExtentFromOperandMetadata( rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap, dimReplacements); reassociations.push_back(replacementInfo.reassociation); newIndexingMaps.push_back(replacementInfo.indexMap); targetShapes.push_back(replacementInfo.targetShape); collapsed.push_back(!(replacementInfo.indexMap.getNumResults() == indexingMap.getNumResults())); } // Abort if the indexing maps of the result operation are not invertible // (i.e. not legal) or if no dimension was reduced. if (newIndexingMaps == indexingMaps || !inversePermutation( concatAffineMaps(newIndexingMaps, rewriter.getContext()))) return failure(); Location loc = genericOp.getLoc(); // 4. For each of the operands, collapse the operand to convert // from original shape to shape in the modified operation if needed, // either through use of reshapes or rank-reducing slices as // specified in `options`. SmallVector newOperands; for (OpOperand &opOperand : genericOp->getOpOperands()) { int64_t idx = opOperand.getOperandNumber(); if (!collapsed[idx]) { newOperands.push_back(opOperand.get()); continue; } newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(), targetShapes[idx], reassociations[idx], options.rankReductionStrategy)); } // 5. Create the `linalg.generic` operation with the new operands, // indexing maps, iterator types and result types. ArrayRef newInputs = ArrayRef(newOperands).take_front(genericOp.getNumDpsInputs()); ArrayRef newOutputs = ArrayRef(newOperands).take_back(genericOp.getNumDpsInits()); SmallVector resultTypes; resultTypes.reserve(genericOp.getNumResults()); for (unsigned i : llvm::seq(0, genericOp.getNumResults())) resultTypes.push_back(newOutputs[i].getType()); GenericOp replacementOp = rewriter.create(loc, resultTypes, newInputs, newOutputs, newIndexingMaps, newIteratorTypes); rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(), replacementOp.getRegion().begin()); // 5a. Replace `linalg.index` operations that refer to the dropped unit // dimensions. replaceUnitDimIndexOps(replacementOp, unitDims, rewriter); // 6. If any result type changes, insert a reshape/slice to convert from the // original type to the new type. SmallVector resultReplacements; for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) { unsigned opOperandIndex = index + replacementOp.getNumDpsInputs(); Value origDest = genericOp.getDpsInitOperand(index)->get(); if (!collapsed[opOperandIndex]) { resultReplacements.push_back(result); continue; } Value expandedValue = expandValue(rewriter, loc, result, origDest, reassociations[opOperandIndex], options.rankReductionStrategy); resultReplacements.push_back(expandedValue); } return DropUnitDimsResult{replacementOp, resultReplacements}; } namespace { struct DropUnitDims : public OpRewritePattern { DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {}, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), options(std::move(options)) {} LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { FailureOr result = dropUnitDims(rewriter, genericOp, options); if (failed(result)) { return failure(); } rewriter.replaceOp(genericOp, result->replacements); return success(); } private: ControlDropUnitDims options; }; } // namespace //===---------------------------------------------------------------------===// // Drop dimensions that are unit-extents within tensor operations. //===---------------------------------------------------------------------===// namespace { struct DropPadUnitDims : public OpRewritePattern { DropPadUnitDims(MLIRContext *context, ControlDropUnitDims options = {}, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), options(std::move(options)) {} LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override { // 1a. Get the allowed list of dimensions to drop from the `options`. SmallVector allowedUnitDims = options.controlFn(padOp); if (allowedUnitDims.empty()) { return rewriter.notifyMatchFailure( padOp, "control function returns no allowed unit dims to prune"); } if (padOp.getSourceType().getEncoding()) { return rewriter.notifyMatchFailure( padOp, "cannot collapse dims of tensor with encoding"); } // Fail for non-constant padding values. The body of the pad could // depend on the padding indices and/or properties of the padded // tensor so for now we fail. // TODO: Support non-constant padding values. Value paddingVal = padOp.getConstantPaddingValue(); if (!paddingVal) { return rewriter.notifyMatchFailure( padOp, "unimplemented: non-constant padding value"); } ArrayRef sourceShape = padOp.getSourceType().getShape(); int64_t padRank = sourceShape.size(); auto isStaticZero = [](OpFoldResult f) { std::optional maybeInt = getConstantIntValue(f); return maybeInt && *maybeInt == 0; }; llvm::SmallDenseSet unitDimsFilter(allowedUnitDims.begin(), allowedUnitDims.end()); llvm::SmallDenseSet unitDims; SmallVector newShape; SmallVector newLowPad; SmallVector newHighPad; for (const auto [dim, size, low, high] : zip_equal(llvm::seq(static_cast(0), padRank), sourceShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) { if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) && isStaticZero(high)) { unitDims.insert(dim); } else { newShape.push_back(size); newLowPad.push_back(low); newHighPad.push_back(high); } } if (unitDims.empty()) { return rewriter.notifyMatchFailure(padOp, "no unit dims to collapse"); } ReassociationIndices reassociationGroup; SmallVector reassociationMap; int64_t dim = 0; while (dim < padRank && unitDims.contains(dim)) reassociationGroup.push_back(dim++); while (dim < padRank) { assert(!unitDims.contains(dim) && "expected non unit-extent"); reassociationGroup.push_back(dim); dim++; // Fold all following dimensions that are unit-extent. while (dim < padRank && unitDims.contains(dim)) reassociationGroup.push_back(dim++); reassociationMap.push_back(reassociationGroup); reassociationGroup.clear(); } Value collapsedSource = collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape, reassociationMap, options.rankReductionStrategy); auto newPadOp = rewriter.create( padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad, newHighPad, paddingVal, padOp.getNofold()); Value dest = padOp.getResult(); if (options.rankReductionStrategy == ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { SmallVector expandedSizes; int64_t numUnitDims = 0; for (auto dim : llvm::seq(static_cast(0), padRank)) { if (unitDims.contains(dim)) { expandedSizes.push_back(rewriter.getIndexAttr(1)); numUnitDims++; continue; } expandedSizes.push_back(tensor::getMixedSize( rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims)); } dest = rewriter.create( padOp.getLoc(), expandedSizes, padOp.getResultType().getElementType()); } Value expandedValue = expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest, reassociationMap, options.rankReductionStrategy); rewriter.replaceOp(padOp, expandedValue); return success(); } private: ControlDropUnitDims options; }; } // namespace namespace { /// Convert `extract_slice` operations to rank-reduced versions. struct RankReducedExtractSliceOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { RankedTensorType resultType = sliceOp.getType(); SmallVector targetShape; for (auto size : resultType.getShape()) targetShape.push_back(rewriter.getIndexAttr(size)); auto reassociation = getReassociationMapForFoldingUnitDims(targetShape); if (!reassociation || reassociation->size() == static_cast(resultType.getRank())) return failure(); SmallVector offsets = sliceOp.getMixedOffsets(); SmallVector strides = sliceOp.getMixedStrides(); SmallVector sizes = sliceOp.getMixedSizes(); auto rankReducedType = cast( tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( reassociation->size(), sliceOp.getSourceType(), offsets, sizes, strides)); Location loc = sliceOp.getLoc(); Value newSlice = rewriter.create( loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides); rewriter.replaceOpWithNewOp( sliceOp, resultType, newSlice, *reassociation); return success(); } }; /// Convert `insert_slice` operations to rank-reduced versions. /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp. template struct RankReducedInsertSliceOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InsertOpTy insertSliceOp, PatternRewriter &rewriter) const override { RankedTensorType sourceType = insertSliceOp.getSourceType(); SmallVector targetShape; for (auto size : sourceType.getShape()) targetShape.push_back(rewriter.getIndexAttr(size)); auto reassociation = getReassociationMapForFoldingUnitDims(targetShape); if (!reassociation || reassociation->size() == static_cast(sourceType.getRank())) return failure(); Location loc = insertSliceOp.getLoc(); tensor::CollapseShapeOp reshapedSource; { OpBuilder::InsertionGuard g(rewriter); // The only difference between InsertSliceOp and ParallelInsertSliceOp // is the insertion point is just before the ParallelCombiningOp in the // parallel case. if (std::is_same::value) rewriter.setInsertionPoint(insertSliceOp->getParentOp()); reshapedSource = rewriter.create( loc, insertSliceOp.getSource(), *reassociation); } rewriter.replaceOpWithNewOp( insertSliceOp, reshapedSource, insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); return success(); } }; } // namespace /// Patterns that are used to canonicalize the use of unit-extent dims for /// broadcasting. static void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options) { auto *context = patterns.getContext(); patterns.add(context, options); patterns.add(context, options); // TODO: Patterns unrelated to unit dim folding should be factored out. patterns.add, RankReducedInsertSliceOp>( context); linalg::FillOp::getCanonicalizationPatterns(patterns, context); tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); tensor::populateFoldTensorEmptyPatterns(patterns); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); } static void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options) { auto *context = patterns.getContext(); patterns.add(context, options); patterns.add(context, options); // TODO: Patterns unrelated to unit dim folding should be factored out. linalg::FillOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::populateFoldTensorEmptyPatterns(patterns); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); } void mlir::linalg::populateFoldUnitExtentDimsPatterns( RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) { if (options.rankReductionStrategy == linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { populateFoldUnitExtentDimsViaSlicesPatterns(patterns, options); } else if (options.rankReductionStrategy == linalg::ControlDropUnitDims::RankReductionStrategy:: ReassociativeReshape) { populateFoldUnitExtentDimsViaReshapesPatterns(patterns, options); } } void mlir::linalg::populateMoveInitOperandsToInputPattern( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } namespace { /// Pass that removes unit-extent dims within generic ops. struct LinalgFoldUnitExtentDimsPass : public impl::LinalgFoldUnitExtentDimsPassBase< LinalgFoldUnitExtentDimsPass> { using impl::LinalgFoldUnitExtentDimsPassBase< LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase; void runOnOperation() override { Operation *op = getOperation(); MLIRContext *context = op->getContext(); RewritePatternSet patterns(context); ControlDropUnitDims options; if (useRankReducingSlices) { options.rankReductionStrategy = linalg::ControlDropUnitDims:: RankReductionStrategy::ExtractInsertSlice; } linalg::populateFoldUnitExtentDimsPatterns(patterns, options); populateMoveInitOperandsToInputPattern(patterns); (void)applyPatternsGreedily(op, std::move(patterns)); } }; } // namespace namespace { /// Returns reassociation indices for collapsing/expanding a /// tensor of rank `rank` at position `pos`. static SmallVector getReassociationForReshapeAtDim(int64_t rank, int64_t pos) { SmallVector reassociation(rank - 1, {0, 1}); bool lastDim = pos == rank - 1; if (rank > 2) { for (int64_t i = 0; i < rank - 1; i++) { if (i == pos || (lastDim && i == pos - 1)) reassociation[i] = ReassociationIndices{i, i + 1}; else if (i < pos) reassociation[i] = ReassociationIndices{i}; else reassociation[i] = ReassociationIndices{i + 1}; } } return reassociation; } /// Returns a collapsed `val` where the collapsing occurs at dim `pos`. /// If `pos < 0`, then don't collapse. static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val, int64_t pos) { if (pos < 0) return val; auto valType = cast(val.getType()); SmallVector collapsedShape(valType.getShape()); collapsedShape.erase(collapsedShape.begin() + pos); return collapseValue( rewriter, val.getLoc(), val, collapsedShape, getReassociationForReshapeAtDim(valType.getRank(), pos), ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape); } /// Base class for all rank reduction patterns for contraction ops /// with unit dimensions. All patterns should convert one named op /// to another named op. Intended to reduce only one iteration space dim /// at a time. /// Reducing multiple dims will happen with recusive application of /// pattern rewrites. template struct RankReduceContractionOps : OpRewritePattern { using OpRewritePattern::OpRewritePattern; /// Collapse all collapsable operands. SmallVector collapseOperands(PatternRewriter &rewriter, ArrayRef operands, ArrayRef operandCollapseDims) const { assert(operandCollapseDims.size() == 3 && operands.size() == 3 && "expected 3 operands and dims"); return llvm::map_to_vector( llvm::zip(operands, operandCollapseDims), [&](auto pair) { return collapseSingletonDimAt(rewriter, std::get<0>(pair), std::get<1>(pair)); }); } /// Expand result tensor. Value expandResult(PatternRewriter &rewriter, Value result, RankedTensorType expandedType, int64_t dim) const { return rewriter.create( result.getLoc(), expandedType, result, getReassociationForReshapeAtDim(expandedType.getRank(), dim)); } LogicalResult matchAndRewrite(FromOpTy contractionOp, PatternRewriter &rewriter) const override { auto loc = contractionOp.getLoc(); auto inputs = contractionOp.getDpsInputs(); auto inits = contractionOp.getDpsInits(); if (inputs.size() != 2 || inits.size() != 1) return rewriter.notifyMatchFailure(contractionOp, "expected 2 inputs and 1 init"); auto lhs = inputs[0]; auto rhs = inputs[1]; auto init = inits[0]; SmallVector operands{lhs, rhs, init}; SmallVector operandUnitDims; if (failed(getOperandUnitDims(contractionOp, operandUnitDims))) return rewriter.notifyMatchFailure(contractionOp, "no reducable dims found"); SmallVector collapsedOperands = collapseOperands(rewriter, operands, operandUnitDims); Value collapsedLhs = collapsedOperands[0]; Value collapsedRhs = collapsedOperands[1]; Value collapsedInit = collapsedOperands[2]; SmallVector collapsedResultTy; if (isa(collapsedInit.getType())) collapsedResultTy.push_back(collapsedInit.getType()); auto collapsedOp = rewriter.create( loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs}, ValueRange{collapsedInit}); for (auto attr : contractionOp->getAttrs()) { if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) continue; collapsedOp->setAttr(attr.getName(), attr.getValue()); } auto results = contractionOp.getResults(); assert(results.size() < 2 && "expected at most one result"); if (results.empty()) { rewriter.replaceOp(contractionOp, collapsedOp); } else { rewriter.replaceOp( contractionOp, expandResult(rewriter, collapsedOp.getResultTensors()[0], cast(results[0].getType()), operandUnitDims[2])); } return success(); } /// Populate `operandUnitDims` with 3 indices indicating the unit dim /// for each operand that should be collapsed in this pattern. If an /// operand shouldn't be collapsed, the index should be negative. virtual LogicalResult getOperandUnitDims(LinalgOp op, SmallVectorImpl &operandUnitDims) const = 0; }; /// Patterns for unbatching batched contraction ops template struct RankReduceToUnBatched : RankReduceContractionOps { using RankReduceContractionOps::RankReduceContractionOps; /// Look for unit batch dims to collapse. LogicalResult getOperandUnitDims(LinalgOp op, SmallVectorImpl &operandUnitDims) const override { FailureOr maybeContractionDims = inferContractionDims(op); if (failed(maybeContractionDims)) { LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims"); return failure(); } ContractionDimensions contractionDims = maybeContractionDims.value(); if (contractionDims.batch.size() != 1) return failure(); auto batchDim = contractionDims.batch[0]; SmallVector, 3> bOperands; op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands); if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) { return cast(std::get<0>(pair).getType()) .getShape()[std::get<1>(pair)] != 1; })) { LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found"); return failure(); } operandUnitDims = SmallVector{std::get<1>(bOperands[0]), std::get<1>(bOperands[1]), std::get<1>(bOperands[2])}; return success(); } }; /// Patterns for reducing non-batch dimensions template struct RankReduceMatmul : RankReduceContractionOps { using RankReduceContractionOps::RankReduceContractionOps; /// Helper for determining whether the lhs/init or rhs/init are reduced. static bool constexpr reduceLeft = (std::is_same_v && std::is_same_v) || (std::is_same_v && std::is_same_v) || (std::is_same_v && std::is_same_v) || (std::is_same_v && std::is_same_v) || (std::is_same_v && std::is_same_v); /// Look for non-batch spatial dims to collapse. LogicalResult getOperandUnitDims(LinalgOp op, SmallVectorImpl &operandUnitDims) const override { FailureOr maybeContractionDims = inferContractionDims(op); if (failed(maybeContractionDims)) { LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims"); return failure(); } ContractionDimensions contractionDims = maybeContractionDims.value(); if constexpr (reduceLeft) { auto m = contractionDims.m[0]; SmallVector, 2> mOperands; op.mapIterationSpaceDimToAllOperandDims(m, mOperands); if (mOperands.size() != 2) return failure(); if (llvm::all_of(mOperands, [](auto pair) { return cast(std::get<0>(pair).getType()) .getShape()[std::get<1>(pair)] == 1; })) { operandUnitDims = SmallVector{std::get<1>(mOperands[0]), -1, std::get<1>(mOperands[1])}; return success(); } } else { auto n = contractionDims.n[0]; SmallVector, 2> nOperands; op.mapIterationSpaceDimToAllOperandDims(n, nOperands); if (nOperands.size() != 2) return failure(); if (llvm::all_of(nOperands, [](auto pair) { return cast(std::get<0>(pair).getType()) .getShape()[std::get<1>(pair)] == 1; })) { operandUnitDims = SmallVector{-1, std::get<1>(nOperands[0]), std::get<1>(nOperands[1])}; return success(); } } LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found"); return failure(); } }; } // namespace void mlir::linalg::populateContractionOpRankReducingPatterns( RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); // Unbatching patterns for unit batch size patterns.add>(context); patterns .add>( context); patterns .add>( context); patterns.add>(context); patterns.add>(context); // Non-batch rank 1 reducing patterns patterns.add>(context); patterns.add>(context); patterns.add>(context); patterns.add>(context); // Batch rank 1 reducing patterns patterns.add>(context); patterns.add>(context); patterns.add>( context); patterns.add>( context); // Non-batch rank 0 reducing patterns patterns.add>(context); patterns.add>(context); }