1 //===- Split.cpp - Structured op splitting --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 11 #include "mlir/Dialect/Utils/StaticValueUtils.h" 12 #include "mlir/IR/AffineExpr.h" 13 #include "mlir/IR/Attributes.h" 14 #include "mlir/IR/BuiltinAttributes.h" 15 #include "mlir/IR/OpDefinition.h" 16 #include "mlir/Interfaces/TilingInterface.h" 17 18 #include "llvm/ADT/STLExtras.h" 19 #include "llvm/ADT/SmallVector.h" 20 21 using namespace mlir; 22 using namespace mlir::linalg; 23 24 /// Creates a part of the given `op` split along the iteration space `dimension` 25 /// with the given `size` and an optional `offset` (default 0). Makes slices 26 /// of operands, using the input operands of the original op and the output 27 /// operands provided as `resultOperands`. Expects `offsets` and `sizes` to 28 /// define the shape of the iteration space of the original op. Returns the 29 /// split-out op as well as the output operand values updated with the partial 30 /// results produced by this op through `results`. 31 static TilingInterface 32 createSplitPart(RewriterBase &b, Location loc, TilingInterface op, 33 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 34 ValueRange resultOperands, unsigned dimension, 35 OpFoldResult size, OpFoldResult offset, 36 SmallVectorImpl<Value> &results) { 37 // Iteration space of the current part. 38 SmallVector<OpFoldResult> sizesCopy = llvm::to_vector(sizes); 39 SmallVector<OpFoldResult> offsetsCopy = llvm::to_vector(offsets); 40 sizesCopy[dimension] = size; 41 offsetsCopy[dimension] = offset; 42 43 // Create the part as if it were a single tile. 44 FailureOr<TilingResult> tilingResult = 45 op.getTiledImplementation(b, offsetsCopy, sizesCopy); 46 47 // Insert the results back and populate the `results` list. 48 for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) { 49 SmallVector<OpFoldResult> resultOffsets, resultSizes; 50 if (failed(op.getResultTilePosition(b, index, offsetsCopy, sizesCopy, 51 resultOffsets, resultSizes))) 52 return nullptr; 53 SmallVector<OpFoldResult> resultStrides(resultOffsets.size(), 54 b.getIndexAttr(1)); 55 Value inserted = b.create<tensor::InsertSliceOp>( 56 loc, result, resultOperands[index], resultOffsets, resultSizes, 57 resultStrides); 58 results.push_back(inserted); 59 } 60 // TODO: this part can be generalized maybe to not expect a single op. 61 assert(tilingResult->tiledOps.size() == 1 && 62 "expected split part to return a single tiled operation"); 63 return cast<TilingInterface>(tilingResult->tiledOps[0]); 64 } 65 66 std::pair<TilingInterface, TilingInterface> 67 linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, 68 OpFoldResult splitPoint) { 69 // Compute the iteration space. 70 SmallVector<Range> iterationSpace = op.getIterationDomain(rewriter); 71 72 // Bail out on dimension overflow. 73 if (dimension >= iterationSpace.size()) 74 return std::make_pair(op, TilingInterface()); 75 76 SmallVector<OpFoldResult> offsets = llvm::to_vector(llvm::map_range( 77 iterationSpace, [](const Range &range) { return range.offset; })); 78 SmallVector<OpFoldResult> sizes = llvm::to_vector(llvm::map_range( 79 iterationSpace, [](const Range &range) { return range.size; })); 80 81 // Adjust the split point so that it doesn't overflow the size. 82 AffineExpr d0, d1, d2; 83 bindDims(rewriter.getContext(), d0, d1, d2); 84 OpFoldResult minSplitPoint = affine::makeComposedFoldedAffineMin( 85 rewriter, op.getLoc(), 86 AffineMap::inferFromExprList(ArrayRef<AffineExpr>{d0, d1 + d2}, 87 rewriter.getContext()) 88 .front(), 89 {splitPoint, offsets[dimension], sizes[dimension]}); 90 91 // Compute the size of the second part. Return early if the second part would 92 // have an empty iteration space. 93 OpFoldResult remainingSize = affine::makeComposedFoldedAffineApply( 94 rewriter, op.getLoc(), d0 + d1 - d2, 95 {iterationSpace[dimension].offset, iterationSpace[dimension].size, 96 minSplitPoint}); 97 if (auto attr = llvm::dyn_cast_if_present<Attribute>(remainingSize)) { 98 if (cast<IntegerAttr>(attr).getValue().isZero()) 99 return {op, TilingInterface()}; 100 } 101 102 // Compute destination tensors. 103 SmallVector<Value> destinationTensors; 104 LogicalResult destStatus = tensor::getOrCreateDestinations( 105 rewriter, op.getLoc(), op, destinationTensors); 106 (void)destStatus; 107 assert(succeeded(destStatus) && "failed to get destination tensors"); 108 109 // Create the first part. 110 SmallVector<Value> firstResults; 111 TilingInterface firstPart = createSplitPart( 112 rewriter, op.getLoc(), op, offsets, sizes, destinationTensors, dimension, 113 minSplitPoint, iterationSpace[dimension].offset, firstResults); 114 115 // Need to pretend that the original op now takes as operands firstResults, 116 // otherwise tiling interface implementation will take the wrong value to 117 // produce data tiles. 118 rewriter.modifyOpInPlace(op, [&]() { 119 unsigned numTotalOperands = op->getNumOperands(); 120 unsigned numOutputOperands = firstResults.size(); 121 op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands, 122 firstResults); 123 }); 124 125 // Create the second part. 126 OpFoldResult totalOffset = affine::makeComposedFoldedAffineApply( 127 rewriter, op.getLoc(), d0 + d1, {offsets[dimension], minSplitPoint}); 128 SmallVector<Value> secondResults; 129 TilingInterface secondPart = 130 createSplitPart(rewriter, op.getLoc(), op, offsets, sizes, firstResults, 131 dimension, remainingSize, totalOffset, secondResults); 132 133 // Propagate any errors in part creation. 134 if (!firstPart || !secondPart) 135 return {TilingInterface(), TilingInterface()}; 136 137 // Replace the original op with the results of the two newly created ops. 138 rewriter.replaceOp(op, secondResults); 139 return {firstPart, secondPart}; 140 } 141