//===- Specialize.cpp - linalg generic ops to named ops ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a method to specialize generic operations to named // operations. Conceptually it is the opposite of generalize.cpp. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/TypeID.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" namespace mlir { #define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS #include "mlir/Dialect/Linalg/Passes.h.inc" } // namespace mlir #define DEBUG_TYPE "linalg-specialization" #define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \ (rewriter.replaceOpWithNewOp( \ genericOp, \ ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \ genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \ ValueRange{genericOp.getDpsInits()[0]})) #define REPLACE_UNARY_OP(NEWOP) \ (rewriter.replaceOpWithNewOp(genericOp, \ ValueRange{genericOp.getDpsInputs()[0]}, \ ValueRange{genericOp.getDpsInits()[0]})) using namespace mlir; using namespace mlir::linalg; // Given a elementwise single binary linalg generic op, checks whether the // binary op accesses operands as swapped. e.g. // this differentiates between a linalg-generic body that contains: // ^bb0(%a: f32, %b: f32, %c : f32): // %0 = arith.subf %a, %b : f32 // linalg.yield %0: f32 // against: // ^bb0(%a: f32, %b: f32, %c : f32): // %0 = arith.subf %b, %a : f32 // linalg.yield %0: f32 // Former is linalg.sub(a,b), latter is linalg.sub(b,a). static bool areBinOpsSwapped(GenericOp genericOp) { Block *body = genericOp.getBody(); Operation *op = &body->front(); bool swapped = false; if (op->getOpOperand(0).get() != body->getArgument(0)) { swapped = true; assert(op->getOpOperand(0).get() == body->getArgument(1) && op->getOpOperand(1).get() == body->getArgument(0) && "binary op uses just one block arg"); } return swapped; } //===----------------------------------------------------------------------===// // Specialize linalg generic to matmul variants. //===----------------------------------------------------------------------===// /// Identifies linalg.generic that is essentially named op of the form: // ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? ` // // It is possible that a linalg.generic may be implementing a matmul but not // in a straight-forward way e.g. below is matrix multiply over some slice // ``` // %0 = linalg.generic { // indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>, // affine_map<(d0, d1, d2) -> (d0, 5, d2)>, // affine_map<(d0, d1, d2) -> (d2, d1, 13)>], // iterator_types = ["parallel", "parallel", "parallel"]} // ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>) // outs(%C : tensor<20x20x20xf32>) { // ^bb0(%a: f32, %b: f32, %c : f32): // %mul = arith.mulf %a, %b : f32 // %add = arith.addf %mul, %c : f32 // linalg.yield %add : f32 // } -> tensor<20x20x20xf32> // ``` // It is not possible to represent above as named op. // e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is // not the same as linalg.generic above. namespace { enum class IndexMatchResult { Match = 0, // identity map. Transposed, // transposed map. Mismatch // none of the above. }; // Checks whether the input Affine `map` contains two consecutive dims that // can be interpreted as accessing a 2D matrix. It is assumed that the row // column dimension are adjacent axis (in this order) and start at // `rowDimIdx` in the input map. // // e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check // whether the map of A is identity (match), transposed, or something // completely different (mis-match). Similar for B and C. static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx, unsigned expectedPosOfRowDim, unsigned expectedPosOfColDim) { // Get the matrix multiply indices. They are past the batch indices. auto exprOfRowDim = map.getResults()[rowDimIdx]; auto exprOfColDim = map.getResults()[rowDimIdx + 1]; // They should be pure dimension ids. if (exprOfRowDim.getKind() != AffineExprKind::DimId || exprOfColDim.getKind() != AffineExprKind::DimId) return IndexMatchResult::Mismatch; auto posRowDim = cast(exprOfRowDim).getPosition(); auto posColDim = cast(exprOfColDim).getPosition(); if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim) return IndexMatchResult::Match; if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim) return IndexMatchResult::Transposed; return IndexMatchResult::Mismatch; } // Replaces genericOp with `NamedOpTy` op, supplied as a template arg. // All the variants expressed as pseudo regular expression: // `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?` // have same number of ins/out, so its easy to stamp different versions. template static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) { LinalgOp namedOp = rewriter.replaceOpWithNewOp( op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]}, ValueRange{op.getDpsInits()[0]}); return namedOp; } // Converts linalg.generic to named linalg.*matmul* where possible. static FailureOr specializeLinalgContractions(RewriterBase &rewriter, GenericOp genericOp) { if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1) return failure(); // Early exit if not projected permutations. auto mapRange = genericOp.getIndexingMapsArray(); if (llvm::any_of(mapRange, [](AffineMap m) { return !m.isProjectedPermutation(); })) return failure(); // Linalg generic contraction can be across multiple axis e.g. // ``` // linalg.generic // {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>, // affine_map<(m, n, k1, k2) -> (k2, k1, n)>, // affine_map<(m, n, k1, k2) -> (m, n)>], // iterator_types = ["parallel", "parallel", // "reduction", "reduction"]} // ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>) // outs(%C : tensor<10x40xf32>) { // ^bb0(%a: f32, %b: f32, %c: f32): // %1 = arith.mulf %a, %b : f32 // %2 = arith.addf %c, %1 : f32 // linalg.yield %2 : f32 // } -> tensor<10x40xf32> // ``` // In above contraction, there are two reduction dimensions {k1, k2} // and although a valid linalg contraction, it is not a named-op // matrix multiply kind. Therefore, reject multi-dim reduction. auto res = inferContractionDims(genericOp); if (!succeeded(res)) return failure(); auto dims = *res; if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1) return failure(); if (!mlir::linalg::detail::isContractionBody( *genericOp.getBlock(), [](Operation *first, Operation *second) { if ((isa(first) && isa(second)) || (isa(first) && isa(second)) || (isa(first) && isa(second))) return true; return false; })) return failure(); // Check rank of operands auto indexingMaps = genericOp.getIndexingMapsArray(); if (llvm::any_of(indexingMaps, [&dims](AffineMap m) { return m.getResults().size() != dims.batch.size() + 2 /* any two of {m,n,k} */; })) return failure(); auto numOfBatchDims = dims.batch.size(); if (indexingMaps[0].getNumDims() != numOfBatchDims + 3) return failure(); if (numOfBatchDims) { // Each operand in a linalg generic contraction could express different // permutations for its batch dimension. But for named op it must be // identity since separate maps are not specified. if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) { for (unsigned i = 0; i < numOfBatchDims; ++i) { auto expr = m.getResults()[i]; if (expr.getKind() != AffineExprKind::DimId || cast(expr).getPosition() != i) return true; } return false; })) return failure(); } auto a = matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]); auto b = matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]); auto c = matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]); if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch)) return failure(); if (c != IndexMatchResult::Match || (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed)) return failure(); /// Codegen the different matmul variants. if (numOfBatchDims) { if (a == IndexMatchResult::Transposed) return replaceWithMatmulVariant(rewriter, genericOp); if (b == IndexMatchResult::Transposed) return replaceWithMatmulVariant(rewriter, genericOp); return replaceWithMatmulVariant(rewriter, genericOp); } if (a == IndexMatchResult::Transposed) return replaceWithMatmulVariant(rewriter, genericOp); if (b == IndexMatchResult::Transposed) return replaceWithMatmulVariant(rewriter, genericOp); return replaceWithMatmulVariant(rewriter, genericOp); } } // namespace //===----------------------------------------------------------------------===// // Categorize linalg generic to named op where possible. //===----------------------------------------------------------------------===// FailureOr mlir::linalg::specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp) { // Copy if (isaCopyOpInterface(genericOp)) { LinalgOp namedOp = rewriter.replaceOpWithNewOp( genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]); return namedOp; } // Fill if (isaFillOpInterface(genericOp)) { LinalgOp namedOp = rewriter.replaceOpWithNewOp( genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]); return namedOp; } // Broadcast std::optional> equivalentToBroadcast = isaBroadcastOpInterface(genericOp); if (equivalentToBroadcast) { auto dims = *equivalentToBroadcast; LinalgOp namedOp = rewriter.replaceOpWithNewOp( genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0], dims); return namedOp; } // Transpose std::optional> equivalentToTranspose = isaTransposeOpInterface(genericOp); if (equivalentToTranspose) { auto permutation = *equivalentToTranspose; LinalgOp namedOp = rewriter.replaceOpWithNewOp( genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0], permutation); return namedOp; } // Elementwise Unary if (isaElemwiseSingleUnaryOpInterface(genericOp)) { Operation *op = &genericOp.getBody()->front(); if (isa(op)) { LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp); return namedOp; } } // Elementwise Binary if (isaElemwiseSingleBinaryOpInterface(genericOp)) { bool swap = areBinOpsSwapped(genericOp); Operation *op = &genericOp.getBody()->front(); if (isa(op)) { LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap); return namedOp; } if (isa(op)) { LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap); return namedOp; } if (isa(op)) { LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap); return namedOp; } if (isa(op)) { LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap); return namedOp; } } // Contraction - e.g. matmul if (isaContractionOpInterface(genericOp)) { return specializeLinalgContractions(rewriter, genericOp); } return failure(); } namespace { struct LinalgSpecializeGenericOpsPass : public impl::LinalgSpecializeGenericOpsPassBase< LinalgSpecializeGenericOpsPass> { using impl::LinalgSpecializeGenericOpsPassBase< LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase; void runOnOperation() override; }; } // namespace void LinalgSpecializeGenericOpsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateLinalgGenericOpsSpecializationPatterns(patterns); populateDecomposeProjectedPermutationPatterns(patterns); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); }