xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Split.cpp (revision 63f5b80fcd94ca30a29677ad9431c4f743b61d74)
1 //===- Split.cpp - Structured op splitting --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
11 #include "mlir/Dialect/Utils/StaticValueUtils.h"
12 #include "mlir/IR/AffineExpr.h"
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/BuiltinAttributes.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/Interfaces/TilingInterface.h"
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 
21 using namespace mlir;
22 using namespace mlir::linalg;
23 
24 /// Creates a part of the given `op` split along the iteration space `dimension`
25 /// with the given `size` and an optional `offset` (default 0). Makes slices
26 /// of operands, using the input operands of the original op and the output
27 /// operands provided as `resultOperands`. Expects `offsets` and `sizes` to
28 /// define the shape of the iteration space of the original op. Returns the
29 /// split-out op as well as the output operand values updated with the partial
30 /// results produced by this op through `results`.
31 static TilingInterface
32 createSplitPart(RewriterBase &b, Location loc, TilingInterface op,
33                 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
34                 ValueRange resultOperands, unsigned dimension,
35                 OpFoldResult size, OpFoldResult offset,
36                 SmallVectorImpl<Value> &results) {
37   // Iteration space of the current part.
38   SmallVector<OpFoldResult> sizesCopy = llvm::to_vector(sizes);
39   SmallVector<OpFoldResult> offsetsCopy = llvm::to_vector(offsets);
40   sizesCopy[dimension] = size;
41   offsetsCopy[dimension] = offset;
42 
43   // Create the part as if it were a single tile.
44   FailureOr<TilingResult> tilingResult =
45       op.getTiledImplementation(b, offsetsCopy, sizesCopy);
46 
47   // Insert the results back and populate the `results` list.
48   for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) {
49     SmallVector<OpFoldResult> resultOffsets, resultSizes;
50     if (failed(op.getResultTilePosition(b, index, offsetsCopy, sizesCopy,
51                                         resultOffsets, resultSizes)))
52       return nullptr;
53     SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
54                                             b.getIndexAttr(1));
55     Value inserted = b.create<tensor::InsertSliceOp>(
56         loc, result, resultOperands[index], resultOffsets, resultSizes,
57         resultStrides);
58     results.push_back(inserted);
59   }
60   // TODO: this part can be generalized maybe to not expect a single op.
61   assert(tilingResult->tiledOps.size() == 1 &&
62          "expected split part to return a single tiled operation");
63   return cast<TilingInterface>(tilingResult->tiledOps[0]);
64 }
65 
66 std::pair<TilingInterface, TilingInterface>
67 linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
68                 OpFoldResult splitPoint) {
69   // Compute the iteration space.
70   SmallVector<Range> iterationSpace = op.getIterationDomain(rewriter);
71 
72   // Bail out on dimension overflow.
73   if (dimension >= iterationSpace.size())
74     return std::make_pair(op, TilingInterface());
75 
76   SmallVector<OpFoldResult> offsets = llvm::to_vector(llvm::map_range(
77       iterationSpace, [](const Range &range) { return range.offset; }));
78   SmallVector<OpFoldResult> sizes = llvm::to_vector(llvm::map_range(
79       iterationSpace, [](const Range &range) { return range.size; }));
80 
81   // Adjust the split point so that it doesn't overflow the size.
82   AffineExpr d0, d1, d2;
83   bindDims(rewriter.getContext(), d0, d1, d2);
84   OpFoldResult minSplitPoint = affine::makeComposedFoldedAffineMin(
85       rewriter, op.getLoc(),
86       AffineMap::inferFromExprList(ArrayRef<AffineExpr>{d0, d1 + d2},
87                                    rewriter.getContext())
88           .front(),
89       {splitPoint, offsets[dimension], sizes[dimension]});
90 
91   // Compute the size of the second part. Return early if the second part would
92   // have an empty iteration space.
93   OpFoldResult remainingSize = affine::makeComposedFoldedAffineApply(
94       rewriter, op.getLoc(), d0 + d1 - d2,
95       {iterationSpace[dimension].offset, iterationSpace[dimension].size,
96        minSplitPoint});
97   if (auto attr = llvm::dyn_cast_if_present<Attribute>(remainingSize)) {
98     if (cast<IntegerAttr>(attr).getValue().isZero())
99       return {op, TilingInterface()};
100   }
101 
102   // Compute destination tensors.
103   SmallVector<Value> destinationTensors;
104   LogicalResult destStatus = tensor::getOrCreateDestinations(
105       rewriter, op.getLoc(), op, destinationTensors);
106   (void)destStatus;
107   assert(succeeded(destStatus) && "failed to get destination tensors");
108 
109   // Create the first part.
110   SmallVector<Value> firstResults;
111   TilingInterface firstPart = createSplitPart(
112       rewriter, op.getLoc(), op, offsets, sizes, destinationTensors, dimension,
113       minSplitPoint, iterationSpace[dimension].offset, firstResults);
114 
115   // Need to pretend that the original op now takes as operands firstResults,
116   // otherwise tiling interface implementation will take the wrong value to
117   // produce data tiles.
118   rewriter.modifyOpInPlace(op, [&]() {
119     unsigned numTotalOperands = op->getNumOperands();
120     unsigned numOutputOperands = firstResults.size();
121     op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands,
122                     firstResults);
123   });
124 
125   // Create the second part.
126   OpFoldResult totalOffset = affine::makeComposedFoldedAffineApply(
127       rewriter, op.getLoc(), d0 + d1, {offsets[dimension], minSplitPoint});
128   SmallVector<Value> secondResults;
129   TilingInterface secondPart =
130       createSplitPart(rewriter, op.getLoc(), op, offsets, sizes, firstResults,
131                       dimension, remainingSize, totalOffset, secondResults);
132 
133   // Propagate any errors in part creation.
134   if (!firstPart || !secondPart)
135     return {TilingInterface(), TilingInterface()};
136 
137   // Replace the original op with the results of the two newly created ops.
138   rewriter.replaceOp(op, secondResults);
139   return {firstPart, secondPart};
140 }
141