xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp (revision 97069a86193a617a9e4cf742a29db6116b2bf449)
133d2a780SThomas Raoux //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
233d2a780SThomas Raoux //
333d2a780SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
433d2a780SThomas Raoux // See https://llvm.org/LICENSE.txt for license information.
533d2a780SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
633d2a780SThomas Raoux //
733d2a780SThomas Raoux //===----------------------------------------------------------------------===//
833d2a780SThomas Raoux //
933d2a780SThomas Raoux // This file implements linalg transformation to break a reduction dimension
1033d2a780SThomas Raoux // between a parallel and a reduction dimension.
1133d2a780SThomas Raoux //
1233d2a780SThomas Raoux //===----------------------------------------------------------------------===//
1333d2a780SThomas Raoux 
14a1fe1f5fSKazu Hirata #include <optional>
158ce6f7ddSNicolas Vasilache #include <utility>
16e188ad8bSMehdi Amini 
1733d2a780SThomas Raoux #include "mlir/Analysis/SliceAnalysis.h"
18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
19178f9bd6SNicolas Vasilache #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2033d2a780SThomas Raoux #include "mlir/Dialect/Linalg/IR/Linalg.h"
2133d2a780SThomas Raoux #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
2233d2a780SThomas Raoux #include "mlir/Dialect/Linalg/Utils/Utils.h"
2333d2a780SThomas Raoux #include "mlir/Dialect/Tensor/IR/Tensor.h"
24d5716395SNicolas Vasilache #include "mlir/Dialect/Tensor/Utils/Utils.h"
2533d2a780SThomas Raoux #include "mlir/IR/PatternMatch.h"
2633d2a780SThomas Raoux 
2733d2a780SThomas Raoux using namespace mlir;
2833d2a780SThomas Raoux using namespace mlir::linalg;
2933d2a780SThomas Raoux 
splitReduction(RewriterBase & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)30f439b319SNicolas Vasilache FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
311cff4cbdSNicolas Vasilache     RewriterBase &b, LinalgOp op,
32178f9bd6SNicolas Vasilache     const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
33f439b319SNicolas Vasilache   OpBuilder::InsertionGuard guard(b);
34f439b319SNicolas Vasilache   b.setInsertionPoint(op);
35f439b319SNicolas Vasilache 
36146c3ea0SMurali Vijayaraghavan   SplitReductionOptions control = controlSplitReductionFn(op);
37146c3ea0SMurali Vijayaraghavan   int64_t ratio = control.ratio;
38dddf6ab2SMurali Vijayaraghavan   unsigned insertSplitIndex = control.index;
392d2cdf41SMurali Vijayaraghavan   unsigned insertSplitDimension = control.index;
4033d2a780SThomas Raoux   if (ratio <= 1)
4133d2a780SThomas Raoux     return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
42f439b319SNicolas Vasilache 
4333d2a780SThomas Raoux   SmallVector<unsigned> dims;
4433d2a780SThomas Raoux   op.getReductionDims(dims);
458ce6f7ddSNicolas Vasilache 
468ce6f7ddSNicolas Vasilache   if (dims.size() != 1)
478ce6f7ddSNicolas Vasilache     return b.notifyMatchFailure(op, "needs a single reduction dimension");
4833d2a780SThomas Raoux   unsigned reductionDim = dims[0];
492d2cdf41SMurali Vijayaraghavan   if (control.innerParallel) {
502d2cdf41SMurali Vijayaraghavan     insertSplitDimension = reductionDim + 1;
512d2cdf41SMurali Vijayaraghavan   }
52919e459fSHanhan Wang   SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
53919e459fSHanhan Wang   int64_t reductionDimSize = loopRanges[reductionDim];
548ce6f7ddSNicolas Vasilache   if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0)
5533d2a780SThomas Raoux     return b.notifyMatchFailure(
5633d2a780SThomas Raoux         op, "Reduction dimension not divisible by split ratio");
57dddf6ab2SMurali Vijayaraghavan   if (op.getNumDpsInits() != 1)
58dddf6ab2SMurali Vijayaraghavan     return b.notifyMatchFailure(op, "More than one output in split reduction");
59dddf6ab2SMurali Vijayaraghavan   if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size())
60dddf6ab2SMurali Vijayaraghavan     return b.notifyMatchFailure(op, "Insert dimension position too large "
61dddf6ab2SMurali Vijayaraghavan                                     "compared to intermediate tensor size");
62f439b319SNicolas Vasilache 
6333d2a780SThomas Raoux   SmallVector<Operation *, 4> combinerOps;
6433d2a780SThomas Raoux   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
6533d2a780SThomas Raoux       combinerOps.size() != 1)
6633d2a780SThomas Raoux     return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
67f439b319SNicolas Vasilache 
6833d2a780SThomas Raoux   Operation *reductionOp = combinerOps[0];
69f8e59b09SQuentin Colombet   std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
703310fe55SThomas Raoux   if (!identity.has_value())
71f439b319SNicolas Vasilache     return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
7233d2a780SThomas Raoux 
7333d2a780SThomas Raoux   Location loc = op->getLoc();
7433d2a780SThomas Raoux   SmallVector<Value> newInputs;
7533d2a780SThomas Raoux   SmallVector<AffineMap> newMaps;
7633d2a780SThomas Raoux   // Calculate the new shapes and indexing maps of the input operands.
77b4db15a9SAlexander Belyaev   for (OpOperand *operand : op.getDpsInputOperands()) {
781227b8abSOleg Shyshkov     AffineMap map = op.getMatchingIndexingMap(operand);
7933d2a780SThomas Raoux     SmallVector<int64_t> newShape;
8033d2a780SThomas Raoux     SmallVector<AffineExpr> exprs;
8133d2a780SThomas Raoux     SmallVector<ReassociationIndices> reassociation;
8233d2a780SThomas Raoux     unsigned index = 0;
8333d2a780SThomas Raoux     for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
8433d2a780SThomas Raoux       unsigned dim = map.getDimPosition(idx);
8533d2a780SThomas Raoux       if (reductionDim == dim) {
86146c3ea0SMurali Vijayaraghavan         if (control.innerParallel) {
872d2cdf41SMurali Vijayaraghavan           newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
882d2cdf41SMurali Vijayaraghavan           newShape.push_back(ratio); // parallel (insert)
898ce6f7ddSNicolas Vasilache           exprs.push_back(
908ce6f7ddSNicolas Vasilache               b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
912d2cdf41SMurali Vijayaraghavan           exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
92146c3ea0SMurali Vijayaraghavan         } else {
932d2cdf41SMurali Vijayaraghavan           newShape.push_back(ratio); // parallel (insert)
942d2cdf41SMurali Vijayaraghavan           newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
952d2cdf41SMurali Vijayaraghavan           exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
968ce6f7ddSNicolas Vasilache           exprs.push_back(
978ce6f7ddSNicolas Vasilache               b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
98146c3ea0SMurali Vijayaraghavan         }
99dddf6ab2SMurali Vijayaraghavan         reassociation.push_back({index++, index++});
10033d2a780SThomas Raoux         continue;
10133d2a780SThomas Raoux       }
10233d2a780SThomas Raoux       newShape.push_back(op.getShape(operand)[idx]);
1038ce6f7ddSNicolas Vasilache       exprs.push_back(
1048ce6f7ddSNicolas Vasilache           b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
10533d2a780SThomas Raoux       reassociation.push_back({index++});
10633d2a780SThomas Raoux     }
10733d2a780SThomas Raoux     newMaps.push_back(
10833d2a780SThomas Raoux         AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
10933d2a780SThomas Raoux     // If the shape is unchanged the input doesn't change.
11033d2a780SThomas Raoux     if (newShape == op.getShape(operand)) {
11133d2a780SThomas Raoux       newInputs.push_back(operand->get());
11233d2a780SThomas Raoux       continue;
11333d2a780SThomas Raoux     }
11433d2a780SThomas Raoux     Type newType = RankedTensorType::get(
11533d2a780SThomas Raoux         newShape,
1165550c821STres Popp         cast<RankedTensorType>(operand->get().getType()).getElementType());
117*97069a86SGaurav Shukla 
11833d2a780SThomas Raoux     Value newInput = b.create<tensor::ExpandShapeOp>(
11933d2a780SThomas Raoux         loc, newType, operand->get(), reassociation);
12033d2a780SThomas Raoux     newInputs.push_back(newInput);
12133d2a780SThomas Raoux   }
122f439b319SNicolas Vasilache 
12333d2a780SThomas Raoux   // Calculate the new output map and shape, we insert the new dimension based
12433d2a780SThomas Raoux   // on the index returned by `controlSplitReductionFn`.
12533d2a780SThomas Raoux   SmallVector<int64_t> newOutputShape;
126b4db15a9SAlexander Belyaev   AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0));
127b4db15a9SAlexander Belyaev   ArrayRef<int64_t> oldShape = op.getShape(op.getDpsInitOperand(0));
12833d2a780SThomas Raoux   SmallVector<AffineExpr> outputExpr;
129dddf6ab2SMurali Vijayaraghavan   for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) {
130dddf6ab2SMurali Vijayaraghavan     if (insertSplitIndex == idx) {
13133d2a780SThomas Raoux       newOutputShape.push_back(ratio);
1322d2cdf41SMurali Vijayaraghavan       outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
13333d2a780SThomas Raoux     }
134dddf6ab2SMurali Vijayaraghavan     if (idx < oldShape.size()) {
135dddf6ab2SMurali Vijayaraghavan       newOutputShape.push_back(oldShape[idx]);
136dddf6ab2SMurali Vijayaraghavan       unsigned dim = oldOutputMap.getDimPosition(idx);
137146c3ea0SMurali Vijayaraghavan       outputExpr.push_back(
1382d2cdf41SMurali Vijayaraghavan           b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
13933d2a780SThomas Raoux     }
140146c3ea0SMurali Vijayaraghavan   }
14181ca5aa4SMatthias Springer   Value emptyOrAllocTensor;
142178f9bd6SNicolas Vasilache   if (useAlloc) {
14381ca5aa4SMatthias Springer     emptyOrAllocTensor = b.create<bufferization::AllocTensorOp>(
144178f9bd6SNicolas Vasilache         loc,
145178f9bd6SNicolas Vasilache         RankedTensorType::get(newOutputShape,
146178f9bd6SNicolas Vasilache                               op.getRegionOutputArgs()[0].getType()),
147178f9bd6SNicolas Vasilache         ValueRange{});
148178f9bd6SNicolas Vasilache   } else {
14981ca5aa4SMatthias Springer     emptyOrAllocTensor = b.create<tensor::EmptyOp>(
15033d2a780SThomas Raoux         loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
151178f9bd6SNicolas Vasilache   }
1523310fe55SThomas Raoux   Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
15333d2a780SThomas Raoux   Value identityTensor =
15481ca5aa4SMatthias Springer       b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor)
15533d2a780SThomas Raoux           .getResult(0);
15633d2a780SThomas Raoux 
15733d2a780SThomas Raoux   newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
15833d2a780SThomas Raoux                                    op.getContext()));
159e6598b05SOleg Shyshkov   SmallVector<utils::IteratorType> newIteratorTypes;
1608c258fdaSJakub Kuderski   for (auto [index, iteratorType] :
1618c258fdaSJakub Kuderski        llvm::enumerate(op.getIteratorTypesArray())) {
1628c258fdaSJakub Kuderski     if (insertSplitDimension == index)
163e6598b05SOleg Shyshkov       newIteratorTypes.push_back(utils::IteratorType::parallel);
1648c258fdaSJakub Kuderski     newIteratorTypes.push_back(iteratorType);
1652d2cdf41SMurali Vijayaraghavan   }
1662d2cdf41SMurali Vijayaraghavan   if (insertSplitDimension == op.getIteratorTypesArray().size()) {
167e6598b05SOleg Shyshkov     newIteratorTypes.push_back(utils::IteratorType::parallel);
16833d2a780SThomas Raoux   }
16933d2a780SThomas Raoux   // Create the new op matching the original op with an extra parallel
17033d2a780SThomas Raoux   // dimension.
17133d2a780SThomas Raoux   GenericOp genericOp = b.create<GenericOp>(
17281ca5aa4SMatthias Springer       loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs,
17333d2a780SThomas Raoux       ValueRange({identityTensor}), newMaps, newIteratorTypes);
174d3b3f765SJacques Pienaar   b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
175d3b3f765SJacques Pienaar                        genericOp.getRegion().begin());
17633d2a780SThomas Raoux 
177f439b319SNicolas Vasilache   // Then create a new reduction that only reduce the newly added dimension
178f439b319SNicolas Vasilache   // from the previous op.
17933d2a780SThomas Raoux   unsigned intermRank = newOutputShape.size();
18033d2a780SThomas Raoux   AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
181e6598b05SOleg Shyshkov   SmallVector<utils::IteratorType> reductionIteratorTypes;
18233d2a780SThomas Raoux   SmallVector<AffineExpr> exprs;
18333d2a780SThomas Raoux   for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
184dddf6ab2SMurali Vijayaraghavan     if (insertSplitIndex == i) {
185e6598b05SOleg Shyshkov       reductionIteratorTypes.push_back(utils::IteratorType::reduction);
18633d2a780SThomas Raoux     } else {
18733d2a780SThomas Raoux       exprs.push_back(b.getAffineDimExpr(i));
188e6598b05SOleg Shyshkov       reductionIteratorTypes.push_back(utils::IteratorType::parallel);
18933d2a780SThomas Raoux     }
19033d2a780SThomas Raoux   }
19133d2a780SThomas Raoux   AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
19233d2a780SThomas Raoux   SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
19333d2a780SThomas Raoux 
19433d2a780SThomas Raoux   auto reduction = b.create<GenericOp>(
19533d2a780SThomas Raoux       loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
1960b2197b0SMatthias Springer       op.getDpsInits(), reductionMaps, reductionIteratorTypes,
19733d2a780SThomas Raoux       [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
19833d2a780SThomas Raoux         Operation *clonedReductionOp = b.clone(*reductionOp);
19933d2a780SThomas Raoux         clonedReductionOp->setOperand(0, inputs[0]);
20033d2a780SThomas Raoux         clonedReductionOp->setOperand(1, inputs[1]);
20133d2a780SThomas Raoux         b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
20233d2a780SThomas Raoux       });
20333d2a780SThomas Raoux   b.replaceOp(op, reduction.getResults());
204f439b319SNicolas Vasilache 
20581ca5aa4SMatthias Springer   return SplitReductionResult{emptyOrAllocTensor.getDefiningOp(),
20681ca5aa4SMatthias Springer                               identityTensor.getDefiningOp<FillOp>(),
20781ca5aa4SMatthias Springer                               cast<LinalgOp>(genericOp.getOperation()),
20881ca5aa4SMatthias Springer                               reduction};
20933d2a780SThomas Raoux }
21033d2a780SThomas Raoux 
211d5716395SNicolas Vasilache /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
212d5716395SNicolas Vasilache /// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
213d5716395SNicolas Vasilache /// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
214d5716395SNicolas Vasilache /// done as a transform to enable better vectorization.
scaleReductionDim(LinalgOp op,OpOperand & opOperand,unsigned reductionDimPos,int64_t reductionRatio)215d5716395SNicolas Vasilache static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
216d5716395SNicolas Vasilache                                    unsigned reductionDimPos,
217d5716395SNicolas Vasilache                                    int64_t reductionRatio) {
218d5716395SNicolas Vasilache   auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
219d5716395SNicolas Vasilache   auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
2201227b8abSOleg Shyshkov   AffineMap map = op.getMatchingIndexingMap(&opOperand);
221d5716395SNicolas Vasilache   AffineMap idMap =
222d5716395SNicolas Vasilache       AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
223d5716395SNicolas Vasilache   AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
224d5716395SNicolas Vasilache   AffineMap composeMap = shiftedIdMap.replace(
225d5716395SNicolas Vasilache       reductionDim, reductionDim * reductionRatio + reductionDimP1,
226d5716395SNicolas Vasilache       shiftedIdMap.getNumDims(), /*numSymbols=*/0);
227d5716395SNicolas Vasilache   return map.compose(composeMap);
228d5716395SNicolas Vasilache }
229d5716395SNicolas Vasilache 
insertParallelDim(LinalgOp op,OpOperand & opOperand,unsigned reductionDimPos,int64_t size)230d5716395SNicolas Vasilache static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
231d5716395SNicolas Vasilache                                    unsigned reductionDimPos, int64_t size) {
232d5716395SNicolas Vasilache   auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
2331227b8abSOleg Shyshkov   AffineMap map = op.getMatchingIndexingMap(&opOperand);
234d5716395SNicolas Vasilache   AffineMap idMap =
235d5716395SNicolas Vasilache       AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
236d5716395SNicolas Vasilache   AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
237d5716395SNicolas Vasilache   return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos);
238d5716395SNicolas Vasilache }
239d5716395SNicolas Vasilache 
240d5716395SNicolas Vasilache /// Core rewrite implementation.
splitReductionByScaling(RewriterBase & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)241d5716395SNicolas Vasilache FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
2421cff4cbdSNicolas Vasilache     RewriterBase &b, LinalgOp op,
243178f9bd6SNicolas Vasilache     const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
244d5716395SNicolas Vasilache   OpBuilder::InsertionGuard guard(b);
245d5716395SNicolas Vasilache   b.setInsertionPoint(op);
246d5716395SNicolas Vasilache 
247d5716395SNicolas Vasilache   // Matcher part, enforce preconditions.
248146c3ea0SMurali Vijayaraghavan   SplitReductionOptions control = controlSplitReductionFn(op);
249146c3ea0SMurali Vijayaraghavan   if (control.innerParallel)
250146c3ea0SMurali Vijayaraghavan     return b.notifyMatchFailure(op, "innerParallel not supported");
251146c3ea0SMurali Vijayaraghavan 
252146c3ea0SMurali Vijayaraghavan   int64_t splitFactor = control.ratio;
253146c3ea0SMurali Vijayaraghavan   unsigned insertSplitDimension = control.index;
254d5716395SNicolas Vasilache   if (splitFactor <= 1)
255d5716395SNicolas Vasilache     return b.notifyMatchFailure(op, "split factor needs to be greater than 1");
256d5716395SNicolas Vasilache 
257d5716395SNicolas Vasilache   SmallVector<unsigned> dims;
258d5716395SNicolas Vasilache   op.getReductionDims(dims);
259d5716395SNicolas Vasilache   if (dims.empty())
260d5716395SNicolas Vasilache     return b.notifyMatchFailure(op, "needs at least 1 reduction dimension");
261d5716395SNicolas Vasilache 
262d5716395SNicolas Vasilache   unsigned reductionDimPos = dims[0];
263d5716395SNicolas Vasilache   SmallVector<int64_t> loopRanges = op.getStaticLoopRanges();
264d5716395SNicolas Vasilache   int64_t reductionDimSize = loopRanges[reductionDimPos];
265399638f9SAliia Khasanova   if (reductionDimSize == ShapedType::kDynamic ||
266d5716395SNicolas Vasilache       reductionDimSize % splitFactor != 0 ||
267d5716395SNicolas Vasilache       insertSplitDimension >= loopRanges.size())
268d5716395SNicolas Vasilache     return b.notifyMatchFailure(
269d5716395SNicolas Vasilache         op, "first reduction dimension not divisible by split factor");
270d5716395SNicolas Vasilache 
271d5716395SNicolas Vasilache   SmallVector<Operation *> combinerOps;
272d5716395SNicolas Vasilache   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
273d5716395SNicolas Vasilache     return b.notifyMatchFailure(op, "cannot match a reduction pattern");
274d5716395SNicolas Vasilache 
2756089d612SRahul Kayaith   SmallVector<TypedAttr> neutralElements;
2763310fe55SThomas Raoux   for (Operation *reductionOp : combinerOps) {
277f8e59b09SQuentin Colombet     std::optional<TypedAttr> neutralElement =
278f8e59b09SQuentin Colombet         arith::getNeutralElement(reductionOp);
2793310fe55SThomas Raoux     if (!neutralElement.has_value())
2803310fe55SThomas Raoux       return b.notifyMatchFailure(op, "cannot find neutral element.");
2813310fe55SThomas Raoux     neutralElements.push_back(*neutralElement);
2823310fe55SThomas Raoux   }
283d5716395SNicolas Vasilache   if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
284d5716395SNicolas Vasilache     return b.notifyMatchFailure(op, "unknown reduction neutral");
285d5716395SNicolas Vasilache 
286d5716395SNicolas Vasilache   // TODO: relax this when multi-reduction support is available.
287b4db15a9SAlexander Belyaev   if (op.getNumDpsInits() != static_cast<int64_t>(neutralElements.size()))
288d5716395SNicolas Vasilache     return b.notifyMatchFailure(op, "expect one reduction per output");
289d5716395SNicolas Vasilache 
290d5716395SNicolas Vasilache   // Rewrite part.
291d5716395SNicolas Vasilache   // Step 1. Build the intermediate outputs filled with the proper
292d5716395SNicolas Vasilache   // neutralElements. Such outputs are of the same shape with an extra dimension
293d5716395SNicolas Vasilache   // inserted at `insertSplitDimension`.
294d5716395SNicolas Vasilache   //
295d5716395SNicolas Vasilache   // Consider a minimal example where `k` is reduced:
296d5716395SNicolas Vasilache   //     O(i, j) += I(i, j, k)
297d5716395SNicolas Vasilache   // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0.
298d5716395SNicolas Vasilache   // The compute is rewritten as:
299d5716395SNicolas Vasilache   //   a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
300d5716395SNicolas Vasilache   //   b. O(i, j) += O_i(kk, i, j)
301d5716395SNicolas Vasilache   // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
302d5716395SNicolas Vasilache   Location loc = op->getLoc();
303d5716395SNicolas Vasilache   MLIRContext *context = op.getContext();
304d5716395SNicolas Vasilache   // For now assume outputs are 1-1 with reduction neutralElements.
305d5716395SNicolas Vasilache   // TODO: generalize when multi-reduction support is available.
306d5716395SNicolas Vasilache   SmallVector<Value> newOutputs;
307b4db15a9SAlexander Belyaev   newOutputs.reserve(op.getNumDpsInits());
30881ca5aa4SMatthias Springer   SmallVector<Operation *> emptyOrAllocTensorOps;
309d5716395SNicolas Vasilache   SmallVector<linalg::FillOp> fillOps;
310b4db15a9SAlexander Belyaev   fillOps.reserve(op.getNumDpsInits());
3110b2197b0SMatthias Springer   for (auto it : llvm::zip(op.getDpsInitsMutable(), neutralElements)) {
3120b2197b0SMatthias Springer     Value rankedTensor = std::get<0>(it).get();
3135550c821STres Popp     auto t = cast<RankedTensorType>(rankedTensor.getType());
314d5716395SNicolas Vasilache     RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
315d5716395SNicolas Vasilache         reductionDimSize / splitFactor, insertSplitDimension);
316d5716395SNicolas Vasilache     SmallVector<Value> dims =
317d5716395SNicolas Vasilache         tensor::createDynamicDimValues(b, loc, rankedTensor);
31881ca5aa4SMatthias Springer     Value emptyOrAllocTensor;
319178f9bd6SNicolas Vasilache     if (useAlloc) {
32081ca5aa4SMatthias Springer       emptyOrAllocTensor =
321178f9bd6SNicolas Vasilache           b.create<bufferization::AllocTensorOp>(loc, newT, dims);
322178f9bd6SNicolas Vasilache     } else {
32381ca5aa4SMatthias Springer       emptyOrAllocTensor = b.create<tensor::EmptyOp>(loc, newT.getShape(),
32481ca5aa4SMatthias Springer                                                      t.getElementType(), dims);
325178f9bd6SNicolas Vasilache     }
326d5716395SNicolas Vasilache     Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
327d5716395SNicolas Vasilache     fillOps.push_back(
32881ca5aa4SMatthias Springer         b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor));
329d5716395SNicolas Vasilache     newOutputs.push_back(fillOps.back().getResult(0));
33081ca5aa4SMatthias Springer     emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.getDefiningOp());
331d5716395SNicolas Vasilache   }
332d5716395SNicolas Vasilache 
333d5716395SNicolas Vasilache   // Step 2. Reindex / expand indexing maps.
334d5716395SNicolas Vasilache   // Reindex existing input indexings: k -> k * splitFactor + k'.
335d5716395SNicolas Vasilache   SmallVector<AffineMap> newMaps;
336a7cccb9cSAlexander Belyaev   newMaps.reserve(op->getNumOperands() + 1);
337b4db15a9SAlexander Belyaev   for (OpOperand *o : op.getDpsInputOperands())
338d5716395SNicolas Vasilache     newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
339d5716395SNicolas Vasilache   // Provision a new indexing for the shape-only tensor.
340d5716395SNicolas Vasilache   auto nDims = op.getNumLoops() + 1;
341d5716395SNicolas Vasilache   auto redDim = getAffineDimExpr(reductionDimPos, context);
342d5716395SNicolas Vasilache   auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context);
343d5716395SNicolas Vasilache   newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
344d5716395SNicolas Vasilache   // Expand existing output indexings.
345d5716395SNicolas Vasilache   // TODO: a subset of these may not reduce along reducePos and should be
346d5716395SNicolas Vasilache   // reindexed: k -> k * splitFactor + k', when multi-reduction support is
347d5716395SNicolas Vasilache   // available.
3480b2197b0SMatthias Springer   for (OpOperand &o : op.getDpsInitsMutable())
3490b2197b0SMatthias Springer     newMaps.push_back(insertParallelDim(op, o, reductionDimPos,
350d5716395SNicolas Vasilache                                         reductionDimSize / splitFactor));
351d5716395SNicolas Vasilache 
352d5716395SNicolas Vasilache   // Step 3. Handle operands.
353d5716395SNicolas Vasilache   // Compute the new input tensors.
3540b2197b0SMatthias Springer   SmallVector<Value> newInputs = op.getDpsInputs();
355d5716395SNicolas Vasilache   // Add a single shape-only tensor to carry the dimensions without resorting to
356d5716395SNicolas Vasilache   // more complex inversions.
35781ca5aa4SMatthias Springer   newInputs.push_back(b.create<tensor::EmptyOp>(
358d5716395SNicolas Vasilache       loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
359d5716395SNicolas Vasilache       b.getIntegerType(1)));
360d5716395SNicolas Vasilache   // Output tensors are already good to go.
361d5716395SNicolas Vasilache 
362d5716395SNicolas Vasilache   // Step 4. Create the new op matching the original op with an extra parallel
363d5716395SNicolas Vasilache   // dimension.
364c54bc8bdSOleg Shyshkov   auto iteratorTypes = op.getIteratorTypesArray();
365d5716395SNicolas Vasilache   iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
366e6598b05SOleg Shyshkov                        utils::IteratorType::parallel);
367d5716395SNicolas Vasilache   GenericOp genericOp =
368d5716395SNicolas Vasilache       b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
369d5716395SNicolas Vasilache                           newOutputs, newMaps, iteratorTypes);
370d3b3f765SJacques Pienaar   b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
371d3b3f765SJacques Pienaar                        genericOp.getRegion().begin());
372d3b3f765SJacques Pienaar   genericOp.getRegion().front().insertArgument(reductionDimPos,
373d5716395SNicolas Vasilache                                                b.getIntegerType(1), loc);
374d5716395SNicolas Vasilache 
375d5716395SNicolas Vasilache   // Step 5. Create new reduction ops that only reduce the newly added
376d5716395SNicolas Vasilache   // dimensions from the previous op.
377d5716395SNicolas Vasilache   // For now assume outputs are 1-1 with reduction ops.
378d5716395SNicolas Vasilache   // TODO: a subset of these may not reduce in the first place and do not
379d5716395SNicolas Vasilache   // require a new op, when multi-reduction support is available.
380d5716395SNicolas Vasilache   // TODO: all results can be handled in a single GenericOp, when
381d5716395SNicolas Vasilache   // multi-reduction support is available.
382d5716395SNicolas Vasilache   SmallVector<LinalgOp> results;
3830b2197b0SMatthias Springer   for (auto it :
3840b2197b0SMatthias Springer        llvm::zip(genericOp->getResults(), op.getDpsInits(), combinerOps)) {
385d5716395SNicolas Vasilache     Value reindexedOutput = std::get<0>(it);
3860b2197b0SMatthias Springer     Value originalOutput = std::get<1>(it);
3875550c821STres Popp     auto originalOutputType = cast<RankedTensorType>(originalOutput.getType());
388d5716395SNicolas Vasilache     Operation *combinerOp = std::get<2>(it);
389d5716395SNicolas Vasilache 
390d5716395SNicolas Vasilache     AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
391d5716395SNicolas Vasilache     SmallVector<AffineMap> indexingMaps = {
392d5716395SNicolas Vasilache         map, map.dropResult(insertSplitDimension)};
393e6598b05SOleg Shyshkov     SmallVector<utils::IteratorType> reductionIteratorTypes(
394e6598b05SOleg Shyshkov         originalOutputType.getRank() + 1, utils::IteratorType::parallel);
395d5716395SNicolas Vasilache     reductionIteratorTypes[insertSplitDimension] =
396e6598b05SOleg Shyshkov         utils::IteratorType::reduction;
397d5716395SNicolas Vasilache 
398d5716395SNicolas Vasilache     // clang-format off
399d5716395SNicolas Vasilache     auto reductionOp = b.create<GenericOp>(
400d5716395SNicolas Vasilache         loc,
401d5716395SNicolas Vasilache         originalOutputType,
402d5716395SNicolas Vasilache         reindexedOutput,
403d5716395SNicolas Vasilache         originalOutput,
404d5716395SNicolas Vasilache         indexingMaps,
405d5716395SNicolas Vasilache         reductionIteratorTypes,
406d5716395SNicolas Vasilache         [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) {
407d5716395SNicolas Vasilache           Operation *clonedReductionOp = b.clone(*combinerOp);
408d5716395SNicolas Vasilache           clonedReductionOp->setOperand(0, bbArgs[0]);
409d5716395SNicolas Vasilache           clonedReductionOp->setOperand(1, bbArgs[1]);
410d5716395SNicolas Vasilache           b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
411d5716395SNicolas Vasilache         });
412d5716395SNicolas Vasilache     // clang-format on
413d5716395SNicolas Vasilache 
414d5716395SNicolas Vasilache     results.push_back(reductionOp);
415d5716395SNicolas Vasilache   }
416d5716395SNicolas Vasilache 
417d5716395SNicolas Vasilache   // TODO: extend when multi-reduction support is available.
418d5716395SNicolas Vasilache   assert(fillOps.size() == results.size() && results.size() == 1);
419d5716395SNicolas Vasilache   b.replaceOp(op, results.front()->getResults());
42081ca5aa4SMatthias Springer   return SplitReductionResult{emptyOrAllocTensorOps.front(), fillOps.front(),
421d5716395SNicolas Vasilache                               cast<LinalgOp>(genericOp.getOperation()),
422d5716395SNicolas Vasilache                               results.front()};
423d5716395SNicolas Vasilache }
424d5716395SNicolas Vasilache 
42533d2a780SThomas Raoux namespace {
42633d2a780SThomas Raoux 
42733d2a780SThomas Raoux struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
42833d2a780SThomas Raoux   /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgSplitReduction__anonaac8c2be0411::LinalgSplitReduction42933d2a780SThomas Raoux   LinalgSplitReduction(MLIRContext *context,
43033d2a780SThomas Raoux                        ControlSplitReductionFn controlSplitReductionFn,
431e0cea169SNicolas Vasilache                        bool useAlloc = false, PatternBenefit benefit = 1)
43233d2a780SThomas Raoux       : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
433e188ad8bSMehdi Amini         controlSplitReductionFn(std::move(controlSplitReductionFn)),
434e0cea169SNicolas Vasilache         useAlloc(useAlloc) {}
43533d2a780SThomas Raoux 
matchAndRewrite__anonaac8c2be0411::LinalgSplitReduction43633d2a780SThomas Raoux   LogicalResult matchAndRewrite(LinalgOp op,
43733d2a780SThomas Raoux                                 PatternRewriter &rewriter) const override {
438e0cea169SNicolas Vasilache     return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc);
43933d2a780SThomas Raoux   }
44033d2a780SThomas Raoux 
44133d2a780SThomas Raoux private:
44233d2a780SThomas Raoux   ControlSplitReductionFn controlSplitReductionFn;
443178f9bd6SNicolas Vasilache   bool useAlloc;
44433d2a780SThomas Raoux };
44533d2a780SThomas Raoux 
44633d2a780SThomas Raoux } // namespace
44733d2a780SThomas Raoux 
populateSplitReductionPattern(RewritePatternSet & patterns,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)44833d2a780SThomas Raoux void linalg::populateSplitReductionPattern(
44933d2a780SThomas Raoux     RewritePatternSet &patterns,
450e0cea169SNicolas Vasilache     const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
45133d2a780SThomas Raoux   patterns.add<LinalgSplitReduction>(patterns.getContext(),
452e0cea169SNicolas Vasilache                                      controlSplitReductionFn, useAlloc);
45333d2a780SThomas Raoux }
454