133d2a780SThomas Raoux //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
233d2a780SThomas Raoux //
333d2a780SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
433d2a780SThomas Raoux // See https://llvm.org/LICENSE.txt for license information.
533d2a780SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
633d2a780SThomas Raoux //
733d2a780SThomas Raoux //===----------------------------------------------------------------------===//
833d2a780SThomas Raoux //
933d2a780SThomas Raoux // This file implements linalg transformation to break a reduction dimension
1033d2a780SThomas Raoux // between a parallel and a reduction dimension.
1133d2a780SThomas Raoux //
1233d2a780SThomas Raoux //===----------------------------------------------------------------------===//
1333d2a780SThomas Raoux
14a1fe1f5fSKazu Hirata #include <optional>
158ce6f7ddSNicolas Vasilache #include <utility>
16e188ad8bSMehdi Amini
1733d2a780SThomas Raoux #include "mlir/Analysis/SliceAnalysis.h"
18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
19178f9bd6SNicolas Vasilache #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2033d2a780SThomas Raoux #include "mlir/Dialect/Linalg/IR/Linalg.h"
2133d2a780SThomas Raoux #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
2233d2a780SThomas Raoux #include "mlir/Dialect/Linalg/Utils/Utils.h"
2333d2a780SThomas Raoux #include "mlir/Dialect/Tensor/IR/Tensor.h"
24d5716395SNicolas Vasilache #include "mlir/Dialect/Tensor/Utils/Utils.h"
2533d2a780SThomas Raoux #include "mlir/IR/PatternMatch.h"
2633d2a780SThomas Raoux
2733d2a780SThomas Raoux using namespace mlir;
2833d2a780SThomas Raoux using namespace mlir::linalg;
2933d2a780SThomas Raoux
splitReduction(RewriterBase & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)30f439b319SNicolas Vasilache FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
311cff4cbdSNicolas Vasilache RewriterBase &b, LinalgOp op,
32178f9bd6SNicolas Vasilache const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
33f439b319SNicolas Vasilache OpBuilder::InsertionGuard guard(b);
34f439b319SNicolas Vasilache b.setInsertionPoint(op);
35f439b319SNicolas Vasilache
36146c3ea0SMurali Vijayaraghavan SplitReductionOptions control = controlSplitReductionFn(op);
37146c3ea0SMurali Vijayaraghavan int64_t ratio = control.ratio;
38dddf6ab2SMurali Vijayaraghavan unsigned insertSplitIndex = control.index;
392d2cdf41SMurali Vijayaraghavan unsigned insertSplitDimension = control.index;
4033d2a780SThomas Raoux if (ratio <= 1)
4133d2a780SThomas Raoux return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
42f439b319SNicolas Vasilache
4333d2a780SThomas Raoux SmallVector<unsigned> dims;
4433d2a780SThomas Raoux op.getReductionDims(dims);
458ce6f7ddSNicolas Vasilache
468ce6f7ddSNicolas Vasilache if (dims.size() != 1)
478ce6f7ddSNicolas Vasilache return b.notifyMatchFailure(op, "needs a single reduction dimension");
4833d2a780SThomas Raoux unsigned reductionDim = dims[0];
492d2cdf41SMurali Vijayaraghavan if (control.innerParallel) {
502d2cdf41SMurali Vijayaraghavan insertSplitDimension = reductionDim + 1;
512d2cdf41SMurali Vijayaraghavan }
52919e459fSHanhan Wang SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
53919e459fSHanhan Wang int64_t reductionDimSize = loopRanges[reductionDim];
548ce6f7ddSNicolas Vasilache if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0)
5533d2a780SThomas Raoux return b.notifyMatchFailure(
5633d2a780SThomas Raoux op, "Reduction dimension not divisible by split ratio");
57dddf6ab2SMurali Vijayaraghavan if (op.getNumDpsInits() != 1)
58dddf6ab2SMurali Vijayaraghavan return b.notifyMatchFailure(op, "More than one output in split reduction");
59dddf6ab2SMurali Vijayaraghavan if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size())
60dddf6ab2SMurali Vijayaraghavan return b.notifyMatchFailure(op, "Insert dimension position too large "
61dddf6ab2SMurali Vijayaraghavan "compared to intermediate tensor size");
62f439b319SNicolas Vasilache
6333d2a780SThomas Raoux SmallVector<Operation *, 4> combinerOps;
6433d2a780SThomas Raoux if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
6533d2a780SThomas Raoux combinerOps.size() != 1)
6633d2a780SThomas Raoux return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
67f439b319SNicolas Vasilache
6833d2a780SThomas Raoux Operation *reductionOp = combinerOps[0];
69f8e59b09SQuentin Colombet std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
703310fe55SThomas Raoux if (!identity.has_value())
71f439b319SNicolas Vasilache return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
7233d2a780SThomas Raoux
7333d2a780SThomas Raoux Location loc = op->getLoc();
7433d2a780SThomas Raoux SmallVector<Value> newInputs;
7533d2a780SThomas Raoux SmallVector<AffineMap> newMaps;
7633d2a780SThomas Raoux // Calculate the new shapes and indexing maps of the input operands.
77b4db15a9SAlexander Belyaev for (OpOperand *operand : op.getDpsInputOperands()) {
781227b8abSOleg Shyshkov AffineMap map = op.getMatchingIndexingMap(operand);
7933d2a780SThomas Raoux SmallVector<int64_t> newShape;
8033d2a780SThomas Raoux SmallVector<AffineExpr> exprs;
8133d2a780SThomas Raoux SmallVector<ReassociationIndices> reassociation;
8233d2a780SThomas Raoux unsigned index = 0;
8333d2a780SThomas Raoux for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
8433d2a780SThomas Raoux unsigned dim = map.getDimPosition(idx);
8533d2a780SThomas Raoux if (reductionDim == dim) {
86146c3ea0SMurali Vijayaraghavan if (control.innerParallel) {
872d2cdf41SMurali Vijayaraghavan newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
882d2cdf41SMurali Vijayaraghavan newShape.push_back(ratio); // parallel (insert)
898ce6f7ddSNicolas Vasilache exprs.push_back(
908ce6f7ddSNicolas Vasilache b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
912d2cdf41SMurali Vijayaraghavan exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
92146c3ea0SMurali Vijayaraghavan } else {
932d2cdf41SMurali Vijayaraghavan newShape.push_back(ratio); // parallel (insert)
942d2cdf41SMurali Vijayaraghavan newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
952d2cdf41SMurali Vijayaraghavan exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
968ce6f7ddSNicolas Vasilache exprs.push_back(
978ce6f7ddSNicolas Vasilache b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
98146c3ea0SMurali Vijayaraghavan }
99dddf6ab2SMurali Vijayaraghavan reassociation.push_back({index++, index++});
10033d2a780SThomas Raoux continue;
10133d2a780SThomas Raoux }
10233d2a780SThomas Raoux newShape.push_back(op.getShape(operand)[idx]);
1038ce6f7ddSNicolas Vasilache exprs.push_back(
1048ce6f7ddSNicolas Vasilache b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
10533d2a780SThomas Raoux reassociation.push_back({index++});
10633d2a780SThomas Raoux }
10733d2a780SThomas Raoux newMaps.push_back(
10833d2a780SThomas Raoux AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
10933d2a780SThomas Raoux // If the shape is unchanged the input doesn't change.
11033d2a780SThomas Raoux if (newShape == op.getShape(operand)) {
11133d2a780SThomas Raoux newInputs.push_back(operand->get());
11233d2a780SThomas Raoux continue;
11333d2a780SThomas Raoux }
11433d2a780SThomas Raoux Type newType = RankedTensorType::get(
11533d2a780SThomas Raoux newShape,
1165550c821STres Popp cast<RankedTensorType>(operand->get().getType()).getElementType());
117*97069a86SGaurav Shukla
11833d2a780SThomas Raoux Value newInput = b.create<tensor::ExpandShapeOp>(
11933d2a780SThomas Raoux loc, newType, operand->get(), reassociation);
12033d2a780SThomas Raoux newInputs.push_back(newInput);
12133d2a780SThomas Raoux }
122f439b319SNicolas Vasilache
12333d2a780SThomas Raoux // Calculate the new output map and shape, we insert the new dimension based
12433d2a780SThomas Raoux // on the index returned by `controlSplitReductionFn`.
12533d2a780SThomas Raoux SmallVector<int64_t> newOutputShape;
126b4db15a9SAlexander Belyaev AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0));
127b4db15a9SAlexander Belyaev ArrayRef<int64_t> oldShape = op.getShape(op.getDpsInitOperand(0));
12833d2a780SThomas Raoux SmallVector<AffineExpr> outputExpr;
129dddf6ab2SMurali Vijayaraghavan for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) {
130dddf6ab2SMurali Vijayaraghavan if (insertSplitIndex == idx) {
13133d2a780SThomas Raoux newOutputShape.push_back(ratio);
1322d2cdf41SMurali Vijayaraghavan outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
13333d2a780SThomas Raoux }
134dddf6ab2SMurali Vijayaraghavan if (idx < oldShape.size()) {
135dddf6ab2SMurali Vijayaraghavan newOutputShape.push_back(oldShape[idx]);
136dddf6ab2SMurali Vijayaraghavan unsigned dim = oldOutputMap.getDimPosition(idx);
137146c3ea0SMurali Vijayaraghavan outputExpr.push_back(
1382d2cdf41SMurali Vijayaraghavan b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
13933d2a780SThomas Raoux }
140146c3ea0SMurali Vijayaraghavan }
14181ca5aa4SMatthias Springer Value emptyOrAllocTensor;
142178f9bd6SNicolas Vasilache if (useAlloc) {
14381ca5aa4SMatthias Springer emptyOrAllocTensor = b.create<bufferization::AllocTensorOp>(
144178f9bd6SNicolas Vasilache loc,
145178f9bd6SNicolas Vasilache RankedTensorType::get(newOutputShape,
146178f9bd6SNicolas Vasilache op.getRegionOutputArgs()[0].getType()),
147178f9bd6SNicolas Vasilache ValueRange{});
148178f9bd6SNicolas Vasilache } else {
14981ca5aa4SMatthias Springer emptyOrAllocTensor = b.create<tensor::EmptyOp>(
15033d2a780SThomas Raoux loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
151178f9bd6SNicolas Vasilache }
1523310fe55SThomas Raoux Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
15333d2a780SThomas Raoux Value identityTensor =
15481ca5aa4SMatthias Springer b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor)
15533d2a780SThomas Raoux .getResult(0);
15633d2a780SThomas Raoux
15733d2a780SThomas Raoux newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
15833d2a780SThomas Raoux op.getContext()));
159e6598b05SOleg Shyshkov SmallVector<utils::IteratorType> newIteratorTypes;
1608c258fdaSJakub Kuderski for (auto [index, iteratorType] :
1618c258fdaSJakub Kuderski llvm::enumerate(op.getIteratorTypesArray())) {
1628c258fdaSJakub Kuderski if (insertSplitDimension == index)
163e6598b05SOleg Shyshkov newIteratorTypes.push_back(utils::IteratorType::parallel);
1648c258fdaSJakub Kuderski newIteratorTypes.push_back(iteratorType);
1652d2cdf41SMurali Vijayaraghavan }
1662d2cdf41SMurali Vijayaraghavan if (insertSplitDimension == op.getIteratorTypesArray().size()) {
167e6598b05SOleg Shyshkov newIteratorTypes.push_back(utils::IteratorType::parallel);
16833d2a780SThomas Raoux }
16933d2a780SThomas Raoux // Create the new op matching the original op with an extra parallel
17033d2a780SThomas Raoux // dimension.
17133d2a780SThomas Raoux GenericOp genericOp = b.create<GenericOp>(
17281ca5aa4SMatthias Springer loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs,
17333d2a780SThomas Raoux ValueRange({identityTensor}), newMaps, newIteratorTypes);
174d3b3f765SJacques Pienaar b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
175d3b3f765SJacques Pienaar genericOp.getRegion().begin());
17633d2a780SThomas Raoux
177f439b319SNicolas Vasilache // Then create a new reduction that only reduce the newly added dimension
178f439b319SNicolas Vasilache // from the previous op.
17933d2a780SThomas Raoux unsigned intermRank = newOutputShape.size();
18033d2a780SThomas Raoux AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
181e6598b05SOleg Shyshkov SmallVector<utils::IteratorType> reductionIteratorTypes;
18233d2a780SThomas Raoux SmallVector<AffineExpr> exprs;
18333d2a780SThomas Raoux for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
184dddf6ab2SMurali Vijayaraghavan if (insertSplitIndex == i) {
185e6598b05SOleg Shyshkov reductionIteratorTypes.push_back(utils::IteratorType::reduction);
18633d2a780SThomas Raoux } else {
18733d2a780SThomas Raoux exprs.push_back(b.getAffineDimExpr(i));
188e6598b05SOleg Shyshkov reductionIteratorTypes.push_back(utils::IteratorType::parallel);
18933d2a780SThomas Raoux }
19033d2a780SThomas Raoux }
19133d2a780SThomas Raoux AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
19233d2a780SThomas Raoux SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
19333d2a780SThomas Raoux
19433d2a780SThomas Raoux auto reduction = b.create<GenericOp>(
19533d2a780SThomas Raoux loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
1960b2197b0SMatthias Springer op.getDpsInits(), reductionMaps, reductionIteratorTypes,
19733d2a780SThomas Raoux [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
19833d2a780SThomas Raoux Operation *clonedReductionOp = b.clone(*reductionOp);
19933d2a780SThomas Raoux clonedReductionOp->setOperand(0, inputs[0]);
20033d2a780SThomas Raoux clonedReductionOp->setOperand(1, inputs[1]);
20133d2a780SThomas Raoux b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
20233d2a780SThomas Raoux });
20333d2a780SThomas Raoux b.replaceOp(op, reduction.getResults());
204f439b319SNicolas Vasilache
20581ca5aa4SMatthias Springer return SplitReductionResult{emptyOrAllocTensor.getDefiningOp(),
20681ca5aa4SMatthias Springer identityTensor.getDefiningOp<FillOp>(),
20781ca5aa4SMatthias Springer cast<LinalgOp>(genericOp.getOperation()),
20881ca5aa4SMatthias Springer reduction};
20933d2a780SThomas Raoux }
21033d2a780SThomas Raoux
211d5716395SNicolas Vasilache /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
212d5716395SNicolas Vasilache /// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
213d5716395SNicolas Vasilache /// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
214d5716395SNicolas Vasilache /// done as a transform to enable better vectorization.
scaleReductionDim(LinalgOp op,OpOperand & opOperand,unsigned reductionDimPos,int64_t reductionRatio)215d5716395SNicolas Vasilache static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
216d5716395SNicolas Vasilache unsigned reductionDimPos,
217d5716395SNicolas Vasilache int64_t reductionRatio) {
218d5716395SNicolas Vasilache auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
219d5716395SNicolas Vasilache auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
2201227b8abSOleg Shyshkov AffineMap map = op.getMatchingIndexingMap(&opOperand);
221d5716395SNicolas Vasilache AffineMap idMap =
222d5716395SNicolas Vasilache AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
223d5716395SNicolas Vasilache AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
224d5716395SNicolas Vasilache AffineMap composeMap = shiftedIdMap.replace(
225d5716395SNicolas Vasilache reductionDim, reductionDim * reductionRatio + reductionDimP1,
226d5716395SNicolas Vasilache shiftedIdMap.getNumDims(), /*numSymbols=*/0);
227d5716395SNicolas Vasilache return map.compose(composeMap);
228d5716395SNicolas Vasilache }
229d5716395SNicolas Vasilache
insertParallelDim(LinalgOp op,OpOperand & opOperand,unsigned reductionDimPos,int64_t size)230d5716395SNicolas Vasilache static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
231d5716395SNicolas Vasilache unsigned reductionDimPos, int64_t size) {
232d5716395SNicolas Vasilache auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
2331227b8abSOleg Shyshkov AffineMap map = op.getMatchingIndexingMap(&opOperand);
234d5716395SNicolas Vasilache AffineMap idMap =
235d5716395SNicolas Vasilache AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
236d5716395SNicolas Vasilache AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
237d5716395SNicolas Vasilache return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos);
238d5716395SNicolas Vasilache }
239d5716395SNicolas Vasilache
240d5716395SNicolas Vasilache /// Core rewrite implementation.
splitReductionByScaling(RewriterBase & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)241d5716395SNicolas Vasilache FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
2421cff4cbdSNicolas Vasilache RewriterBase &b, LinalgOp op,
243178f9bd6SNicolas Vasilache const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
244d5716395SNicolas Vasilache OpBuilder::InsertionGuard guard(b);
245d5716395SNicolas Vasilache b.setInsertionPoint(op);
246d5716395SNicolas Vasilache
247d5716395SNicolas Vasilache // Matcher part, enforce preconditions.
248146c3ea0SMurali Vijayaraghavan SplitReductionOptions control = controlSplitReductionFn(op);
249146c3ea0SMurali Vijayaraghavan if (control.innerParallel)
250146c3ea0SMurali Vijayaraghavan return b.notifyMatchFailure(op, "innerParallel not supported");
251146c3ea0SMurali Vijayaraghavan
252146c3ea0SMurali Vijayaraghavan int64_t splitFactor = control.ratio;
253146c3ea0SMurali Vijayaraghavan unsigned insertSplitDimension = control.index;
254d5716395SNicolas Vasilache if (splitFactor <= 1)
255d5716395SNicolas Vasilache return b.notifyMatchFailure(op, "split factor needs to be greater than 1");
256d5716395SNicolas Vasilache
257d5716395SNicolas Vasilache SmallVector<unsigned> dims;
258d5716395SNicolas Vasilache op.getReductionDims(dims);
259d5716395SNicolas Vasilache if (dims.empty())
260d5716395SNicolas Vasilache return b.notifyMatchFailure(op, "needs at least 1 reduction dimension");
261d5716395SNicolas Vasilache
262d5716395SNicolas Vasilache unsigned reductionDimPos = dims[0];
263d5716395SNicolas Vasilache SmallVector<int64_t> loopRanges = op.getStaticLoopRanges();
264d5716395SNicolas Vasilache int64_t reductionDimSize = loopRanges[reductionDimPos];
265399638f9SAliia Khasanova if (reductionDimSize == ShapedType::kDynamic ||
266d5716395SNicolas Vasilache reductionDimSize % splitFactor != 0 ||
267d5716395SNicolas Vasilache insertSplitDimension >= loopRanges.size())
268d5716395SNicolas Vasilache return b.notifyMatchFailure(
269d5716395SNicolas Vasilache op, "first reduction dimension not divisible by split factor");
270d5716395SNicolas Vasilache
271d5716395SNicolas Vasilache SmallVector<Operation *> combinerOps;
272d5716395SNicolas Vasilache if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
273d5716395SNicolas Vasilache return b.notifyMatchFailure(op, "cannot match a reduction pattern");
274d5716395SNicolas Vasilache
2756089d612SRahul Kayaith SmallVector<TypedAttr> neutralElements;
2763310fe55SThomas Raoux for (Operation *reductionOp : combinerOps) {
277f8e59b09SQuentin Colombet std::optional<TypedAttr> neutralElement =
278f8e59b09SQuentin Colombet arith::getNeutralElement(reductionOp);
2793310fe55SThomas Raoux if (!neutralElement.has_value())
2803310fe55SThomas Raoux return b.notifyMatchFailure(op, "cannot find neutral element.");
2813310fe55SThomas Raoux neutralElements.push_back(*neutralElement);
2823310fe55SThomas Raoux }
283d5716395SNicolas Vasilache if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
284d5716395SNicolas Vasilache return b.notifyMatchFailure(op, "unknown reduction neutral");
285d5716395SNicolas Vasilache
286d5716395SNicolas Vasilache // TODO: relax this when multi-reduction support is available.
287b4db15a9SAlexander Belyaev if (op.getNumDpsInits() != static_cast<int64_t>(neutralElements.size()))
288d5716395SNicolas Vasilache return b.notifyMatchFailure(op, "expect one reduction per output");
289d5716395SNicolas Vasilache
290d5716395SNicolas Vasilache // Rewrite part.
291d5716395SNicolas Vasilache // Step 1. Build the intermediate outputs filled with the proper
292d5716395SNicolas Vasilache // neutralElements. Such outputs are of the same shape with an extra dimension
293d5716395SNicolas Vasilache // inserted at `insertSplitDimension`.
294d5716395SNicolas Vasilache //
295d5716395SNicolas Vasilache // Consider a minimal example where `k` is reduced:
296d5716395SNicolas Vasilache // O(i, j) += I(i, j, k)
297d5716395SNicolas Vasilache // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0.
298d5716395SNicolas Vasilache // The compute is rewritten as:
299d5716395SNicolas Vasilache // a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
300d5716395SNicolas Vasilache // b. O(i, j) += O_i(kk, i, j)
301d5716395SNicolas Vasilache // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
302d5716395SNicolas Vasilache Location loc = op->getLoc();
303d5716395SNicolas Vasilache MLIRContext *context = op.getContext();
304d5716395SNicolas Vasilache // For now assume outputs are 1-1 with reduction neutralElements.
305d5716395SNicolas Vasilache // TODO: generalize when multi-reduction support is available.
306d5716395SNicolas Vasilache SmallVector<Value> newOutputs;
307b4db15a9SAlexander Belyaev newOutputs.reserve(op.getNumDpsInits());
30881ca5aa4SMatthias Springer SmallVector<Operation *> emptyOrAllocTensorOps;
309d5716395SNicolas Vasilache SmallVector<linalg::FillOp> fillOps;
310b4db15a9SAlexander Belyaev fillOps.reserve(op.getNumDpsInits());
3110b2197b0SMatthias Springer for (auto it : llvm::zip(op.getDpsInitsMutable(), neutralElements)) {
3120b2197b0SMatthias Springer Value rankedTensor = std::get<0>(it).get();
3135550c821STres Popp auto t = cast<RankedTensorType>(rankedTensor.getType());
314d5716395SNicolas Vasilache RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
315d5716395SNicolas Vasilache reductionDimSize / splitFactor, insertSplitDimension);
316d5716395SNicolas Vasilache SmallVector<Value> dims =
317d5716395SNicolas Vasilache tensor::createDynamicDimValues(b, loc, rankedTensor);
31881ca5aa4SMatthias Springer Value emptyOrAllocTensor;
319178f9bd6SNicolas Vasilache if (useAlloc) {
32081ca5aa4SMatthias Springer emptyOrAllocTensor =
321178f9bd6SNicolas Vasilache b.create<bufferization::AllocTensorOp>(loc, newT, dims);
322178f9bd6SNicolas Vasilache } else {
32381ca5aa4SMatthias Springer emptyOrAllocTensor = b.create<tensor::EmptyOp>(loc, newT.getShape(),
32481ca5aa4SMatthias Springer t.getElementType(), dims);
325178f9bd6SNicolas Vasilache }
326d5716395SNicolas Vasilache Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
327d5716395SNicolas Vasilache fillOps.push_back(
32881ca5aa4SMatthias Springer b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor));
329d5716395SNicolas Vasilache newOutputs.push_back(fillOps.back().getResult(0));
33081ca5aa4SMatthias Springer emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.getDefiningOp());
331d5716395SNicolas Vasilache }
332d5716395SNicolas Vasilache
333d5716395SNicolas Vasilache // Step 2. Reindex / expand indexing maps.
334d5716395SNicolas Vasilache // Reindex existing input indexings: k -> k * splitFactor + k'.
335d5716395SNicolas Vasilache SmallVector<AffineMap> newMaps;
336a7cccb9cSAlexander Belyaev newMaps.reserve(op->getNumOperands() + 1);
337b4db15a9SAlexander Belyaev for (OpOperand *o : op.getDpsInputOperands())
338d5716395SNicolas Vasilache newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
339d5716395SNicolas Vasilache // Provision a new indexing for the shape-only tensor.
340d5716395SNicolas Vasilache auto nDims = op.getNumLoops() + 1;
341d5716395SNicolas Vasilache auto redDim = getAffineDimExpr(reductionDimPos, context);
342d5716395SNicolas Vasilache auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context);
343d5716395SNicolas Vasilache newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
344d5716395SNicolas Vasilache // Expand existing output indexings.
345d5716395SNicolas Vasilache // TODO: a subset of these may not reduce along reducePos and should be
346d5716395SNicolas Vasilache // reindexed: k -> k * splitFactor + k', when multi-reduction support is
347d5716395SNicolas Vasilache // available.
3480b2197b0SMatthias Springer for (OpOperand &o : op.getDpsInitsMutable())
3490b2197b0SMatthias Springer newMaps.push_back(insertParallelDim(op, o, reductionDimPos,
350d5716395SNicolas Vasilache reductionDimSize / splitFactor));
351d5716395SNicolas Vasilache
352d5716395SNicolas Vasilache // Step 3. Handle operands.
353d5716395SNicolas Vasilache // Compute the new input tensors.
3540b2197b0SMatthias Springer SmallVector<Value> newInputs = op.getDpsInputs();
355d5716395SNicolas Vasilache // Add a single shape-only tensor to carry the dimensions without resorting to
356d5716395SNicolas Vasilache // more complex inversions.
35781ca5aa4SMatthias Springer newInputs.push_back(b.create<tensor::EmptyOp>(
358d5716395SNicolas Vasilache loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
359d5716395SNicolas Vasilache b.getIntegerType(1)));
360d5716395SNicolas Vasilache // Output tensors are already good to go.
361d5716395SNicolas Vasilache
362d5716395SNicolas Vasilache // Step 4. Create the new op matching the original op with an extra parallel
363d5716395SNicolas Vasilache // dimension.
364c54bc8bdSOleg Shyshkov auto iteratorTypes = op.getIteratorTypesArray();
365d5716395SNicolas Vasilache iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
366e6598b05SOleg Shyshkov utils::IteratorType::parallel);
367d5716395SNicolas Vasilache GenericOp genericOp =
368d5716395SNicolas Vasilache b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
369d5716395SNicolas Vasilache newOutputs, newMaps, iteratorTypes);
370d3b3f765SJacques Pienaar b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
371d3b3f765SJacques Pienaar genericOp.getRegion().begin());
372d3b3f765SJacques Pienaar genericOp.getRegion().front().insertArgument(reductionDimPos,
373d5716395SNicolas Vasilache b.getIntegerType(1), loc);
374d5716395SNicolas Vasilache
375d5716395SNicolas Vasilache // Step 5. Create new reduction ops that only reduce the newly added
376d5716395SNicolas Vasilache // dimensions from the previous op.
377d5716395SNicolas Vasilache // For now assume outputs are 1-1 with reduction ops.
378d5716395SNicolas Vasilache // TODO: a subset of these may not reduce in the first place and do not
379d5716395SNicolas Vasilache // require a new op, when multi-reduction support is available.
380d5716395SNicolas Vasilache // TODO: all results can be handled in a single GenericOp, when
381d5716395SNicolas Vasilache // multi-reduction support is available.
382d5716395SNicolas Vasilache SmallVector<LinalgOp> results;
3830b2197b0SMatthias Springer for (auto it :
3840b2197b0SMatthias Springer llvm::zip(genericOp->getResults(), op.getDpsInits(), combinerOps)) {
385d5716395SNicolas Vasilache Value reindexedOutput = std::get<0>(it);
3860b2197b0SMatthias Springer Value originalOutput = std::get<1>(it);
3875550c821STres Popp auto originalOutputType = cast<RankedTensorType>(originalOutput.getType());
388d5716395SNicolas Vasilache Operation *combinerOp = std::get<2>(it);
389d5716395SNicolas Vasilache
390d5716395SNicolas Vasilache AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
391d5716395SNicolas Vasilache SmallVector<AffineMap> indexingMaps = {
392d5716395SNicolas Vasilache map, map.dropResult(insertSplitDimension)};
393e6598b05SOleg Shyshkov SmallVector<utils::IteratorType> reductionIteratorTypes(
394e6598b05SOleg Shyshkov originalOutputType.getRank() + 1, utils::IteratorType::parallel);
395d5716395SNicolas Vasilache reductionIteratorTypes[insertSplitDimension] =
396e6598b05SOleg Shyshkov utils::IteratorType::reduction;
397d5716395SNicolas Vasilache
398d5716395SNicolas Vasilache // clang-format off
399d5716395SNicolas Vasilache auto reductionOp = b.create<GenericOp>(
400d5716395SNicolas Vasilache loc,
401d5716395SNicolas Vasilache originalOutputType,
402d5716395SNicolas Vasilache reindexedOutput,
403d5716395SNicolas Vasilache originalOutput,
404d5716395SNicolas Vasilache indexingMaps,
405d5716395SNicolas Vasilache reductionIteratorTypes,
406d5716395SNicolas Vasilache [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) {
407d5716395SNicolas Vasilache Operation *clonedReductionOp = b.clone(*combinerOp);
408d5716395SNicolas Vasilache clonedReductionOp->setOperand(0, bbArgs[0]);
409d5716395SNicolas Vasilache clonedReductionOp->setOperand(1, bbArgs[1]);
410d5716395SNicolas Vasilache b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
411d5716395SNicolas Vasilache });
412d5716395SNicolas Vasilache // clang-format on
413d5716395SNicolas Vasilache
414d5716395SNicolas Vasilache results.push_back(reductionOp);
415d5716395SNicolas Vasilache }
416d5716395SNicolas Vasilache
417d5716395SNicolas Vasilache // TODO: extend when multi-reduction support is available.
418d5716395SNicolas Vasilache assert(fillOps.size() == results.size() && results.size() == 1);
419d5716395SNicolas Vasilache b.replaceOp(op, results.front()->getResults());
42081ca5aa4SMatthias Springer return SplitReductionResult{emptyOrAllocTensorOps.front(), fillOps.front(),
421d5716395SNicolas Vasilache cast<LinalgOp>(genericOp.getOperation()),
422d5716395SNicolas Vasilache results.front()};
423d5716395SNicolas Vasilache }
424d5716395SNicolas Vasilache
42533d2a780SThomas Raoux namespace {
42633d2a780SThomas Raoux
42733d2a780SThomas Raoux struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
42833d2a780SThomas Raoux /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgSplitReduction__anonaac8c2be0411::LinalgSplitReduction42933d2a780SThomas Raoux LinalgSplitReduction(MLIRContext *context,
43033d2a780SThomas Raoux ControlSplitReductionFn controlSplitReductionFn,
431e0cea169SNicolas Vasilache bool useAlloc = false, PatternBenefit benefit = 1)
43233d2a780SThomas Raoux : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
433e188ad8bSMehdi Amini controlSplitReductionFn(std::move(controlSplitReductionFn)),
434e0cea169SNicolas Vasilache useAlloc(useAlloc) {}
43533d2a780SThomas Raoux
matchAndRewrite__anonaac8c2be0411::LinalgSplitReduction43633d2a780SThomas Raoux LogicalResult matchAndRewrite(LinalgOp op,
43733d2a780SThomas Raoux PatternRewriter &rewriter) const override {
438e0cea169SNicolas Vasilache return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc);
43933d2a780SThomas Raoux }
44033d2a780SThomas Raoux
44133d2a780SThomas Raoux private:
44233d2a780SThomas Raoux ControlSplitReductionFn controlSplitReductionFn;
443178f9bd6SNicolas Vasilache bool useAlloc;
44433d2a780SThomas Raoux };
44533d2a780SThomas Raoux
44633d2a780SThomas Raoux } // namespace
44733d2a780SThomas Raoux
populateSplitReductionPattern(RewritePatternSet & patterns,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)44833d2a780SThomas Raoux void linalg::populateSplitReductionPattern(
44933d2a780SThomas Raoux RewritePatternSet &patterns,
450e0cea169SNicolas Vasilache const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
45133d2a780SThomas Raoux patterns.add<LinalgSplitReduction>(patterns.getContext(),
452e0cea169SNicolas Vasilache controlSplitReductionFn, useAlloc);
45333d2a780SThomas Raoux }
454