xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp (revision 9cbc1f29cabc01c02a523c11d098c00650f6955c)
1 //===- BlockPackMatmul.cpp - Linalg matmul block packing ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Passes.h"
10 
11 #include "mlir/Dialect/Linalg/IR/Linalg.h"
12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13 #include "mlir/Dialect/Linalg/Utils/Utils.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 
19 #include <optional>
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
23 #include "mlir/Dialect/Linalg/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 using namespace mlir::linalg;
28 
29 /// Return constant range span or nullopt, otherwise.
30 static std::optional<int64_t> getConstantRange(const Range &range) {
31   std::optional<int64_t> stride = getConstantIntValue(range.stride);
32   if (!stride || *stride != 1)
33     return std::nullopt;
34   std::optional<int64_t> offset = getConstantIntValue(range.offset);
35   if (!offset)
36     return std::nullopt;
37   std::optional<int64_t> size = getConstantIntValue(range.size);
38   if (!size)
39     return std::nullopt;
40   return (*size - *offset);
41 }
42 
43 /// Return true if all dimensions are fully divisible by the respective tiles.
44 static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
45                                     ArrayRef<OpFoldResult> tiles,
46                                     ArrayRef<int64_t> dims) {
47   if (dims.size() != tiles.size() || tiles.empty())
48     return false;
49 
50   FailureOr<ContractionDimensions> contractDims =
51       inferContractionDims(linalgOp);
52   if (failed(contractDims))
53     return false;
54   unsigned batchDimsOffset = contractDims->batch.size();
55 
56   // Skip the batch dimension if present.
57   // Offset all dimensions accordingly.
58   SmallVector<int64_t, 3> offsetDims(dims);
59   for (size_t i = 0; i < offsetDims.size(); i++)
60     offsetDims[i] += batchDimsOffset;
61 
62   auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
63   OpBuilder builder(tileOp);
64   OpBuilder::InsertionGuard guard(builder);
65   SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder);
66 
67   for (auto dim : llvm::enumerate(offsetDims)) {
68     if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
69       return false;
70 
71     std::optional<int64_t> tileSize = getConstantIntValue(tiles[dim.index()]);
72     std::optional<int64_t> rangeOnDim =
73         getConstantRange(iterationDomain[dim.value()]);
74 
75     // If the tile factor or the range are non-constant, the tile size is
76     // considered to be invalid.
77     if (!tileSize || !rangeOnDim)
78       return false;
79 
80     // The dimension must be fully divisible by the tile.
81     if (*rangeOnDim % *tileSize != 0)
82       return false;
83   }
84 
85   return true;
86 }
87 
88 /// Return failure or packed matmul with one of its operands transposed.
89 static FailureOr<PackTransposeResult>
90 transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
91                       tensor::PackOp packOp, AffineMap operandMap,
92                       ArrayRef<unsigned> blocksStartDimPos,
93                       bool transposeOuterBlocks, bool transposeInnerBlocks) {
94   assert(operandMap.getNumDims() >= 4 &&
95          "expected at least 4D prepacked matmul");
96   assert(blocksStartDimPos.size() >= 2 &&
97          "expected starting outer and inner block positions");
98 
99   // Bias toward innermost dimensions.
100   unsigned outerBlockPos = operandMap.getNumResults() - 4;
101   unsigned innerBlockPos = operandMap.getNumResults() - 2;
102 
103   // Transpose control options define the desired block and element layout.
104   // Block transposition (outer dimensions) or element transposition (inner
105   // dimensions) may not be necessary depending on the original matmul data
106   // layout.
107   bool isOuterTransposed =
108       operandMap.getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2];
109   bool isInnerTransposed =
110       operandMap.getDimPosition(innerBlockPos) != blocksStartDimPos.back();
111 
112   // Transpose only the dimensions that need that to conform to the provided
113   // transpotion settings.
114   SmallVector<int64_t> innerPerm = {0, 1};
115   if (isInnerTransposed != transposeInnerBlocks)
116     innerPerm = {1, 0};
117   SmallVector<int64_t> outerPerm = {0, 1};
118   if (isOuterTransposed != transposeOuterBlocks)
119     outerPerm = {1, 0};
120 
121   // Leave the outer dimensions, like batch, unchanged by offsetting all
122   // outer dimensions permutations.
123   SmallVector<int64_t> offsetPerms;
124   for (auto i : llvm::seq(0u, outerBlockPos))
125     offsetPerms.push_back(i);
126   for (auto perm : outerPerm)
127     offsetPerms.push_back(perm + outerBlockPos);
128   outerPerm = offsetPerms;
129 
130   FailureOr<PackTransposeResult> packTransposedMatmul =
131       packTranspose(rewriter, packOp, linalgOp,
132                     /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
133 
134   return packTransposedMatmul;
135 }
136 
137 /// Pack a matmul operation into blocked 4D layout.
138 FailureOr<PackResult>
139 linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
140                         const ControlBlockPackMatmulFn &controlPackMatmul) {
141   if (linalgOp.hasPureBufferSemantics())
142     return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
143 
144   std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp);
145   if (!options)
146     return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
147 
148   if (options->blockFactors.size() != 3)
149     return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
150 
151   SmallVector<OpFoldResult> mnkTiles =
152       getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
153 
154   // If padding is disabled, make sure that dimensions can be packed cleanly.
155   if (!options->allowPadding &&
156       !validateFullTilesOnDims(linalgOp, mnkTiles, options->mnkOrder)) {
157     return rewriter.notifyMatchFailure(linalgOp,
158                                        "expect packing full tiles only");
159   }
160 
161   OpBuilder::InsertionGuard guard(rewriter);
162   // The op is replaced, we need to set the insertion point after it.
163   rewriter.setInsertionPointAfter(linalgOp);
164 
165   // Pack the matmul operation into blocked layout with two levels of
166   // subdivision:
167   //   - major 2D blocks - outer dimensions, consist of minor blocks
168   //   - minor 2D blocks - inner dimensions, consist of scalar elements
169   FailureOr<PackResult> packedMatmul = packMatmulGreedily(
170       rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
171       options->mnkOrder);
172   if (failed(packedMatmul))
173     return failure();
174 
175   assert(packedMatmul->packOps.size() == 3 &&
176          "invalid number of pack ops after matmul packing");
177   assert(packedMatmul->unPackOps.size() == 1 &&
178          "invalid number of unpack ops after matmul packing");
179 
180   FailureOr<ContractionDimensions> contractDims =
181       inferContractionDims(packedMatmul->packedLinalgOp);
182   if (failed(contractDims))
183     return failure();
184 
185   auto genericOp =
186       dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
187   SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray();
188 
189   // Transpose LHS matrix according to the options.
190   FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul(
191       rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
192       contractDims->m, options->lhsTransposeOuterBlocks,
193       options->lhsTransposeInnerBlocks);
194   if (failed(packedLhs))
195     return failure();
196 
197   // Update results.
198   packedMatmul->packOps[0] = packedLhs->transposedPackOp;
199   packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
200 
201   // Transpose RHS matrix according to the options.
202   FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul(
203       rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
204       contractDims->k, options->rhsTransposeOuterBlocks,
205       options->rhsTransposeInnerBlocks);
206   if (failed(packedRhs))
207     return failure();
208 
209   // Update results.
210   packedMatmul->packOps[1] = packedRhs->transposedPackOp;
211   packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
212 
213   return packedMatmul;
214 }
215 
216 namespace {
217 template <typename OpTy>
218 struct BlockPackMatmul : public OpRewritePattern<OpTy> {
219   BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
220                   PatternBenefit benefit = 1)
221       : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
222 
223   LogicalResult matchAndRewrite(OpTy linalgOp,
224                                 PatternRewriter &rewriter) const override {
225     FailureOr<PackResult> packedMatmul =
226         blockPackMatmul(rewriter, linalgOp, controlFn);
227     if (failed(packedMatmul))
228       return failure();
229     return success();
230   }
231 
232 private:
233   ControlBlockPackMatmulFn controlFn;
234 };
235 
236 template <>
237 struct BlockPackMatmul<linalg::GenericOp>
238     : public OpRewritePattern<linalg::GenericOp> {
239   BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
240                   PatternBenefit benefit = 1)
241       : OpRewritePattern<linalg::GenericOp>(context, benefit),
242         controlFn(std::move(fun)) {}
243 
244   LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
245                                 PatternRewriter &rewriter) const override {
246     // Match suitable generics.
247     if (!linalg::isaContractionOpInterface(linalgOp)) {
248       return rewriter.notifyMatchFailure(linalgOp, "not a contraction");
249     }
250 
251     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
252     auto infer = [&](MapList m) {
253       return AffineMap::inferFromExprList(m, linalgOp.getContext());
254     };
255 
256     AffineExpr i, j, k;
257     bindDims(linalgOp->getContext(), i, j, k);
258     SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
259 
260     // For now, only match simple matmuls.
261     if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
262           maps == infer({{k, i}, {k, j}, {i, j}}) ||
263           maps == infer({{i, k}, {j, k}, {i, j}}))) {
264       return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul");
265     }
266 
267     FailureOr<PackResult> packedMatmul =
268         blockPackMatmul(rewriter, linalgOp, controlFn);
269     if (failed(packedMatmul))
270       return failure();
271     return success();
272   }
273 
274 private:
275   ControlBlockPackMatmulFn controlFn;
276 };
277 
278 /// Convert linalg matmul ops to block layout and back.
279 struct LinalgBlockPackMatmul
280     : public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
281   using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
282 
283   void runOnOperation() override {
284     Operation *op = getOperation();
285     RewritePatternSet patterns(&getContext());
286 
287     ControlBlockPackMatmulFn controlFn =
288         [&](linalg::LinalgOp op) -> BlockPackMatmulOptions {
289       BlockPackMatmulOptions options;
290       options.blockFactors = SmallVector<int64_t>{*blockFactors};
291       options.allowPadding = allowPadding;
292       options.mnkPaddedSizesNextMultipleOf =
293           SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
294       if (!mnkOrder.empty())
295         options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
296       options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
297       options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
298       options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
299       options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
300       return options;
301     };
302 
303     linalg::populateBlockPackMatmulPatterns(patterns, controlFn);
304     if (failed(applyPatternsGreedily(op, std::move(patterns))))
305       return signalPassFailure();
306   }
307 };
308 } // namespace
309 
310 void linalg::populateBlockPackMatmulPatterns(
311     RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
312   patterns.add<BlockPackMatmul<linalg::GenericOp>,
313                BlockPackMatmul<linalg::MatmulOp>,
314                BlockPackMatmul<linalg::BatchMatmulOp>,
315                BlockPackMatmul<linalg::MatmulTransposeAOp>,
316                BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
317                BlockPackMatmul<linalg::MatmulTransposeBOp>,
318                BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
319       patterns.getContext(), controlFn);
320 }
321