xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp (revision 97069a86193a617a9e4cf742a29db6116b2bf449)
1 //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements linalg transformation to break a reduction dimension
10 // between a parallel and a reduction dimension.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <optional>
15 #include <utility>
16 
17 #include "mlir/Analysis/SliceAnalysis.h"
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
20 #include "mlir/Dialect/Linalg/IR/Linalg.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Tensor/Utils/Utils.h"
25 #include "mlir/IR/PatternMatch.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 
splitReduction(RewriterBase & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)30 FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
31     RewriterBase &b, LinalgOp op,
32     const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
33   OpBuilder::InsertionGuard guard(b);
34   b.setInsertionPoint(op);
35 
36   SplitReductionOptions control = controlSplitReductionFn(op);
37   int64_t ratio = control.ratio;
38   unsigned insertSplitIndex = control.index;
39   unsigned insertSplitDimension = control.index;
40   if (ratio <= 1)
41     return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
42 
43   SmallVector<unsigned> dims;
44   op.getReductionDims(dims);
45 
46   if (dims.size() != 1)
47     return b.notifyMatchFailure(op, "needs a single reduction dimension");
48   unsigned reductionDim = dims[0];
49   if (control.innerParallel) {
50     insertSplitDimension = reductionDim + 1;
51   }
52   SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
53   int64_t reductionDimSize = loopRanges[reductionDim];
54   if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0)
55     return b.notifyMatchFailure(
56         op, "Reduction dimension not divisible by split ratio");
57   if (op.getNumDpsInits() != 1)
58     return b.notifyMatchFailure(op, "More than one output in split reduction");
59   if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size())
60     return b.notifyMatchFailure(op, "Insert dimension position too large "
61                                     "compared to intermediate tensor size");
62 
63   SmallVector<Operation *, 4> combinerOps;
64   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
65       combinerOps.size() != 1)
66     return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
67 
68   Operation *reductionOp = combinerOps[0];
69   std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
70   if (!identity.has_value())
71     return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
72 
73   Location loc = op->getLoc();
74   SmallVector<Value> newInputs;
75   SmallVector<AffineMap> newMaps;
76   // Calculate the new shapes and indexing maps of the input operands.
77   for (OpOperand *operand : op.getDpsInputOperands()) {
78     AffineMap map = op.getMatchingIndexingMap(operand);
79     SmallVector<int64_t> newShape;
80     SmallVector<AffineExpr> exprs;
81     SmallVector<ReassociationIndices> reassociation;
82     unsigned index = 0;
83     for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
84       unsigned dim = map.getDimPosition(idx);
85       if (reductionDim == dim) {
86         if (control.innerParallel) {
87           newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
88           newShape.push_back(ratio); // parallel (insert)
89           exprs.push_back(
90               b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
91           exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
92         } else {
93           newShape.push_back(ratio); // parallel (insert)
94           newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
95           exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
96           exprs.push_back(
97               b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
98         }
99         reassociation.push_back({index++, index++});
100         continue;
101       }
102       newShape.push_back(op.getShape(operand)[idx]);
103       exprs.push_back(
104           b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
105       reassociation.push_back({index++});
106     }
107     newMaps.push_back(
108         AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
109     // If the shape is unchanged the input doesn't change.
110     if (newShape == op.getShape(operand)) {
111       newInputs.push_back(operand->get());
112       continue;
113     }
114     Type newType = RankedTensorType::get(
115         newShape,
116         cast<RankedTensorType>(operand->get().getType()).getElementType());
117 
118     Value newInput = b.create<tensor::ExpandShapeOp>(
119         loc, newType, operand->get(), reassociation);
120     newInputs.push_back(newInput);
121   }
122 
123   // Calculate the new output map and shape, we insert the new dimension based
124   // on the index returned by `controlSplitReductionFn`.
125   SmallVector<int64_t> newOutputShape;
126   AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0));
127   ArrayRef<int64_t> oldShape = op.getShape(op.getDpsInitOperand(0));
128   SmallVector<AffineExpr> outputExpr;
129   for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) {
130     if (insertSplitIndex == idx) {
131       newOutputShape.push_back(ratio);
132       outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
133     }
134     if (idx < oldShape.size()) {
135       newOutputShape.push_back(oldShape[idx]);
136       unsigned dim = oldOutputMap.getDimPosition(idx);
137       outputExpr.push_back(
138           b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
139     }
140   }
141   Value emptyOrAllocTensor;
142   if (useAlloc) {
143     emptyOrAllocTensor = b.create<bufferization::AllocTensorOp>(
144         loc,
145         RankedTensorType::get(newOutputShape,
146                               op.getRegionOutputArgs()[0].getType()),
147         ValueRange{});
148   } else {
149     emptyOrAllocTensor = b.create<tensor::EmptyOp>(
150         loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
151   }
152   Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
153   Value identityTensor =
154       b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor)
155           .getResult(0);
156 
157   newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
158                                    op.getContext()));
159   SmallVector<utils::IteratorType> newIteratorTypes;
160   for (auto [index, iteratorType] :
161        llvm::enumerate(op.getIteratorTypesArray())) {
162     if (insertSplitDimension == index)
163       newIteratorTypes.push_back(utils::IteratorType::parallel);
164     newIteratorTypes.push_back(iteratorType);
165   }
166   if (insertSplitDimension == op.getIteratorTypesArray().size()) {
167     newIteratorTypes.push_back(utils::IteratorType::parallel);
168   }
169   // Create the new op matching the original op with an extra parallel
170   // dimension.
171   GenericOp genericOp = b.create<GenericOp>(
172       loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs,
173       ValueRange({identityTensor}), newMaps, newIteratorTypes);
174   b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
175                        genericOp.getRegion().begin());
176 
177   // Then create a new reduction that only reduce the newly added dimension
178   // from the previous op.
179   unsigned intermRank = newOutputShape.size();
180   AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
181   SmallVector<utils::IteratorType> reductionIteratorTypes;
182   SmallVector<AffineExpr> exprs;
183   for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
184     if (insertSplitIndex == i) {
185       reductionIteratorTypes.push_back(utils::IteratorType::reduction);
186     } else {
187       exprs.push_back(b.getAffineDimExpr(i));
188       reductionIteratorTypes.push_back(utils::IteratorType::parallel);
189     }
190   }
191   AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
192   SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
193 
194   auto reduction = b.create<GenericOp>(
195       loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
196       op.getDpsInits(), reductionMaps, reductionIteratorTypes,
197       [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
198         Operation *clonedReductionOp = b.clone(*reductionOp);
199         clonedReductionOp->setOperand(0, inputs[0]);
200         clonedReductionOp->setOperand(1, inputs[1]);
201         b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
202       });
203   b.replaceOp(op, reduction.getResults());
204 
205   return SplitReductionResult{emptyOrAllocTensor.getDefiningOp(),
206                               identityTensor.getDefiningOp<FillOp>(),
207                               cast<LinalgOp>(genericOp.getOperation()),
208                               reduction};
209 }
210 
211 /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
212 /// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
213 /// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
214 /// done as a transform to enable better vectorization.
scaleReductionDim(LinalgOp op,OpOperand & opOperand,unsigned reductionDimPos,int64_t reductionRatio)215 static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
216                                    unsigned reductionDimPos,
217                                    int64_t reductionRatio) {
218   auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
219   auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
220   AffineMap map = op.getMatchingIndexingMap(&opOperand);
221   AffineMap idMap =
222       AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
223   AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
224   AffineMap composeMap = shiftedIdMap.replace(
225       reductionDim, reductionDim * reductionRatio + reductionDimP1,
226       shiftedIdMap.getNumDims(), /*numSymbols=*/0);
227   return map.compose(composeMap);
228 }
229 
insertParallelDim(LinalgOp op,OpOperand & opOperand,unsigned reductionDimPos,int64_t size)230 static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
231                                    unsigned reductionDimPos, int64_t size) {
232   auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
233   AffineMap map = op.getMatchingIndexingMap(&opOperand);
234   AffineMap idMap =
235       AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
236   AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
237   return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos);
238 }
239 
240 /// Core rewrite implementation.
splitReductionByScaling(RewriterBase & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)241 FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
242     RewriterBase &b, LinalgOp op,
243     const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
244   OpBuilder::InsertionGuard guard(b);
245   b.setInsertionPoint(op);
246 
247   // Matcher part, enforce preconditions.
248   SplitReductionOptions control = controlSplitReductionFn(op);
249   if (control.innerParallel)
250     return b.notifyMatchFailure(op, "innerParallel not supported");
251 
252   int64_t splitFactor = control.ratio;
253   unsigned insertSplitDimension = control.index;
254   if (splitFactor <= 1)
255     return b.notifyMatchFailure(op, "split factor needs to be greater than 1");
256 
257   SmallVector<unsigned> dims;
258   op.getReductionDims(dims);
259   if (dims.empty())
260     return b.notifyMatchFailure(op, "needs at least 1 reduction dimension");
261 
262   unsigned reductionDimPos = dims[0];
263   SmallVector<int64_t> loopRanges = op.getStaticLoopRanges();
264   int64_t reductionDimSize = loopRanges[reductionDimPos];
265   if (reductionDimSize == ShapedType::kDynamic ||
266       reductionDimSize % splitFactor != 0 ||
267       insertSplitDimension >= loopRanges.size())
268     return b.notifyMatchFailure(
269         op, "first reduction dimension not divisible by split factor");
270 
271   SmallVector<Operation *> combinerOps;
272   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
273     return b.notifyMatchFailure(op, "cannot match a reduction pattern");
274 
275   SmallVector<TypedAttr> neutralElements;
276   for (Operation *reductionOp : combinerOps) {
277     std::optional<TypedAttr> neutralElement =
278         arith::getNeutralElement(reductionOp);
279     if (!neutralElement.has_value())
280       return b.notifyMatchFailure(op, "cannot find neutral element.");
281     neutralElements.push_back(*neutralElement);
282   }
283   if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
284     return b.notifyMatchFailure(op, "unknown reduction neutral");
285 
286   // TODO: relax this when multi-reduction support is available.
287   if (op.getNumDpsInits() != static_cast<int64_t>(neutralElements.size()))
288     return b.notifyMatchFailure(op, "expect one reduction per output");
289 
290   // Rewrite part.
291   // Step 1. Build the intermediate outputs filled with the proper
292   // neutralElements. Such outputs are of the same shape with an extra dimension
293   // inserted at `insertSplitDimension`.
294   //
295   // Consider a minimal example where `k` is reduced:
296   //     O(i, j) += I(i, j, k)
297   // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0.
298   // The compute is rewritten as:
299   //   a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
300   //   b. O(i, j) += O_i(kk, i, j)
301   // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
302   Location loc = op->getLoc();
303   MLIRContext *context = op.getContext();
304   // For now assume outputs are 1-1 with reduction neutralElements.
305   // TODO: generalize when multi-reduction support is available.
306   SmallVector<Value> newOutputs;
307   newOutputs.reserve(op.getNumDpsInits());
308   SmallVector<Operation *> emptyOrAllocTensorOps;
309   SmallVector<linalg::FillOp> fillOps;
310   fillOps.reserve(op.getNumDpsInits());
311   for (auto it : llvm::zip(op.getDpsInitsMutable(), neutralElements)) {
312     Value rankedTensor = std::get<0>(it).get();
313     auto t = cast<RankedTensorType>(rankedTensor.getType());
314     RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
315         reductionDimSize / splitFactor, insertSplitDimension);
316     SmallVector<Value> dims =
317         tensor::createDynamicDimValues(b, loc, rankedTensor);
318     Value emptyOrAllocTensor;
319     if (useAlloc) {
320       emptyOrAllocTensor =
321           b.create<bufferization::AllocTensorOp>(loc, newT, dims);
322     } else {
323       emptyOrAllocTensor = b.create<tensor::EmptyOp>(loc, newT.getShape(),
324                                                      t.getElementType(), dims);
325     }
326     Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
327     fillOps.push_back(
328         b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor));
329     newOutputs.push_back(fillOps.back().getResult(0));
330     emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.getDefiningOp());
331   }
332 
333   // Step 2. Reindex / expand indexing maps.
334   // Reindex existing input indexings: k -> k * splitFactor + k'.
335   SmallVector<AffineMap> newMaps;
336   newMaps.reserve(op->getNumOperands() + 1);
337   for (OpOperand *o : op.getDpsInputOperands())
338     newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
339   // Provision a new indexing for the shape-only tensor.
340   auto nDims = op.getNumLoops() + 1;
341   auto redDim = getAffineDimExpr(reductionDimPos, context);
342   auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context);
343   newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
344   // Expand existing output indexings.
345   // TODO: a subset of these may not reduce along reducePos and should be
346   // reindexed: k -> k * splitFactor + k', when multi-reduction support is
347   // available.
348   for (OpOperand &o : op.getDpsInitsMutable())
349     newMaps.push_back(insertParallelDim(op, o, reductionDimPos,
350                                         reductionDimSize / splitFactor));
351 
352   // Step 3. Handle operands.
353   // Compute the new input tensors.
354   SmallVector<Value> newInputs = op.getDpsInputs();
355   // Add a single shape-only tensor to carry the dimensions without resorting to
356   // more complex inversions.
357   newInputs.push_back(b.create<tensor::EmptyOp>(
358       loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
359       b.getIntegerType(1)));
360   // Output tensors are already good to go.
361 
362   // Step 4. Create the new op matching the original op with an extra parallel
363   // dimension.
364   auto iteratorTypes = op.getIteratorTypesArray();
365   iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
366                        utils::IteratorType::parallel);
367   GenericOp genericOp =
368       b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
369                           newOutputs, newMaps, iteratorTypes);
370   b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
371                        genericOp.getRegion().begin());
372   genericOp.getRegion().front().insertArgument(reductionDimPos,
373                                                b.getIntegerType(1), loc);
374 
375   // Step 5. Create new reduction ops that only reduce the newly added
376   // dimensions from the previous op.
377   // For now assume outputs are 1-1 with reduction ops.
378   // TODO: a subset of these may not reduce in the first place and do not
379   // require a new op, when multi-reduction support is available.
380   // TODO: all results can be handled in a single GenericOp, when
381   // multi-reduction support is available.
382   SmallVector<LinalgOp> results;
383   for (auto it :
384        llvm::zip(genericOp->getResults(), op.getDpsInits(), combinerOps)) {
385     Value reindexedOutput = std::get<0>(it);
386     Value originalOutput = std::get<1>(it);
387     auto originalOutputType = cast<RankedTensorType>(originalOutput.getType());
388     Operation *combinerOp = std::get<2>(it);
389 
390     AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
391     SmallVector<AffineMap> indexingMaps = {
392         map, map.dropResult(insertSplitDimension)};
393     SmallVector<utils::IteratorType> reductionIteratorTypes(
394         originalOutputType.getRank() + 1, utils::IteratorType::parallel);
395     reductionIteratorTypes[insertSplitDimension] =
396         utils::IteratorType::reduction;
397 
398     // clang-format off
399     auto reductionOp = b.create<GenericOp>(
400         loc,
401         originalOutputType,
402         reindexedOutput,
403         originalOutput,
404         indexingMaps,
405         reductionIteratorTypes,
406         [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) {
407           Operation *clonedReductionOp = b.clone(*combinerOp);
408           clonedReductionOp->setOperand(0, bbArgs[0]);
409           clonedReductionOp->setOperand(1, bbArgs[1]);
410           b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
411         });
412     // clang-format on
413 
414     results.push_back(reductionOp);
415   }
416 
417   // TODO: extend when multi-reduction support is available.
418   assert(fillOps.size() == results.size() && results.size() == 1);
419   b.replaceOp(op, results.front()->getResults());
420   return SplitReductionResult{emptyOrAllocTensorOps.front(), fillOps.front(),
421                               cast<LinalgOp>(genericOp.getOperation()),
422                               results.front()};
423 }
424 
425 namespace {
426 
427 struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
428   /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgSplitReduction__anonaac8c2be0411::LinalgSplitReduction429   LinalgSplitReduction(MLIRContext *context,
430                        ControlSplitReductionFn controlSplitReductionFn,
431                        bool useAlloc = false, PatternBenefit benefit = 1)
432       : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
433         controlSplitReductionFn(std::move(controlSplitReductionFn)),
434         useAlloc(useAlloc) {}
435 
matchAndRewrite__anonaac8c2be0411::LinalgSplitReduction436   LogicalResult matchAndRewrite(LinalgOp op,
437                                 PatternRewriter &rewriter) const override {
438     return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc);
439   }
440 
441 private:
442   ControlSplitReductionFn controlSplitReductionFn;
443   bool useAlloc;
444 };
445 
446 } // namespace
447 
populateSplitReductionPattern(RewritePatternSet & patterns,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)448 void linalg::populateSplitReductionPattern(
449     RewritePatternSet &patterns,
450     const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
451   patterns.add<LinalgSplitReduction>(patterns.getContext(),
452                                      controlSplitReductionFn, useAlloc);
453 }
454