xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
16cbcb793Slorenzo chelini //===- Specialize.cpp - linalg generic ops to named ops  ------------------===//
26cbcb793Slorenzo chelini //
36cbcb793Slorenzo chelini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
46cbcb793Slorenzo chelini // See https://llvm.org/LICENSE.txt for license information.
56cbcb793Slorenzo chelini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66cbcb793Slorenzo chelini //
76cbcb793Slorenzo chelini //===----------------------------------------------------------------------===//
86cbcb793Slorenzo chelini //
96cbcb793Slorenzo chelini // This file implements a method to specialize generic operations to named
106cbcb793Slorenzo chelini // operations. Conceptually it is the opposite of generalize.cpp.
116cbcb793Slorenzo chelini //
126cbcb793Slorenzo chelini //===----------------------------------------------------------------------===//
136cbcb793Slorenzo chelini 
143efac5c6SJaved Absar #include "mlir/Dialect/Complex/IR/Complex.h"
156cbcb793Slorenzo chelini #include "mlir/Dialect/Linalg/IR/Linalg.h"
166cbcb793Slorenzo chelini #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
173efac5c6SJaved Absar #include "mlir/Dialect/Linalg/Passes.h"
186cbcb793Slorenzo chelini #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1933b78338SJaved Absar #include "mlir/Dialect/Math/IR/Math.h"
203efac5c6SJaved Absar #include "mlir/IR/PatternMatch.h"
213efac5c6SJaved Absar #include "mlir/Support/TypeID.h"
223efac5c6SJaved Absar #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
236cbcb793Slorenzo chelini #include "llvm/Support/Debug.h"
246cbcb793Slorenzo chelini 
253efac5c6SJaved Absar namespace mlir {
263efac5c6SJaved Absar #define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
273efac5c6SJaved Absar #include "mlir/Dialect/Linalg/Passes.h.inc"
283efac5c6SJaved Absar } // namespace mlir
293efac5c6SJaved Absar 
306cbcb793Slorenzo chelini #define DEBUG_TYPE "linalg-specialization"
316cbcb793Slorenzo chelini 
3233b78338SJaved Absar #define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)                                \
3333b78338SJaved Absar   (rewriter.replaceOpWithNewOp<NEWOP>(                                         \
3433b78338SJaved Absar       genericOp,                                                               \
3533b78338SJaved Absar       ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0],            \
3633b78338SJaved Absar                  genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]},           \
3733b78338SJaved Absar       ValueRange{genericOp.getDpsInits()[0]}))
3833b78338SJaved Absar 
3933b78338SJaved Absar #define REPLACE_UNARY_OP(NEWOP)                                                \
4033b78338SJaved Absar   (rewriter.replaceOpWithNewOp<NEWOP>(genericOp,                               \
4133b78338SJaved Absar                                       ValueRange{genericOp.getDpsInputs()[0]}, \
4233b78338SJaved Absar                                       ValueRange{genericOp.getDpsInits()[0]}))
4333b78338SJaved Absar 
446cbcb793Slorenzo chelini using namespace mlir;
456cbcb793Slorenzo chelini using namespace mlir::linalg;
466cbcb793Slorenzo chelini 
4733b78338SJaved Absar // Given a elementwise single binary linalg generic op, checks whether the
4833b78338SJaved Absar // binary op accesses operands as swapped. e.g.
4933b78338SJaved Absar // this differentiates between a linalg-generic body that contains:
5033b78338SJaved Absar //    ^bb0(%a: f32, %b: f32, %c : f32):
5133b78338SJaved Absar //         %0 = arith.subf %a, %b : f32
5233b78338SJaved Absar //         linalg.yield %0: f32
5333b78338SJaved Absar // against:
5433b78338SJaved Absar //    ^bb0(%a: f32, %b: f32, %c : f32):
5533b78338SJaved Absar //         %0 = arith.subf %b, %a : f32
5633b78338SJaved Absar //         linalg.yield %0: f32
5733b78338SJaved Absar // Former is linalg.sub(a,b), latter is linalg.sub(b,a).
5833b78338SJaved Absar static bool areBinOpsSwapped(GenericOp genericOp) {
5933b78338SJaved Absar   Block *body = genericOp.getBody();
6033b78338SJaved Absar   Operation *op = &body->front();
6133b78338SJaved Absar   bool swapped = false;
6233b78338SJaved Absar   if (op->getOpOperand(0).get() != body->getArgument(0)) {
6333b78338SJaved Absar     swapped = true;
6433b78338SJaved Absar     assert(op->getOpOperand(0).get() == body->getArgument(1) &&
6533b78338SJaved Absar            op->getOpOperand(1).get() == body->getArgument(0) &&
6633b78338SJaved Absar            "binary op uses just one block arg");
6733b78338SJaved Absar   }
6833b78338SJaved Absar   return swapped;
6933b78338SJaved Absar }
7033b78338SJaved Absar 
713efac5c6SJaved Absar //===----------------------------------------------------------------------===//
723efac5c6SJaved Absar // Specialize linalg generic to matmul variants.
733efac5c6SJaved Absar //===----------------------------------------------------------------------===//
743efac5c6SJaved Absar /// Identifies linalg.generic that is essentially named op of the form:
753efac5c6SJaved Absar //    ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
763efac5c6SJaved Absar //
773efac5c6SJaved Absar // It is possible that a linalg.generic may be implementing a matmul but not
783efac5c6SJaved Absar // in a straight-forward way e.g. below is matrix multiply over some slice
793efac5c6SJaved Absar // ```
803efac5c6SJaved Absar //  %0 = linalg.generic {
813efac5c6SJaved Absar //          indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
823efac5c6SJaved Absar //                           affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
833efac5c6SJaved Absar //                           affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
843efac5c6SJaved Absar //          iterator_types = ["parallel", "parallel", "parallel"]}
853efac5c6SJaved Absar //          ins(%A, %B : tensor<20x20x20xf32>,  tensor<20x20x20xf32>)
863efac5c6SJaved Absar //          outs(%C : tensor<20x20x20xf32>) {
873efac5c6SJaved Absar //             ^bb0(%a: f32, %b: f32, %c : f32):
883efac5c6SJaved Absar //                %mul = arith.mulf %a, %b : f32
893efac5c6SJaved Absar //                %add = arith.addf %mul, %c : f32
903efac5c6SJaved Absar //                linalg.yield %add : f32
913efac5c6SJaved Absar //       } -> tensor<20x20x20xf32>
923efac5c6SJaved Absar // ```
933efac5c6SJaved Absar // It is not possible to represent above as named op.
943efac5c6SJaved Absar // e.g. linalg.batch_matmul(%A, %B :  tensor<20x20x20xf32>, ...) is
953efac5c6SJaved Absar // not  the same as linalg.generic above.
963efac5c6SJaved Absar namespace {
973efac5c6SJaved Absar enum class IndexMatchResult {
983efac5c6SJaved Absar   Match = 0,  // identity map.
993efac5c6SJaved Absar   Transposed, // transposed map.
1003efac5c6SJaved Absar   Mismatch    // none of the above.
1013efac5c6SJaved Absar };
1023efac5c6SJaved Absar 
1033efac5c6SJaved Absar // Checks whether the input Affine `map` contains two consecutive dims that
1043efac5c6SJaved Absar // can be interpreted as accessing a 2D matrix. It is assumed that the row
1053efac5c6SJaved Absar // column dimension are adjacent axis (in this order) and start at
1063efac5c6SJaved Absar // `rowDimIdx` in the input map.
1073efac5c6SJaved Absar //
1083efac5c6SJaved Absar //  e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
1093efac5c6SJaved Absar //  whether the map of A is identity (match), transposed, or something
1103efac5c6SJaved Absar //  completely different (mis-match). Similar for B and C.
1113efac5c6SJaved Absar static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
1123efac5c6SJaved Absar                                         unsigned expectedPosOfRowDim,
1133efac5c6SJaved Absar                                         unsigned expectedPosOfColDim) {
1143efac5c6SJaved Absar   // Get the matrix multiply indices. They are past the batch indices.
1153efac5c6SJaved Absar   auto exprOfRowDim = map.getResults()[rowDimIdx];
1163efac5c6SJaved Absar   auto exprOfColDim = map.getResults()[rowDimIdx + 1];
1173efac5c6SJaved Absar 
1183efac5c6SJaved Absar   // They should be pure dimension ids.
1193efac5c6SJaved Absar   if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
1203efac5c6SJaved Absar       exprOfColDim.getKind() != AffineExprKind::DimId)
1213efac5c6SJaved Absar     return IndexMatchResult::Mismatch;
1223efac5c6SJaved Absar 
1233efac5c6SJaved Absar   auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
1243efac5c6SJaved Absar   auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
1253efac5c6SJaved Absar 
1263efac5c6SJaved Absar   if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
1273efac5c6SJaved Absar     return IndexMatchResult::Match;
1283efac5c6SJaved Absar 
1293efac5c6SJaved Absar   if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
1303efac5c6SJaved Absar     return IndexMatchResult::Transposed;
1313efac5c6SJaved Absar 
1323efac5c6SJaved Absar   return IndexMatchResult::Mismatch;
1333efac5c6SJaved Absar }
1343efac5c6SJaved Absar 
1353efac5c6SJaved Absar // Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
1363efac5c6SJaved Absar //  All the variants expressed as pseudo regular expression:
1373efac5c6SJaved Absar //      `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
1383efac5c6SJaved Absar //  have same number of ins/out, so its easy to stamp different versions.
1393efac5c6SJaved Absar template <typename NamedOpTy>
1403efac5c6SJaved Absar static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
1413efac5c6SJaved Absar   LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
1423efac5c6SJaved Absar       op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
1433efac5c6SJaved Absar       ValueRange{op.getDpsInits()[0]});
1443efac5c6SJaved Absar   return namedOp;
1453efac5c6SJaved Absar }
1463efac5c6SJaved Absar 
1473efac5c6SJaved Absar // Converts linalg.generic to named linalg.*matmul* where possible.
1483efac5c6SJaved Absar static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
1493efac5c6SJaved Absar                                                         GenericOp genericOp) {
1503efac5c6SJaved Absar   if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
1513efac5c6SJaved Absar     return failure();
1523efac5c6SJaved Absar 
1533efac5c6SJaved Absar   // Early exit if not projected permutations.
1543efac5c6SJaved Absar   auto mapRange = genericOp.getIndexingMapsArray();
1553efac5c6SJaved Absar   if (llvm::any_of(mapRange,
1563efac5c6SJaved Absar                    [](AffineMap m) { return !m.isProjectedPermutation(); }))
1573efac5c6SJaved Absar     return failure();
1583efac5c6SJaved Absar 
1593efac5c6SJaved Absar   // Linalg generic contraction can be across multiple axis e.g.
1603efac5c6SJaved Absar   // ```
1613efac5c6SJaved Absar   //      linalg.generic
1623efac5c6SJaved Absar   //           {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
1633efac5c6SJaved Absar   //                             affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
1643efac5c6SJaved Absar   //                             affine_map<(m, n, k1, k2) -> (m, n)>],
1653efac5c6SJaved Absar   //           iterator_types = ["parallel", "parallel",
1663efac5c6SJaved Absar   //                             "reduction", "reduction"]}
1673efac5c6SJaved Absar   //           ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
1683efac5c6SJaved Absar   //           outs(%C : tensor<10x40xf32>) {
1693efac5c6SJaved Absar   //           ^bb0(%a: f32, %b: f32, %c: f32):
1703efac5c6SJaved Absar   //                 %1 = arith.mulf %a, %b : f32
1713efac5c6SJaved Absar   //                 %2 = arith.addf %c, %1 : f32
1723efac5c6SJaved Absar   //                 linalg.yield %2 : f32
1733efac5c6SJaved Absar   //      } -> tensor<10x40xf32>
1743efac5c6SJaved Absar   //  ```
1753efac5c6SJaved Absar   //  In above contraction, there are two reduction dimensions {k1, k2}
1763efac5c6SJaved Absar   //  and although a valid linalg contraction, it is not a named-op
1773efac5c6SJaved Absar   //  matrix multiply kind. Therefore, reject multi-dim reduction.
1783efac5c6SJaved Absar   auto res = inferContractionDims(genericOp);
1793efac5c6SJaved Absar   if (!succeeded(res))
1803efac5c6SJaved Absar     return failure();
1813efac5c6SJaved Absar   auto dims = *res;
1823efac5c6SJaved Absar   if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
1833efac5c6SJaved Absar     return failure();
1843efac5c6SJaved Absar 
1853efac5c6SJaved Absar   if (!mlir::linalg::detail::isContractionBody(
1863efac5c6SJaved Absar           *genericOp.getBlock(), [](Operation *first, Operation *second) {
1873efac5c6SJaved Absar             if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
1883efac5c6SJaved Absar                 (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
1893efac5c6SJaved Absar                 (isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
1903efac5c6SJaved Absar               return true;
1913efac5c6SJaved Absar             return false;
1923efac5c6SJaved Absar           }))
1933efac5c6SJaved Absar     return failure();
1943efac5c6SJaved Absar 
1953efac5c6SJaved Absar   // Check rank of operands
1963efac5c6SJaved Absar   auto indexingMaps = genericOp.getIndexingMapsArray();
1973efac5c6SJaved Absar   if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
1983efac5c6SJaved Absar         return m.getResults().size() !=
1993efac5c6SJaved Absar                dims.batch.size() + 2 /* any two of {m,n,k} */;
2003efac5c6SJaved Absar       }))
2013efac5c6SJaved Absar     return failure();
2023efac5c6SJaved Absar 
2033efac5c6SJaved Absar   auto numOfBatchDims = dims.batch.size();
2043efac5c6SJaved Absar   if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
2053efac5c6SJaved Absar     return failure();
2063efac5c6SJaved Absar 
2073efac5c6SJaved Absar   if (numOfBatchDims) {
2083efac5c6SJaved Absar     // Each operand in a linalg generic contraction  could express different
2093efac5c6SJaved Absar     // permutations for its batch dimension. But for named op it must be
2103efac5c6SJaved Absar     // identity since separate maps are not specified.
2113efac5c6SJaved Absar     if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
2123efac5c6SJaved Absar           for (unsigned i = 0; i < numOfBatchDims; ++i) {
2133efac5c6SJaved Absar             auto expr = m.getResults()[i];
2143efac5c6SJaved Absar             if (expr.getKind() != AffineExprKind::DimId ||
2153efac5c6SJaved Absar                 cast<AffineDimExpr>(expr).getPosition() != i)
2163efac5c6SJaved Absar               return true;
2173efac5c6SJaved Absar           }
2183efac5c6SJaved Absar           return false;
2193efac5c6SJaved Absar         }))
2203efac5c6SJaved Absar       return failure();
2213efac5c6SJaved Absar   }
2223efac5c6SJaved Absar 
2233efac5c6SJaved Absar   auto a =
2243efac5c6SJaved Absar       matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
2253efac5c6SJaved Absar   auto b =
2263efac5c6SJaved Absar       matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
2273efac5c6SJaved Absar   auto c =
2283efac5c6SJaved Absar       matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
2293efac5c6SJaved Absar 
230165f4535SKazu Hirata   if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch))
2313efac5c6SJaved Absar     return failure();
2323efac5c6SJaved Absar 
2333efac5c6SJaved Absar   if (c != IndexMatchResult::Match ||
2343efac5c6SJaved Absar       (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
2353efac5c6SJaved Absar     return failure();
2363efac5c6SJaved Absar 
2373efac5c6SJaved Absar   /// Codegen the different matmul variants.
2383efac5c6SJaved Absar   if (numOfBatchDims) {
2393efac5c6SJaved Absar     if (a == IndexMatchResult::Transposed)
2403efac5c6SJaved Absar       return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
2413efac5c6SJaved Absar                                                                genericOp);
2423efac5c6SJaved Absar     if (b == IndexMatchResult::Transposed)
2433efac5c6SJaved Absar       return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
2443efac5c6SJaved Absar                                                                genericOp);
2453efac5c6SJaved Absar     return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
2463efac5c6SJaved Absar   }
2473efac5c6SJaved Absar 
2483efac5c6SJaved Absar   if (a == IndexMatchResult::Transposed)
2493efac5c6SJaved Absar     return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
2503efac5c6SJaved Absar   if (b == IndexMatchResult::Transposed)
2513efac5c6SJaved Absar     return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
2523efac5c6SJaved Absar   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
2533efac5c6SJaved Absar }
2543efac5c6SJaved Absar 
2553efac5c6SJaved Absar } // namespace
2563efac5c6SJaved Absar 
2573efac5c6SJaved Absar //===----------------------------------------------------------------------===//
2583efac5c6SJaved Absar // Categorize linalg generic to named op where possible.
2593efac5c6SJaved Absar //===----------------------------------------------------------------------===//
2606cbcb793Slorenzo chelini FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
2616cbcb793Slorenzo chelini                                                       GenericOp genericOp) {
262c13f806fSJaved Absar   // Copy
2636cbcb793Slorenzo chelini   if (isaCopyOpInterface(genericOp)) {
2646cbcb793Slorenzo chelini     LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
2656cbcb793Slorenzo chelini         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
2666cbcb793Slorenzo chelini     return namedOp;
2676cbcb793Slorenzo chelini   }
26833b78338SJaved Absar 
269c13f806fSJaved Absar   // Fill
27033b78338SJaved Absar   if (isaFillOpInterface(genericOp)) {
27133b78338SJaved Absar     LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
27233b78338SJaved Absar         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
27333b78338SJaved Absar     return namedOp;
27433b78338SJaved Absar   }
27533b78338SJaved Absar 
276c13f806fSJaved Absar   // Broadcast
277c13f806fSJaved Absar   std::optional<SmallVector<int64_t>> equivalentToBroadcast =
278c13f806fSJaved Absar       isaBroadcastOpInterface(genericOp);
279c13f806fSJaved Absar   if (equivalentToBroadcast) {
280c13f806fSJaved Absar     auto dims = *equivalentToBroadcast;
281c13f806fSJaved Absar     LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
282c13f806fSJaved Absar         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
283c13f806fSJaved Absar         dims);
284c13f806fSJaved Absar     return namedOp;
285c13f806fSJaved Absar   }
286c13f806fSJaved Absar 
287c13f806fSJaved Absar   // Transpose
288c13f806fSJaved Absar   std::optional<SmallVector<int64_t>> equivalentToTranspose =
289c13f806fSJaved Absar       isaTransposeOpInterface(genericOp);
290c13f806fSJaved Absar   if (equivalentToTranspose) {
291c13f806fSJaved Absar     auto permutation = *equivalentToTranspose;
292c13f806fSJaved Absar     LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
293c13f806fSJaved Absar         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
294c13f806fSJaved Absar         permutation);
295c13f806fSJaved Absar     return namedOp;
296c13f806fSJaved Absar   }
297c13f806fSJaved Absar 
298c13f806fSJaved Absar   // Elementwise Unary
29933b78338SJaved Absar   if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
30033b78338SJaved Absar     Operation *op = &genericOp.getBody()->front();
30133b78338SJaved Absar     if (isa<math::ExpOp>(op)) {
30233b78338SJaved Absar       LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
30333b78338SJaved Absar       return namedOp;
30433b78338SJaved Absar     }
30533b78338SJaved Absar   }
30633b78338SJaved Absar 
307c13f806fSJaved Absar   // Elementwise Binary
30833b78338SJaved Absar   if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
30933b78338SJaved Absar     bool swap = areBinOpsSwapped(genericOp);
31033b78338SJaved Absar     Operation *op = &genericOp.getBody()->front();
31133b78338SJaved Absar     if (isa<arith::AddFOp>(op)) {
31233b78338SJaved Absar       LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
31333b78338SJaved Absar       return namedOp;
31433b78338SJaved Absar     }
31533b78338SJaved Absar     if (isa<arith::SubFOp>(op)) {
31633b78338SJaved Absar       LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
31733b78338SJaved Absar       return namedOp;
31833b78338SJaved Absar     }
31933b78338SJaved Absar     if (isa<arith::MulFOp>(op)) {
32033b78338SJaved Absar       LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
32133b78338SJaved Absar       return namedOp;
32233b78338SJaved Absar     }
32333b78338SJaved Absar     if (isa<arith::DivFOp>(op)) {
32433b78338SJaved Absar       LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
32533b78338SJaved Absar       return namedOp;
32633b78338SJaved Absar     }
32733b78338SJaved Absar   }
3283efac5c6SJaved Absar 
329c13f806fSJaved Absar   // Contraction - e.g. matmul
3303efac5c6SJaved Absar   if (isaContractionOpInterface(genericOp)) {
3313efac5c6SJaved Absar     return specializeLinalgContractions(rewriter, genericOp);
3323efac5c6SJaved Absar   }
3336cbcb793Slorenzo chelini   return failure();
3346cbcb793Slorenzo chelini }
3353efac5c6SJaved Absar 
3363efac5c6SJaved Absar namespace {
3373efac5c6SJaved Absar struct LinalgSpecializeGenericOpsPass
3383efac5c6SJaved Absar     : public impl::LinalgSpecializeGenericOpsPassBase<
3393efac5c6SJaved Absar           LinalgSpecializeGenericOpsPass> {
3403efac5c6SJaved Absar 
3413efac5c6SJaved Absar   using impl::LinalgSpecializeGenericOpsPassBase<
3423efac5c6SJaved Absar       LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
3433efac5c6SJaved Absar   void runOnOperation() override;
3443efac5c6SJaved Absar };
3453efac5c6SJaved Absar } // namespace
3463efac5c6SJaved Absar 
3473efac5c6SJaved Absar void LinalgSpecializeGenericOpsPass::runOnOperation() {
3483efac5c6SJaved Absar   RewritePatternSet patterns(&getContext());
3493efac5c6SJaved Absar   populateLinalgGenericOpsSpecializationPatterns(patterns);
3500ac4821bSJaved Absar   populateDecomposeProjectedPermutationPatterns(patterns);
3513efac5c6SJaved Absar 
352*09dfc571SJacques Pienaar   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
3533efac5c6SJaved Absar     signalPassFailure();
3543efac5c6SJaved Absar }
3553efac5c6SJaved Absar 
3563efac5c6SJaved Absar void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
3573efac5c6SJaved Absar     RewritePatternSet &patterns) {
3583efac5c6SJaved Absar   patterns.add<LinalgSpecializationPattern>(patterns.getContext());
3593efac5c6SJaved Absar }
360