xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp (revision 971b852546a7d96bc8887ced913724b884cf40df)
153a0d45dSSean Silva //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
253a0d45dSSean Silva //
353a0d45dSSean Silva // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
453a0d45dSSean Silva // See https://llvm.org/LICENSE.txt for license information.
553a0d45dSSean Silva // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
653a0d45dSSean Silva //
753a0d45dSSean Silva //===----------------------------------------------------------------------===//
853a0d45dSSean Silva 
953a0d45dSSean Silva #include "mlir/Dialect/Linalg/Passes.h"
1053a0d45dSSean Silva 
11abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h"
12b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
13ea069aebSMaheshRavishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14f75f391fSRob Suderman #include "mlir/Dialect/Linalg/Utils/Utils.h"
1553a0d45dSSean Silva #include "mlir/Transforms/DialectConversion.h"
1653a0d45dSSean Silva 
1767d0d7acSMichele Scuttari namespace mlir {
181e98d488SQuinn Dawkins #define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS
1967d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h.inc"
2067d0d7acSMichele Scuttari } // namespace mlir
2167d0d7acSMichele Scuttari 
2253a0d45dSSean Silva using namespace mlir;
2353a0d45dSSean Silva 
isElementwiseMappableOpOnRankedTensors(Operation * op)2453a0d45dSSean Silva static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
25bcc9b371SFrederik Gossen   if (!OpTrait::hasElementwiseMappableTraits(op))
2653a0d45dSSean Silva     return false;
2753a0d45dSSean Silva 
2853a0d45dSSean Silva   // TODO: The conversion pattern can be made to work for `any_of` here, but
2953a0d45dSSean Silva   // it's more complex as it requires tracking which operands are scalars.
30*971b8525SJakub Kuderski   return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
3153a0d45dSSean Silva }
3253a0d45dSSean Silva 
33b7ae1d3dSnicolasvasilache /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
34b7ae1d3dSnicolasvasilache /// the result types and return a list of values such that, for each result type
35b7ae1d3dSnicolasvasilache /// `t` and value `v` at the same index `idx`:
36b7ae1d3dSnicolasvasilache ///   1. `v.getType() == t`
37b7ae1d3dSnicolasvasilache ///   2. If an operand of `op` has type `t`, let `operand_first` be the first
38b7ae1d3dSnicolasvasilache ///      such operand. Then`v == operand_first`.
3981ca5aa4SMatthias Springer ///   3. Otherwise, v is a newly created `tensor::EmptyOp` with:
40b7ae1d3dSnicolasvasilache ///        a. Static and dynamic dims extracted from the first operand of `op`.
41b7ae1d3dSnicolasvasilache ///        b. Elemental type equal to the elemental type of `t`.
42b7ae1d3dSnicolasvasilache ///
43b7ae1d3dSnicolasvasilache /// This is sufficient because ElementwiseMappable guarantees that "The static
44b7ae1d3dSnicolasvasilache /// types of all vector (resp. tensor) operands and results must have the same
45b7ae1d3dSnicolasvasilache /// shape".
46b7ae1d3dSnicolasvasilache static SmallVector<Value, 4>
getOrCreateOperandsMatchingResultTypes(OpBuilder & b,Operation * op)47b7ae1d3dSnicolasvasilache getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) {
48b7ae1d3dSnicolasvasilache   assert(isElementwiseMappableOpOnRankedTensors(op));
49b7ae1d3dSnicolasvasilache   Location loc = op->getLoc();
50b7ae1d3dSnicolasvasilache   ValueRange operands = op->getOperands();
51b7ae1d3dSnicolasvasilache   TypeRange rankedTensorTypes = op->getResultTypes();
52b7ae1d3dSnicolasvasilache   SmallVector<Value, 4> res;
53b7ae1d3dSnicolasvasilache   res.reserve(rankedTensorTypes.size());
54b7ae1d3dSnicolasvasilache   for (Type t : rankedTensorTypes) {
55b7ae1d3dSnicolasvasilache     // Try to find an operand with type matching the result tensor.
56b7ae1d3dSnicolasvasilache     bool found = false;
57b7ae1d3dSnicolasvasilache     for (Value v : operands) {
58b7ae1d3dSnicolasvasilache       if (v.getType() == t) {
59b7ae1d3dSnicolasvasilache         found = true;
60b7ae1d3dSnicolasvasilache         res.push_back(v);
61b7ae1d3dSnicolasvasilache         break;
62b7ae1d3dSnicolasvasilache       }
63b7ae1d3dSnicolasvasilache     }
64b7ae1d3dSnicolasvasilache     if (found)
65b7ae1d3dSnicolasvasilache       continue;
66b7ae1d3dSnicolasvasilache 
67b7ae1d3dSnicolasvasilache     // Extract static / dynamic shape mix from the first operand.
6881ca5aa4SMatthias Springer     res.push_back(b.create<tensor::EmptyOp>(
69be6d96e9SMatthias Springer         loc, tensor::getMixedSizes(b, loc, operands.front()),
70be6d96e9SMatthias Springer         cast<RankedTensorType>(t).getElementType()));
71b7ae1d3dSnicolasvasilache   }
72b7ae1d3dSnicolasvasilache   return res;
73b7ae1d3dSnicolasvasilache }
74b7ae1d3dSnicolasvasilache 
7553a0d45dSSean Silva namespace {
765488a6b0SSean Silva struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
ConvertAnyElementwiseMappableOpOnRankedTensors__anonab5fbb810111::ConvertAnyElementwiseMappableOpOnRankedTensors7776f3c2f3SRiver Riddle   ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
7876f3c2f3SRiver Riddle       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
matchAndRewrite__anonab5fbb810111::ConvertAnyElementwiseMappableOpOnRankedTensors7953a0d45dSSean Silva   LogicalResult matchAndRewrite(Operation *op,
8053a0d45dSSean Silva                                 PatternRewriter &rewriter) const final {
8153a0d45dSSean Silva     if (!isElementwiseMappableOpOnRankedTensors(op))
8253a0d45dSSean Silva       return rewriter.notifyMatchFailure(
8353a0d45dSSean Silva           op, "requires elementwise op on ranked tensors");
8453a0d45dSSean Silva 
855550c821STres Popp     auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
8653a0d45dSSean Silva     SmallVector<AffineMap, 3> indexingMaps(
8753a0d45dSSean Silva         op->getNumResults() + op->getNumOperands(),
8853a0d45dSSean Silva         rewriter.getMultiDimIdentityMap(rank));
89e6598b05SOleg Shyshkov     SmallVector<utils::IteratorType, 6> iteratorTypes(
90e6598b05SOleg Shyshkov         rank, utils::IteratorType::parallel);
91b7ae1d3dSnicolasvasilache     auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
9253a0d45dSSean Silva     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
9353a0d45dSSean Silva         op, /*resultTensorTypes=*/op->getResultTypes(),
9453a0d45dSSean Silva         /*inputs=*/op->getOperands(),
95b7ae1d3dSnicolasvasilache         /*outputs=*/outputs,
9653a0d45dSSean Silva         /*indexingMaps=*/indexingMaps,
9753a0d45dSSean Silva         /*iteratorTypes=*/iteratorTypes,
9853a0d45dSSean Silva         /*bodyBuilder=*/
9953a0d45dSSean Silva         [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
10053a0d45dSSean Silva           auto resultTypes = llvm::to_vector<6>(
10153a0d45dSSean Silva               llvm::map_range(op->getResultTypes(), [](Type type) {
1025550c821STres Popp                 return cast<TensorType>(type).getElementType();
10353a0d45dSSean Silva               }));
10414ecafd0SChia-hung Duan           auto *scalarOp =
10514ecafd0SChia-hung Duan               builder.create(loc, op->getName().getIdentifier(),
10614ecafd0SChia-hung Duan                              regionArgs.take_front(op->getNumOperands()),
10714ecafd0SChia-hung Duan                              resultTypes, op->getAttrs());
10853a0d45dSSean Silva           builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
10953a0d45dSSean Silva         });
11053a0d45dSSean Silva     return success();
11153a0d45dSSean Silva   }
11253a0d45dSSean Silva };
11353a0d45dSSean Silva } // namespace
11453a0d45dSSean Silva 
populateElementwiseToLinalgConversionPatterns(RewritePatternSet & patterns)115ea069aebSMaheshRavishankar void mlir::linalg::populateElementwiseToLinalgConversionPatterns(
116dc4e913bSChris Lattner     RewritePatternSet &patterns) {
11776f3c2f3SRiver Riddle   patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
11876f3c2f3SRiver Riddle       patterns.getContext());
11953a0d45dSSean Silva }
12053a0d45dSSean Silva 
12153a0d45dSSean Silva namespace {
12253a0d45dSSean Silva class ConvertElementwiseToLinalgPass
1231e98d488SQuinn Dawkins     : public impl::ConvertElementwiseToLinalgPassBase<
12467d0d7acSMichele Scuttari           ConvertElementwiseToLinalgPass> {
1251e98d488SQuinn Dawkins   using impl::ConvertElementwiseToLinalgPassBase<
1261e98d488SQuinn Dawkins       ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase;
12753a0d45dSSean Silva 
runOnOperation()128c10995a8SStella Laurenzo   void runOnOperation() final {
12902b6fb21SMehdi Amini     auto *func = getOperation();
13053a0d45dSSean Silva     auto *context = &getContext();
13153a0d45dSSean Silva     ConversionTarget target(*context);
132dc4e913bSChris Lattner     RewritePatternSet patterns(context);
13353a0d45dSSean Silva 
134ea069aebSMaheshRavishankar     mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns);
13553a0d45dSSean Silva     target.markUnknownOpDynamicallyLegal([](Operation *op) {
13653a0d45dSSean Silva       return !isElementwiseMappableOpOnRankedTensors(op);
13753a0d45dSSean Silva     });
13853a0d45dSSean Silva 
13953a0d45dSSean Silva     if (failed(applyPartialConversion(func, target, std::move(patterns))))
14053a0d45dSSean Silva       signalPassFailure();
14153a0d45dSSean Silva   }
14253a0d45dSSean Silva };
14353a0d45dSSean Silva } // namespace
144