xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- Specialize.cpp - linalg generic ops to named ops  ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a method to specialize generic operations to named
10 // operations. Conceptually it is the opposite of generalize.cpp.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Complex/IR/Complex.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/Math/IR/Math.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Support/TypeID.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "llvm/Support/Debug.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
27 #include "mlir/Dialect/Linalg/Passes.h.inc"
28 } // namespace mlir
29 
30 #define DEBUG_TYPE "linalg-specialization"
31 
32 #define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)                                \
33   (rewriter.replaceOpWithNewOp<NEWOP>(                                         \
34       genericOp,                                                               \
35       ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0],            \
36                  genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]},           \
37       ValueRange{genericOp.getDpsInits()[0]}))
38 
39 #define REPLACE_UNARY_OP(NEWOP)                                                \
40   (rewriter.replaceOpWithNewOp<NEWOP>(genericOp,                               \
41                                       ValueRange{genericOp.getDpsInputs()[0]}, \
42                                       ValueRange{genericOp.getDpsInits()[0]}))
43 
44 using namespace mlir;
45 using namespace mlir::linalg;
46 
47 // Given a elementwise single binary linalg generic op, checks whether the
48 // binary op accesses operands as swapped. e.g.
49 // this differentiates between a linalg-generic body that contains:
50 //    ^bb0(%a: f32, %b: f32, %c : f32):
51 //         %0 = arith.subf %a, %b : f32
52 //         linalg.yield %0: f32
53 // against:
54 //    ^bb0(%a: f32, %b: f32, %c : f32):
55 //         %0 = arith.subf %b, %a : f32
56 //         linalg.yield %0: f32
57 // Former is linalg.sub(a,b), latter is linalg.sub(b,a).
58 static bool areBinOpsSwapped(GenericOp genericOp) {
59   Block *body = genericOp.getBody();
60   Operation *op = &body->front();
61   bool swapped = false;
62   if (op->getOpOperand(0).get() != body->getArgument(0)) {
63     swapped = true;
64     assert(op->getOpOperand(0).get() == body->getArgument(1) &&
65            op->getOpOperand(1).get() == body->getArgument(0) &&
66            "binary op uses just one block arg");
67   }
68   return swapped;
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // Specialize linalg generic to matmul variants.
73 //===----------------------------------------------------------------------===//
74 /// Identifies linalg.generic that is essentially named op of the form:
75 //    ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
76 //
77 // It is possible that a linalg.generic may be implementing a matmul but not
78 // in a straight-forward way e.g. below is matrix multiply over some slice
79 // ```
80 //  %0 = linalg.generic {
81 //          indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
82 //                           affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
83 //                           affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
84 //          iterator_types = ["parallel", "parallel", "parallel"]}
85 //          ins(%A, %B : tensor<20x20x20xf32>,  tensor<20x20x20xf32>)
86 //          outs(%C : tensor<20x20x20xf32>) {
87 //             ^bb0(%a: f32, %b: f32, %c : f32):
88 //                %mul = arith.mulf %a, %b : f32
89 //                %add = arith.addf %mul, %c : f32
90 //                linalg.yield %add : f32
91 //       } -> tensor<20x20x20xf32>
92 // ```
93 // It is not possible to represent above as named op.
94 // e.g. linalg.batch_matmul(%A, %B :  tensor<20x20x20xf32>, ...) is
95 // not  the same as linalg.generic above.
96 namespace {
97 enum class IndexMatchResult {
98   Match = 0,  // identity map.
99   Transposed, // transposed map.
100   Mismatch    // none of the above.
101 };
102 
103 // Checks whether the input Affine `map` contains two consecutive dims that
104 // can be interpreted as accessing a 2D matrix. It is assumed that the row
105 // column dimension are adjacent axis (in this order) and start at
106 // `rowDimIdx` in the input map.
107 //
108 //  e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
109 //  whether the map of A is identity (match), transposed, or something
110 //  completely different (mis-match). Similar for B and C.
111 static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
112                                         unsigned expectedPosOfRowDim,
113                                         unsigned expectedPosOfColDim) {
114   // Get the matrix multiply indices. They are past the batch indices.
115   auto exprOfRowDim = map.getResults()[rowDimIdx];
116   auto exprOfColDim = map.getResults()[rowDimIdx + 1];
117 
118   // They should be pure dimension ids.
119   if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
120       exprOfColDim.getKind() != AffineExprKind::DimId)
121     return IndexMatchResult::Mismatch;
122 
123   auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
124   auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
125 
126   if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
127     return IndexMatchResult::Match;
128 
129   if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
130     return IndexMatchResult::Transposed;
131 
132   return IndexMatchResult::Mismatch;
133 }
134 
135 // Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
136 //  All the variants expressed as pseudo regular expression:
137 //      `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
138 //  have same number of ins/out, so its easy to stamp different versions.
139 template <typename NamedOpTy>
140 static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
141   LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
142       op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
143       ValueRange{op.getDpsInits()[0]});
144   return namedOp;
145 }
146 
147 // Converts linalg.generic to named linalg.*matmul* where possible.
148 static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
149                                                         GenericOp genericOp) {
150   if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
151     return failure();
152 
153   // Early exit if not projected permutations.
154   auto mapRange = genericOp.getIndexingMapsArray();
155   if (llvm::any_of(mapRange,
156                    [](AffineMap m) { return !m.isProjectedPermutation(); }))
157     return failure();
158 
159   // Linalg generic contraction can be across multiple axis e.g.
160   // ```
161   //      linalg.generic
162   //           {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
163   //                             affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
164   //                             affine_map<(m, n, k1, k2) -> (m, n)>],
165   //           iterator_types = ["parallel", "parallel",
166   //                             "reduction", "reduction"]}
167   //           ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
168   //           outs(%C : tensor<10x40xf32>) {
169   //           ^bb0(%a: f32, %b: f32, %c: f32):
170   //                 %1 = arith.mulf %a, %b : f32
171   //                 %2 = arith.addf %c, %1 : f32
172   //                 linalg.yield %2 : f32
173   //      } -> tensor<10x40xf32>
174   //  ```
175   //  In above contraction, there are two reduction dimensions {k1, k2}
176   //  and although a valid linalg contraction, it is not a named-op
177   //  matrix multiply kind. Therefore, reject multi-dim reduction.
178   auto res = inferContractionDims(genericOp);
179   if (!succeeded(res))
180     return failure();
181   auto dims = *res;
182   if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
183     return failure();
184 
185   if (!mlir::linalg::detail::isContractionBody(
186           *genericOp.getBlock(), [](Operation *first, Operation *second) {
187             if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
188                 (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
189                 (isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
190               return true;
191             return false;
192           }))
193     return failure();
194 
195   // Check rank of operands
196   auto indexingMaps = genericOp.getIndexingMapsArray();
197   if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
198         return m.getResults().size() !=
199                dims.batch.size() + 2 /* any two of {m,n,k} */;
200       }))
201     return failure();
202 
203   auto numOfBatchDims = dims.batch.size();
204   if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
205     return failure();
206 
207   if (numOfBatchDims) {
208     // Each operand in a linalg generic contraction  could express different
209     // permutations for its batch dimension. But for named op it must be
210     // identity since separate maps are not specified.
211     if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
212           for (unsigned i = 0; i < numOfBatchDims; ++i) {
213             auto expr = m.getResults()[i];
214             if (expr.getKind() != AffineExprKind::DimId ||
215                 cast<AffineDimExpr>(expr).getPosition() != i)
216               return true;
217           }
218           return false;
219         }))
220       return failure();
221   }
222 
223   auto a =
224       matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
225   auto b =
226       matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
227   auto c =
228       matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
229 
230   if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch))
231     return failure();
232 
233   if (c != IndexMatchResult::Match ||
234       (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
235     return failure();
236 
237   /// Codegen the different matmul variants.
238   if (numOfBatchDims) {
239     if (a == IndexMatchResult::Transposed)
240       return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
241                                                                genericOp);
242     if (b == IndexMatchResult::Transposed)
243       return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
244                                                                genericOp);
245     return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
246   }
247 
248   if (a == IndexMatchResult::Transposed)
249     return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
250   if (b == IndexMatchResult::Transposed)
251     return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
252   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
253 }
254 
255 } // namespace
256 
257 //===----------------------------------------------------------------------===//
258 // Categorize linalg generic to named op where possible.
259 //===----------------------------------------------------------------------===//
260 FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
261                                                       GenericOp genericOp) {
262   // Copy
263   if (isaCopyOpInterface(genericOp)) {
264     LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
265         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
266     return namedOp;
267   }
268 
269   // Fill
270   if (isaFillOpInterface(genericOp)) {
271     LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
272         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
273     return namedOp;
274   }
275 
276   // Broadcast
277   std::optional<SmallVector<int64_t>> equivalentToBroadcast =
278       isaBroadcastOpInterface(genericOp);
279   if (equivalentToBroadcast) {
280     auto dims = *equivalentToBroadcast;
281     LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
282         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
283         dims);
284     return namedOp;
285   }
286 
287   // Transpose
288   std::optional<SmallVector<int64_t>> equivalentToTranspose =
289       isaTransposeOpInterface(genericOp);
290   if (equivalentToTranspose) {
291     auto permutation = *equivalentToTranspose;
292     LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
293         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
294         permutation);
295     return namedOp;
296   }
297 
298   // Elementwise Unary
299   if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
300     Operation *op = &genericOp.getBody()->front();
301     if (isa<math::ExpOp>(op)) {
302       LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
303       return namedOp;
304     }
305   }
306 
307   // Elementwise Binary
308   if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
309     bool swap = areBinOpsSwapped(genericOp);
310     Operation *op = &genericOp.getBody()->front();
311     if (isa<arith::AddFOp>(op)) {
312       LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
313       return namedOp;
314     }
315     if (isa<arith::SubFOp>(op)) {
316       LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
317       return namedOp;
318     }
319     if (isa<arith::MulFOp>(op)) {
320       LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
321       return namedOp;
322     }
323     if (isa<arith::DivFOp>(op)) {
324       LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
325       return namedOp;
326     }
327   }
328 
329   // Contraction - e.g. matmul
330   if (isaContractionOpInterface(genericOp)) {
331     return specializeLinalgContractions(rewriter, genericOp);
332   }
333   return failure();
334 }
335 
336 namespace {
337 struct LinalgSpecializeGenericOpsPass
338     : public impl::LinalgSpecializeGenericOpsPassBase<
339           LinalgSpecializeGenericOpsPass> {
340 
341   using impl::LinalgSpecializeGenericOpsPassBase<
342       LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
343   void runOnOperation() override;
344 };
345 } // namespace
346 
347 void LinalgSpecializeGenericOpsPass::runOnOperation() {
348   RewritePatternSet patterns(&getContext());
349   populateLinalgGenericOpsSpecializationPatterns(patterns);
350   populateDecomposeProjectedPermutationPatterns(patterns);
351 
352   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
353     signalPassFailure();
354 }
355 
356 void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
357     RewritePatternSet &patterns) {
358   patterns.add<LinalgSpecializationPattern>(patterns.getContext());
359 }
360