16cbcb793Slorenzo chelini //===- Specialize.cpp - linalg generic ops to named ops ------------------===// 26cbcb793Slorenzo chelini // 36cbcb793Slorenzo chelini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 46cbcb793Slorenzo chelini // See https://llvm.org/LICENSE.txt for license information. 56cbcb793Slorenzo chelini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 66cbcb793Slorenzo chelini // 76cbcb793Slorenzo chelini //===----------------------------------------------------------------------===// 86cbcb793Slorenzo chelini // 96cbcb793Slorenzo chelini // This file implements a method to specialize generic operations to named 106cbcb793Slorenzo chelini // operations. Conceptually it is the opposite of generalize.cpp. 116cbcb793Slorenzo chelini // 126cbcb793Slorenzo chelini //===----------------------------------------------------------------------===// 136cbcb793Slorenzo chelini 143efac5c6SJaved Absar #include "mlir/Dialect/Complex/IR/Complex.h" 156cbcb793Slorenzo chelini #include "mlir/Dialect/Linalg/IR/Linalg.h" 166cbcb793Slorenzo chelini #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" 173efac5c6SJaved Absar #include "mlir/Dialect/Linalg/Passes.h" 186cbcb793Slorenzo chelini #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 1933b78338SJaved Absar #include "mlir/Dialect/Math/IR/Math.h" 203efac5c6SJaved Absar #include "mlir/IR/PatternMatch.h" 213efac5c6SJaved Absar #include "mlir/Support/TypeID.h" 223efac5c6SJaved Absar #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 236cbcb793Slorenzo chelini #include "llvm/Support/Debug.h" 246cbcb793Slorenzo chelini 253efac5c6SJaved Absar namespace mlir { 263efac5c6SJaved Absar #define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS 273efac5c6SJaved Absar #include "mlir/Dialect/Linalg/Passes.h.inc" 283efac5c6SJaved Absar } // namespace mlir 293efac5c6SJaved Absar 306cbcb793Slorenzo chelini #define DEBUG_TYPE "linalg-specialization" 316cbcb793Slorenzo chelini 3233b78338SJaved Absar #define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \ 3333b78338SJaved Absar (rewriter.replaceOpWithNewOp<NEWOP>( \ 3433b78338SJaved Absar genericOp, \ 3533b78338SJaved Absar ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \ 3633b78338SJaved Absar genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \ 3733b78338SJaved Absar ValueRange{genericOp.getDpsInits()[0]})) 3833b78338SJaved Absar 3933b78338SJaved Absar #define REPLACE_UNARY_OP(NEWOP) \ 4033b78338SJaved Absar (rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \ 4133b78338SJaved Absar ValueRange{genericOp.getDpsInputs()[0]}, \ 4233b78338SJaved Absar ValueRange{genericOp.getDpsInits()[0]})) 4333b78338SJaved Absar 446cbcb793Slorenzo chelini using namespace mlir; 456cbcb793Slorenzo chelini using namespace mlir::linalg; 466cbcb793Slorenzo chelini 4733b78338SJaved Absar // Given a elementwise single binary linalg generic op, checks whether the 4833b78338SJaved Absar // binary op accesses operands as swapped. e.g. 4933b78338SJaved Absar // this differentiates between a linalg-generic body that contains: 5033b78338SJaved Absar // ^bb0(%a: f32, %b: f32, %c : f32): 5133b78338SJaved Absar // %0 = arith.subf %a, %b : f32 5233b78338SJaved Absar // linalg.yield %0: f32 5333b78338SJaved Absar // against: 5433b78338SJaved Absar // ^bb0(%a: f32, %b: f32, %c : f32): 5533b78338SJaved Absar // %0 = arith.subf %b, %a : f32 5633b78338SJaved Absar // linalg.yield %0: f32 5733b78338SJaved Absar // Former is linalg.sub(a,b), latter is linalg.sub(b,a). 5833b78338SJaved Absar static bool areBinOpsSwapped(GenericOp genericOp) { 5933b78338SJaved Absar Block *body = genericOp.getBody(); 6033b78338SJaved Absar Operation *op = &body->front(); 6133b78338SJaved Absar bool swapped = false; 6233b78338SJaved Absar if (op->getOpOperand(0).get() != body->getArgument(0)) { 6333b78338SJaved Absar swapped = true; 6433b78338SJaved Absar assert(op->getOpOperand(0).get() == body->getArgument(1) && 6533b78338SJaved Absar op->getOpOperand(1).get() == body->getArgument(0) && 6633b78338SJaved Absar "binary op uses just one block arg"); 6733b78338SJaved Absar } 6833b78338SJaved Absar return swapped; 6933b78338SJaved Absar } 7033b78338SJaved Absar 713efac5c6SJaved Absar //===----------------------------------------------------------------------===// 723efac5c6SJaved Absar // Specialize linalg generic to matmul variants. 733efac5c6SJaved Absar //===----------------------------------------------------------------------===// 743efac5c6SJaved Absar /// Identifies linalg.generic that is essentially named op of the form: 753efac5c6SJaved Absar // ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? ` 763efac5c6SJaved Absar // 773efac5c6SJaved Absar // It is possible that a linalg.generic may be implementing a matmul but not 783efac5c6SJaved Absar // in a straight-forward way e.g. below is matrix multiply over some slice 793efac5c6SJaved Absar // ``` 803efac5c6SJaved Absar // %0 = linalg.generic { 813efac5c6SJaved Absar // indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>, 823efac5c6SJaved Absar // affine_map<(d0, d1, d2) -> (d0, 5, d2)>, 833efac5c6SJaved Absar // affine_map<(d0, d1, d2) -> (d2, d1, 13)>], 843efac5c6SJaved Absar // iterator_types = ["parallel", "parallel", "parallel"]} 853efac5c6SJaved Absar // ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>) 863efac5c6SJaved Absar // outs(%C : tensor<20x20x20xf32>) { 873efac5c6SJaved Absar // ^bb0(%a: f32, %b: f32, %c : f32): 883efac5c6SJaved Absar // %mul = arith.mulf %a, %b : f32 893efac5c6SJaved Absar // %add = arith.addf %mul, %c : f32 903efac5c6SJaved Absar // linalg.yield %add : f32 913efac5c6SJaved Absar // } -> tensor<20x20x20xf32> 923efac5c6SJaved Absar // ``` 933efac5c6SJaved Absar // It is not possible to represent above as named op. 943efac5c6SJaved Absar // e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is 953efac5c6SJaved Absar // not the same as linalg.generic above. 963efac5c6SJaved Absar namespace { 973efac5c6SJaved Absar enum class IndexMatchResult { 983efac5c6SJaved Absar Match = 0, // identity map. 993efac5c6SJaved Absar Transposed, // transposed map. 1003efac5c6SJaved Absar Mismatch // none of the above. 1013efac5c6SJaved Absar }; 1023efac5c6SJaved Absar 1033efac5c6SJaved Absar // Checks whether the input Affine `map` contains two consecutive dims that 1043efac5c6SJaved Absar // can be interpreted as accessing a 2D matrix. It is assumed that the row 1053efac5c6SJaved Absar // column dimension are adjacent axis (in this order) and start at 1063efac5c6SJaved Absar // `rowDimIdx` in the input map. 1073efac5c6SJaved Absar // 1083efac5c6SJaved Absar // e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check 1093efac5c6SJaved Absar // whether the map of A is identity (match), transposed, or something 1103efac5c6SJaved Absar // completely different (mis-match). Similar for B and C. 1113efac5c6SJaved Absar static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx, 1123efac5c6SJaved Absar unsigned expectedPosOfRowDim, 1133efac5c6SJaved Absar unsigned expectedPosOfColDim) { 1143efac5c6SJaved Absar // Get the matrix multiply indices. They are past the batch indices. 1153efac5c6SJaved Absar auto exprOfRowDim = map.getResults()[rowDimIdx]; 1163efac5c6SJaved Absar auto exprOfColDim = map.getResults()[rowDimIdx + 1]; 1173efac5c6SJaved Absar 1183efac5c6SJaved Absar // They should be pure dimension ids. 1193efac5c6SJaved Absar if (exprOfRowDim.getKind() != AffineExprKind::DimId || 1203efac5c6SJaved Absar exprOfColDim.getKind() != AffineExprKind::DimId) 1213efac5c6SJaved Absar return IndexMatchResult::Mismatch; 1223efac5c6SJaved Absar 1233efac5c6SJaved Absar auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition(); 1243efac5c6SJaved Absar auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition(); 1253efac5c6SJaved Absar 1263efac5c6SJaved Absar if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim) 1273efac5c6SJaved Absar return IndexMatchResult::Match; 1283efac5c6SJaved Absar 1293efac5c6SJaved Absar if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim) 1303efac5c6SJaved Absar return IndexMatchResult::Transposed; 1313efac5c6SJaved Absar 1323efac5c6SJaved Absar return IndexMatchResult::Mismatch; 1333efac5c6SJaved Absar } 1343efac5c6SJaved Absar 1353efac5c6SJaved Absar // Replaces genericOp with `NamedOpTy` op, supplied as a template arg. 1363efac5c6SJaved Absar // All the variants expressed as pseudo regular expression: 1373efac5c6SJaved Absar // `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?` 1383efac5c6SJaved Absar // have same number of ins/out, so its easy to stamp different versions. 1393efac5c6SJaved Absar template <typename NamedOpTy> 1403efac5c6SJaved Absar static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) { 1413efac5c6SJaved Absar LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>( 1423efac5c6SJaved Absar op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]}, 1433efac5c6SJaved Absar ValueRange{op.getDpsInits()[0]}); 1443efac5c6SJaved Absar return namedOp; 1453efac5c6SJaved Absar } 1463efac5c6SJaved Absar 1473efac5c6SJaved Absar // Converts linalg.generic to named linalg.*matmul* where possible. 1483efac5c6SJaved Absar static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter, 1493efac5c6SJaved Absar GenericOp genericOp) { 1503efac5c6SJaved Absar if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1) 1513efac5c6SJaved Absar return failure(); 1523efac5c6SJaved Absar 1533efac5c6SJaved Absar // Early exit if not projected permutations. 1543efac5c6SJaved Absar auto mapRange = genericOp.getIndexingMapsArray(); 1553efac5c6SJaved Absar if (llvm::any_of(mapRange, 1563efac5c6SJaved Absar [](AffineMap m) { return !m.isProjectedPermutation(); })) 1573efac5c6SJaved Absar return failure(); 1583efac5c6SJaved Absar 1593efac5c6SJaved Absar // Linalg generic contraction can be across multiple axis e.g. 1603efac5c6SJaved Absar // ``` 1613efac5c6SJaved Absar // linalg.generic 1623efac5c6SJaved Absar // {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>, 1633efac5c6SJaved Absar // affine_map<(m, n, k1, k2) -> (k2, k1, n)>, 1643efac5c6SJaved Absar // affine_map<(m, n, k1, k2) -> (m, n)>], 1653efac5c6SJaved Absar // iterator_types = ["parallel", "parallel", 1663efac5c6SJaved Absar // "reduction", "reduction"]} 1673efac5c6SJaved Absar // ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>) 1683efac5c6SJaved Absar // outs(%C : tensor<10x40xf32>) { 1693efac5c6SJaved Absar // ^bb0(%a: f32, %b: f32, %c: f32): 1703efac5c6SJaved Absar // %1 = arith.mulf %a, %b : f32 1713efac5c6SJaved Absar // %2 = arith.addf %c, %1 : f32 1723efac5c6SJaved Absar // linalg.yield %2 : f32 1733efac5c6SJaved Absar // } -> tensor<10x40xf32> 1743efac5c6SJaved Absar // ``` 1753efac5c6SJaved Absar // In above contraction, there are two reduction dimensions {k1, k2} 1763efac5c6SJaved Absar // and although a valid linalg contraction, it is not a named-op 1773efac5c6SJaved Absar // matrix multiply kind. Therefore, reject multi-dim reduction. 1783efac5c6SJaved Absar auto res = inferContractionDims(genericOp); 1793efac5c6SJaved Absar if (!succeeded(res)) 1803efac5c6SJaved Absar return failure(); 1813efac5c6SJaved Absar auto dims = *res; 1823efac5c6SJaved Absar if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1) 1833efac5c6SJaved Absar return failure(); 1843efac5c6SJaved Absar 1853efac5c6SJaved Absar if (!mlir::linalg::detail::isContractionBody( 1863efac5c6SJaved Absar *genericOp.getBlock(), [](Operation *first, Operation *second) { 1873efac5c6SJaved Absar if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) || 1883efac5c6SJaved Absar (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) || 1893efac5c6SJaved Absar (isa<complex::MulOp>(first) && isa<complex::AddOp>(second))) 1903efac5c6SJaved Absar return true; 1913efac5c6SJaved Absar return false; 1923efac5c6SJaved Absar })) 1933efac5c6SJaved Absar return failure(); 1943efac5c6SJaved Absar 1953efac5c6SJaved Absar // Check rank of operands 1963efac5c6SJaved Absar auto indexingMaps = genericOp.getIndexingMapsArray(); 1973efac5c6SJaved Absar if (llvm::any_of(indexingMaps, [&dims](AffineMap m) { 1983efac5c6SJaved Absar return m.getResults().size() != 1993efac5c6SJaved Absar dims.batch.size() + 2 /* any two of {m,n,k} */; 2003efac5c6SJaved Absar })) 2013efac5c6SJaved Absar return failure(); 2023efac5c6SJaved Absar 2033efac5c6SJaved Absar auto numOfBatchDims = dims.batch.size(); 2043efac5c6SJaved Absar if (indexingMaps[0].getNumDims() != numOfBatchDims + 3) 2053efac5c6SJaved Absar return failure(); 2063efac5c6SJaved Absar 2073efac5c6SJaved Absar if (numOfBatchDims) { 2083efac5c6SJaved Absar // Each operand in a linalg generic contraction could express different 2093efac5c6SJaved Absar // permutations for its batch dimension. But for named op it must be 2103efac5c6SJaved Absar // identity since separate maps are not specified. 2113efac5c6SJaved Absar if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) { 2123efac5c6SJaved Absar for (unsigned i = 0; i < numOfBatchDims; ++i) { 2133efac5c6SJaved Absar auto expr = m.getResults()[i]; 2143efac5c6SJaved Absar if (expr.getKind() != AffineExprKind::DimId || 2153efac5c6SJaved Absar cast<AffineDimExpr>(expr).getPosition() != i) 2163efac5c6SJaved Absar return true; 2173efac5c6SJaved Absar } 2183efac5c6SJaved Absar return false; 2193efac5c6SJaved Absar })) 2203efac5c6SJaved Absar return failure(); 2213efac5c6SJaved Absar } 2223efac5c6SJaved Absar 2233efac5c6SJaved Absar auto a = 2243efac5c6SJaved Absar matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]); 2253efac5c6SJaved Absar auto b = 2263efac5c6SJaved Absar matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]); 2273efac5c6SJaved Absar auto c = 2283efac5c6SJaved Absar matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]); 2293efac5c6SJaved Absar 230165f4535SKazu Hirata if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch)) 2313efac5c6SJaved Absar return failure(); 2323efac5c6SJaved Absar 2333efac5c6SJaved Absar if (c != IndexMatchResult::Match || 2343efac5c6SJaved Absar (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed)) 2353efac5c6SJaved Absar return failure(); 2363efac5c6SJaved Absar 2373efac5c6SJaved Absar /// Codegen the different matmul variants. 2383efac5c6SJaved Absar if (numOfBatchDims) { 2393efac5c6SJaved Absar if (a == IndexMatchResult::Transposed) 2403efac5c6SJaved Absar return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter, 2413efac5c6SJaved Absar genericOp); 2423efac5c6SJaved Absar if (b == IndexMatchResult::Transposed) 2433efac5c6SJaved Absar return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter, 2443efac5c6SJaved Absar genericOp); 2453efac5c6SJaved Absar return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp); 2463efac5c6SJaved Absar } 2473efac5c6SJaved Absar 2483efac5c6SJaved Absar if (a == IndexMatchResult::Transposed) 2493efac5c6SJaved Absar return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp); 2503efac5c6SJaved Absar if (b == IndexMatchResult::Transposed) 2513efac5c6SJaved Absar return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp); 2523efac5c6SJaved Absar return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp); 2533efac5c6SJaved Absar } 2543efac5c6SJaved Absar 2553efac5c6SJaved Absar } // namespace 2563efac5c6SJaved Absar 2573efac5c6SJaved Absar //===----------------------------------------------------------------------===// 2583efac5c6SJaved Absar // Categorize linalg generic to named op where possible. 2593efac5c6SJaved Absar //===----------------------------------------------------------------------===// 2606cbcb793Slorenzo chelini FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter, 2616cbcb793Slorenzo chelini GenericOp genericOp) { 262c13f806fSJaved Absar // Copy 2636cbcb793Slorenzo chelini if (isaCopyOpInterface(genericOp)) { 2646cbcb793Slorenzo chelini LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>( 2656cbcb793Slorenzo chelini genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]); 2666cbcb793Slorenzo chelini return namedOp; 2676cbcb793Slorenzo chelini } 26833b78338SJaved Absar 269c13f806fSJaved Absar // Fill 27033b78338SJaved Absar if (isaFillOpInterface(genericOp)) { 27133b78338SJaved Absar LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>( 27233b78338SJaved Absar genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]); 27333b78338SJaved Absar return namedOp; 27433b78338SJaved Absar } 27533b78338SJaved Absar 276c13f806fSJaved Absar // Broadcast 277c13f806fSJaved Absar std::optional<SmallVector<int64_t>> equivalentToBroadcast = 278c13f806fSJaved Absar isaBroadcastOpInterface(genericOp); 279c13f806fSJaved Absar if (equivalentToBroadcast) { 280c13f806fSJaved Absar auto dims = *equivalentToBroadcast; 281c13f806fSJaved Absar LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>( 282c13f806fSJaved Absar genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0], 283c13f806fSJaved Absar dims); 284c13f806fSJaved Absar return namedOp; 285c13f806fSJaved Absar } 286c13f806fSJaved Absar 287c13f806fSJaved Absar // Transpose 288c13f806fSJaved Absar std::optional<SmallVector<int64_t>> equivalentToTranspose = 289c13f806fSJaved Absar isaTransposeOpInterface(genericOp); 290c13f806fSJaved Absar if (equivalentToTranspose) { 291c13f806fSJaved Absar auto permutation = *equivalentToTranspose; 292c13f806fSJaved Absar LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>( 293c13f806fSJaved Absar genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0], 294c13f806fSJaved Absar permutation); 295c13f806fSJaved Absar return namedOp; 296c13f806fSJaved Absar } 297c13f806fSJaved Absar 298c13f806fSJaved Absar // Elementwise Unary 29933b78338SJaved Absar if (isaElemwiseSingleUnaryOpInterface(genericOp)) { 30033b78338SJaved Absar Operation *op = &genericOp.getBody()->front(); 30133b78338SJaved Absar if (isa<math::ExpOp>(op)) { 30233b78338SJaved Absar LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp); 30333b78338SJaved Absar return namedOp; 30433b78338SJaved Absar } 30533b78338SJaved Absar } 30633b78338SJaved Absar 307c13f806fSJaved Absar // Elementwise Binary 30833b78338SJaved Absar if (isaElemwiseSingleBinaryOpInterface(genericOp)) { 30933b78338SJaved Absar bool swap = areBinOpsSwapped(genericOp); 31033b78338SJaved Absar Operation *op = &genericOp.getBody()->front(); 31133b78338SJaved Absar if (isa<arith::AddFOp>(op)) { 31233b78338SJaved Absar LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap); 31333b78338SJaved Absar return namedOp; 31433b78338SJaved Absar } 31533b78338SJaved Absar if (isa<arith::SubFOp>(op)) { 31633b78338SJaved Absar LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap); 31733b78338SJaved Absar return namedOp; 31833b78338SJaved Absar } 31933b78338SJaved Absar if (isa<arith::MulFOp>(op)) { 32033b78338SJaved Absar LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap); 32133b78338SJaved Absar return namedOp; 32233b78338SJaved Absar } 32333b78338SJaved Absar if (isa<arith::DivFOp>(op)) { 32433b78338SJaved Absar LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap); 32533b78338SJaved Absar return namedOp; 32633b78338SJaved Absar } 32733b78338SJaved Absar } 3283efac5c6SJaved Absar 329c13f806fSJaved Absar // Contraction - e.g. matmul 3303efac5c6SJaved Absar if (isaContractionOpInterface(genericOp)) { 3313efac5c6SJaved Absar return specializeLinalgContractions(rewriter, genericOp); 3323efac5c6SJaved Absar } 3336cbcb793Slorenzo chelini return failure(); 3346cbcb793Slorenzo chelini } 3353efac5c6SJaved Absar 3363efac5c6SJaved Absar namespace { 3373efac5c6SJaved Absar struct LinalgSpecializeGenericOpsPass 3383efac5c6SJaved Absar : public impl::LinalgSpecializeGenericOpsPassBase< 3393efac5c6SJaved Absar LinalgSpecializeGenericOpsPass> { 3403efac5c6SJaved Absar 3413efac5c6SJaved Absar using impl::LinalgSpecializeGenericOpsPassBase< 3423efac5c6SJaved Absar LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase; 3433efac5c6SJaved Absar void runOnOperation() override; 3443efac5c6SJaved Absar }; 3453efac5c6SJaved Absar } // namespace 3463efac5c6SJaved Absar 3473efac5c6SJaved Absar void LinalgSpecializeGenericOpsPass::runOnOperation() { 3483efac5c6SJaved Absar RewritePatternSet patterns(&getContext()); 3493efac5c6SJaved Absar populateLinalgGenericOpsSpecializationPatterns(patterns); 3500ac4821bSJaved Absar populateDecomposeProjectedPermutationPatterns(patterns); 3513efac5c6SJaved Absar 352*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 3533efac5c6SJaved Absar signalPassFailure(); 3543efac5c6SJaved Absar } 3553efac5c6SJaved Absar 3563efac5c6SJaved Absar void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns( 3573efac5c6SJaved Absar RewritePatternSet &patterns) { 3583efac5c6SJaved Absar patterns.add<LinalgSpecializationPattern>(patterns.getContext()); 3593efac5c6SJaved Absar } 360