1 //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements linalg transformation to break a reduction dimension
10 // between a parallel and a reduction dimension.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include <optional>
15 #include <utility>
16
17 #include "mlir/Analysis/SliceAnalysis.h"
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
20 #include "mlir/Dialect/Linalg/IR/Linalg.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Tensor/Utils/Utils.h"
25 #include "mlir/IR/PatternMatch.h"
26
27 using namespace mlir;
28 using namespace mlir::linalg;
29
splitReduction(RewriterBase & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)30 FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
31 RewriterBase &b, LinalgOp op,
32 const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
33 OpBuilder::InsertionGuard guard(b);
34 b.setInsertionPoint(op);
35
36 SplitReductionOptions control = controlSplitReductionFn(op);
37 int64_t ratio = control.ratio;
38 unsigned insertSplitIndex = control.index;
39 unsigned insertSplitDimension = control.index;
40 if (ratio <= 1)
41 return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
42
43 SmallVector<unsigned> dims;
44 op.getReductionDims(dims);
45
46 if (dims.size() != 1)
47 return b.notifyMatchFailure(op, "needs a single reduction dimension");
48 unsigned reductionDim = dims[0];
49 if (control.innerParallel) {
50 insertSplitDimension = reductionDim + 1;
51 }
52 SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
53 int64_t reductionDimSize = loopRanges[reductionDim];
54 if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0)
55 return b.notifyMatchFailure(
56 op, "Reduction dimension not divisible by split ratio");
57 if (op.getNumDpsInits() != 1)
58 return b.notifyMatchFailure(op, "More than one output in split reduction");
59 if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size())
60 return b.notifyMatchFailure(op, "Insert dimension position too large "
61 "compared to intermediate tensor size");
62
63 SmallVector<Operation *, 4> combinerOps;
64 if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
65 combinerOps.size() != 1)
66 return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
67
68 Operation *reductionOp = combinerOps[0];
69 std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
70 if (!identity.has_value())
71 return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
72
73 Location loc = op->getLoc();
74 SmallVector<Value> newInputs;
75 SmallVector<AffineMap> newMaps;
76 // Calculate the new shapes and indexing maps of the input operands.
77 for (OpOperand *operand : op.getDpsInputOperands()) {
78 AffineMap map = op.getMatchingIndexingMap(operand);
79 SmallVector<int64_t> newShape;
80 SmallVector<AffineExpr> exprs;
81 SmallVector<ReassociationIndices> reassociation;
82 unsigned index = 0;
83 for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
84 unsigned dim = map.getDimPosition(idx);
85 if (reductionDim == dim) {
86 if (control.innerParallel) {
87 newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
88 newShape.push_back(ratio); // parallel (insert)
89 exprs.push_back(
90 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
91 exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
92 } else {
93 newShape.push_back(ratio); // parallel (insert)
94 newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
95 exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
96 exprs.push_back(
97 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
98 }
99 reassociation.push_back({index++, index++});
100 continue;
101 }
102 newShape.push_back(op.getShape(operand)[idx]);
103 exprs.push_back(
104 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
105 reassociation.push_back({index++});
106 }
107 newMaps.push_back(
108 AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
109 // If the shape is unchanged the input doesn't change.
110 if (newShape == op.getShape(operand)) {
111 newInputs.push_back(operand->get());
112 continue;
113 }
114 Type newType = RankedTensorType::get(
115 newShape,
116 cast<RankedTensorType>(operand->get().getType()).getElementType());
117
118 Value newInput = b.create<tensor::ExpandShapeOp>(
119 loc, newType, operand->get(), reassociation);
120 newInputs.push_back(newInput);
121 }
122
123 // Calculate the new output map and shape, we insert the new dimension based
124 // on the index returned by `controlSplitReductionFn`.
125 SmallVector<int64_t> newOutputShape;
126 AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0));
127 ArrayRef<int64_t> oldShape = op.getShape(op.getDpsInitOperand(0));
128 SmallVector<AffineExpr> outputExpr;
129 for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) {
130 if (insertSplitIndex == idx) {
131 newOutputShape.push_back(ratio);
132 outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
133 }
134 if (idx < oldShape.size()) {
135 newOutputShape.push_back(oldShape[idx]);
136 unsigned dim = oldOutputMap.getDimPosition(idx);
137 outputExpr.push_back(
138 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
139 }
140 }
141 Value emptyOrAllocTensor;
142 if (useAlloc) {
143 emptyOrAllocTensor = b.create<bufferization::AllocTensorOp>(
144 loc,
145 RankedTensorType::get(newOutputShape,
146 op.getRegionOutputArgs()[0].getType()),
147 ValueRange{});
148 } else {
149 emptyOrAllocTensor = b.create<tensor::EmptyOp>(
150 loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
151 }
152 Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
153 Value identityTensor =
154 b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor)
155 .getResult(0);
156
157 newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
158 op.getContext()));
159 SmallVector<utils::IteratorType> newIteratorTypes;
160 for (auto [index, iteratorType] :
161 llvm::enumerate(op.getIteratorTypesArray())) {
162 if (insertSplitDimension == index)
163 newIteratorTypes.push_back(utils::IteratorType::parallel);
164 newIteratorTypes.push_back(iteratorType);
165 }
166 if (insertSplitDimension == op.getIteratorTypesArray().size()) {
167 newIteratorTypes.push_back(utils::IteratorType::parallel);
168 }
169 // Create the new op matching the original op with an extra parallel
170 // dimension.
171 GenericOp genericOp = b.create<GenericOp>(
172 loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs,
173 ValueRange({identityTensor}), newMaps, newIteratorTypes);
174 b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
175 genericOp.getRegion().begin());
176
177 // Then create a new reduction that only reduce the newly added dimension
178 // from the previous op.
179 unsigned intermRank = newOutputShape.size();
180 AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
181 SmallVector<utils::IteratorType> reductionIteratorTypes;
182 SmallVector<AffineExpr> exprs;
183 for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
184 if (insertSplitIndex == i) {
185 reductionIteratorTypes.push_back(utils::IteratorType::reduction);
186 } else {
187 exprs.push_back(b.getAffineDimExpr(i));
188 reductionIteratorTypes.push_back(utils::IteratorType::parallel);
189 }
190 }
191 AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
192 SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
193
194 auto reduction = b.create<GenericOp>(
195 loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
196 op.getDpsInits(), reductionMaps, reductionIteratorTypes,
197 [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
198 Operation *clonedReductionOp = b.clone(*reductionOp);
199 clonedReductionOp->setOperand(0, inputs[0]);
200 clonedReductionOp->setOperand(1, inputs[1]);
201 b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
202 });
203 b.replaceOp(op, reduction.getResults());
204
205 return SplitReductionResult{emptyOrAllocTensor.getDefiningOp(),
206 identityTensor.getDefiningOp<FillOp>(),
207 cast<LinalgOp>(genericOp.getOperation()),
208 reduction};
209 }
210
211 /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
212 /// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
213 /// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
214 /// done as a transform to enable better vectorization.
scaleReductionDim(LinalgOp op,OpOperand & opOperand,unsigned reductionDimPos,int64_t reductionRatio)215 static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
216 unsigned reductionDimPos,
217 int64_t reductionRatio) {
218 auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
219 auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
220 AffineMap map = op.getMatchingIndexingMap(&opOperand);
221 AffineMap idMap =
222 AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
223 AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
224 AffineMap composeMap = shiftedIdMap.replace(
225 reductionDim, reductionDim * reductionRatio + reductionDimP1,
226 shiftedIdMap.getNumDims(), /*numSymbols=*/0);
227 return map.compose(composeMap);
228 }
229
insertParallelDim(LinalgOp op,OpOperand & opOperand,unsigned reductionDimPos,int64_t size)230 static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
231 unsigned reductionDimPos, int64_t size) {
232 auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
233 AffineMap map = op.getMatchingIndexingMap(&opOperand);
234 AffineMap idMap =
235 AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
236 AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
237 return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos);
238 }
239
240 /// Core rewrite implementation.
splitReductionByScaling(RewriterBase & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)241 FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
242 RewriterBase &b, LinalgOp op,
243 const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
244 OpBuilder::InsertionGuard guard(b);
245 b.setInsertionPoint(op);
246
247 // Matcher part, enforce preconditions.
248 SplitReductionOptions control = controlSplitReductionFn(op);
249 if (control.innerParallel)
250 return b.notifyMatchFailure(op, "innerParallel not supported");
251
252 int64_t splitFactor = control.ratio;
253 unsigned insertSplitDimension = control.index;
254 if (splitFactor <= 1)
255 return b.notifyMatchFailure(op, "split factor needs to be greater than 1");
256
257 SmallVector<unsigned> dims;
258 op.getReductionDims(dims);
259 if (dims.empty())
260 return b.notifyMatchFailure(op, "needs at least 1 reduction dimension");
261
262 unsigned reductionDimPos = dims[0];
263 SmallVector<int64_t> loopRanges = op.getStaticLoopRanges();
264 int64_t reductionDimSize = loopRanges[reductionDimPos];
265 if (reductionDimSize == ShapedType::kDynamic ||
266 reductionDimSize % splitFactor != 0 ||
267 insertSplitDimension >= loopRanges.size())
268 return b.notifyMatchFailure(
269 op, "first reduction dimension not divisible by split factor");
270
271 SmallVector<Operation *> combinerOps;
272 if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
273 return b.notifyMatchFailure(op, "cannot match a reduction pattern");
274
275 SmallVector<TypedAttr> neutralElements;
276 for (Operation *reductionOp : combinerOps) {
277 std::optional<TypedAttr> neutralElement =
278 arith::getNeutralElement(reductionOp);
279 if (!neutralElement.has_value())
280 return b.notifyMatchFailure(op, "cannot find neutral element.");
281 neutralElements.push_back(*neutralElement);
282 }
283 if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
284 return b.notifyMatchFailure(op, "unknown reduction neutral");
285
286 // TODO: relax this when multi-reduction support is available.
287 if (op.getNumDpsInits() != static_cast<int64_t>(neutralElements.size()))
288 return b.notifyMatchFailure(op, "expect one reduction per output");
289
290 // Rewrite part.
291 // Step 1. Build the intermediate outputs filled with the proper
292 // neutralElements. Such outputs are of the same shape with an extra dimension
293 // inserted at `insertSplitDimension`.
294 //
295 // Consider a minimal example where `k` is reduced:
296 // O(i, j) += I(i, j, k)
297 // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0.
298 // The compute is rewritten as:
299 // a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
300 // b. O(i, j) += O_i(kk, i, j)
301 // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
302 Location loc = op->getLoc();
303 MLIRContext *context = op.getContext();
304 // For now assume outputs are 1-1 with reduction neutralElements.
305 // TODO: generalize when multi-reduction support is available.
306 SmallVector<Value> newOutputs;
307 newOutputs.reserve(op.getNumDpsInits());
308 SmallVector<Operation *> emptyOrAllocTensorOps;
309 SmallVector<linalg::FillOp> fillOps;
310 fillOps.reserve(op.getNumDpsInits());
311 for (auto it : llvm::zip(op.getDpsInitsMutable(), neutralElements)) {
312 Value rankedTensor = std::get<0>(it).get();
313 auto t = cast<RankedTensorType>(rankedTensor.getType());
314 RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
315 reductionDimSize / splitFactor, insertSplitDimension);
316 SmallVector<Value> dims =
317 tensor::createDynamicDimValues(b, loc, rankedTensor);
318 Value emptyOrAllocTensor;
319 if (useAlloc) {
320 emptyOrAllocTensor =
321 b.create<bufferization::AllocTensorOp>(loc, newT, dims);
322 } else {
323 emptyOrAllocTensor = b.create<tensor::EmptyOp>(loc, newT.getShape(),
324 t.getElementType(), dims);
325 }
326 Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
327 fillOps.push_back(
328 b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor));
329 newOutputs.push_back(fillOps.back().getResult(0));
330 emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.getDefiningOp());
331 }
332
333 // Step 2. Reindex / expand indexing maps.
334 // Reindex existing input indexings: k -> k * splitFactor + k'.
335 SmallVector<AffineMap> newMaps;
336 newMaps.reserve(op->getNumOperands() + 1);
337 for (OpOperand *o : op.getDpsInputOperands())
338 newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
339 // Provision a new indexing for the shape-only tensor.
340 auto nDims = op.getNumLoops() + 1;
341 auto redDim = getAffineDimExpr(reductionDimPos, context);
342 auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context);
343 newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
344 // Expand existing output indexings.
345 // TODO: a subset of these may not reduce along reducePos and should be
346 // reindexed: k -> k * splitFactor + k', when multi-reduction support is
347 // available.
348 for (OpOperand &o : op.getDpsInitsMutable())
349 newMaps.push_back(insertParallelDim(op, o, reductionDimPos,
350 reductionDimSize / splitFactor));
351
352 // Step 3. Handle operands.
353 // Compute the new input tensors.
354 SmallVector<Value> newInputs = op.getDpsInputs();
355 // Add a single shape-only tensor to carry the dimensions without resorting to
356 // more complex inversions.
357 newInputs.push_back(b.create<tensor::EmptyOp>(
358 loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
359 b.getIntegerType(1)));
360 // Output tensors are already good to go.
361
362 // Step 4. Create the new op matching the original op with an extra parallel
363 // dimension.
364 auto iteratorTypes = op.getIteratorTypesArray();
365 iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
366 utils::IteratorType::parallel);
367 GenericOp genericOp =
368 b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
369 newOutputs, newMaps, iteratorTypes);
370 b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
371 genericOp.getRegion().begin());
372 genericOp.getRegion().front().insertArgument(reductionDimPos,
373 b.getIntegerType(1), loc);
374
375 // Step 5. Create new reduction ops that only reduce the newly added
376 // dimensions from the previous op.
377 // For now assume outputs are 1-1 with reduction ops.
378 // TODO: a subset of these may not reduce in the first place and do not
379 // require a new op, when multi-reduction support is available.
380 // TODO: all results can be handled in a single GenericOp, when
381 // multi-reduction support is available.
382 SmallVector<LinalgOp> results;
383 for (auto it :
384 llvm::zip(genericOp->getResults(), op.getDpsInits(), combinerOps)) {
385 Value reindexedOutput = std::get<0>(it);
386 Value originalOutput = std::get<1>(it);
387 auto originalOutputType = cast<RankedTensorType>(originalOutput.getType());
388 Operation *combinerOp = std::get<2>(it);
389
390 AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
391 SmallVector<AffineMap> indexingMaps = {
392 map, map.dropResult(insertSplitDimension)};
393 SmallVector<utils::IteratorType> reductionIteratorTypes(
394 originalOutputType.getRank() + 1, utils::IteratorType::parallel);
395 reductionIteratorTypes[insertSplitDimension] =
396 utils::IteratorType::reduction;
397
398 // clang-format off
399 auto reductionOp = b.create<GenericOp>(
400 loc,
401 originalOutputType,
402 reindexedOutput,
403 originalOutput,
404 indexingMaps,
405 reductionIteratorTypes,
406 [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) {
407 Operation *clonedReductionOp = b.clone(*combinerOp);
408 clonedReductionOp->setOperand(0, bbArgs[0]);
409 clonedReductionOp->setOperand(1, bbArgs[1]);
410 b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
411 });
412 // clang-format on
413
414 results.push_back(reductionOp);
415 }
416
417 // TODO: extend when multi-reduction support is available.
418 assert(fillOps.size() == results.size() && results.size() == 1);
419 b.replaceOp(op, results.front()->getResults());
420 return SplitReductionResult{emptyOrAllocTensorOps.front(), fillOps.front(),
421 cast<LinalgOp>(genericOp.getOperation()),
422 results.front()};
423 }
424
425 namespace {
426
427 struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
428 /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgSplitReduction__anonaac8c2be0411::LinalgSplitReduction429 LinalgSplitReduction(MLIRContext *context,
430 ControlSplitReductionFn controlSplitReductionFn,
431 bool useAlloc = false, PatternBenefit benefit = 1)
432 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
433 controlSplitReductionFn(std::move(controlSplitReductionFn)),
434 useAlloc(useAlloc) {}
435
matchAndRewrite__anonaac8c2be0411::LinalgSplitReduction436 LogicalResult matchAndRewrite(LinalgOp op,
437 PatternRewriter &rewriter) const override {
438 return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc);
439 }
440
441 private:
442 ControlSplitReductionFn controlSplitReductionFn;
443 bool useAlloc;
444 };
445
446 } // namespace
447
populateSplitReductionPattern(RewritePatternSet & patterns,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)448 void linalg::populateSplitReductionPattern(
449 RewritePatternSet &patterns,
450 const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
451 patterns.add<LinalgSplitReduction>(patterns.getContext(),
452 controlSplitReductionFn, useAlloc);
453 }
454