153a0d45dSSean Silva //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
253a0d45dSSean Silva //
353a0d45dSSean Silva // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
453a0d45dSSean Silva // See https://llvm.org/LICENSE.txt for license information.
553a0d45dSSean Silva // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
653a0d45dSSean Silva //
753a0d45dSSean Silva //===----------------------------------------------------------------------===//
853a0d45dSSean Silva
953a0d45dSSean Silva #include "mlir/Dialect/Linalg/Passes.h"
1053a0d45dSSean Silva
11abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h"
12b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
13ea069aebSMaheshRavishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14f75f391fSRob Suderman #include "mlir/Dialect/Linalg/Utils/Utils.h"
1553a0d45dSSean Silva #include "mlir/Transforms/DialectConversion.h"
1653a0d45dSSean Silva
1767d0d7acSMichele Scuttari namespace mlir {
181e98d488SQuinn Dawkins #define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS
1967d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h.inc"
2067d0d7acSMichele Scuttari } // namespace mlir
2167d0d7acSMichele Scuttari
2253a0d45dSSean Silva using namespace mlir;
2353a0d45dSSean Silva
isElementwiseMappableOpOnRankedTensors(Operation * op)2453a0d45dSSean Silva static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
25bcc9b371SFrederik Gossen if (!OpTrait::hasElementwiseMappableTraits(op))
2653a0d45dSSean Silva return false;
2753a0d45dSSean Silva
2853a0d45dSSean Silva // TODO: The conversion pattern can be made to work for `any_of` here, but
2953a0d45dSSean Silva // it's more complex as it requires tracking which operands are scalars.
30*971b8525SJakub Kuderski return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
3153a0d45dSSean Silva }
3253a0d45dSSean Silva
33b7ae1d3dSnicolasvasilache /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
34b7ae1d3dSnicolasvasilache /// the result types and return a list of values such that, for each result type
35b7ae1d3dSnicolasvasilache /// `t` and value `v` at the same index `idx`:
36b7ae1d3dSnicolasvasilache /// 1. `v.getType() == t`
37b7ae1d3dSnicolasvasilache /// 2. If an operand of `op` has type `t`, let `operand_first` be the first
38b7ae1d3dSnicolasvasilache /// such operand. Then`v == operand_first`.
3981ca5aa4SMatthias Springer /// 3. Otherwise, v is a newly created `tensor::EmptyOp` with:
40b7ae1d3dSnicolasvasilache /// a. Static and dynamic dims extracted from the first operand of `op`.
41b7ae1d3dSnicolasvasilache /// b. Elemental type equal to the elemental type of `t`.
42b7ae1d3dSnicolasvasilache ///
43b7ae1d3dSnicolasvasilache /// This is sufficient because ElementwiseMappable guarantees that "The static
44b7ae1d3dSnicolasvasilache /// types of all vector (resp. tensor) operands and results must have the same
45b7ae1d3dSnicolasvasilache /// shape".
46b7ae1d3dSnicolasvasilache static SmallVector<Value, 4>
getOrCreateOperandsMatchingResultTypes(OpBuilder & b,Operation * op)47b7ae1d3dSnicolasvasilache getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) {
48b7ae1d3dSnicolasvasilache assert(isElementwiseMappableOpOnRankedTensors(op));
49b7ae1d3dSnicolasvasilache Location loc = op->getLoc();
50b7ae1d3dSnicolasvasilache ValueRange operands = op->getOperands();
51b7ae1d3dSnicolasvasilache TypeRange rankedTensorTypes = op->getResultTypes();
52b7ae1d3dSnicolasvasilache SmallVector<Value, 4> res;
53b7ae1d3dSnicolasvasilache res.reserve(rankedTensorTypes.size());
54b7ae1d3dSnicolasvasilache for (Type t : rankedTensorTypes) {
55b7ae1d3dSnicolasvasilache // Try to find an operand with type matching the result tensor.
56b7ae1d3dSnicolasvasilache bool found = false;
57b7ae1d3dSnicolasvasilache for (Value v : operands) {
58b7ae1d3dSnicolasvasilache if (v.getType() == t) {
59b7ae1d3dSnicolasvasilache found = true;
60b7ae1d3dSnicolasvasilache res.push_back(v);
61b7ae1d3dSnicolasvasilache break;
62b7ae1d3dSnicolasvasilache }
63b7ae1d3dSnicolasvasilache }
64b7ae1d3dSnicolasvasilache if (found)
65b7ae1d3dSnicolasvasilache continue;
66b7ae1d3dSnicolasvasilache
67b7ae1d3dSnicolasvasilache // Extract static / dynamic shape mix from the first operand.
6881ca5aa4SMatthias Springer res.push_back(b.create<tensor::EmptyOp>(
69be6d96e9SMatthias Springer loc, tensor::getMixedSizes(b, loc, operands.front()),
70be6d96e9SMatthias Springer cast<RankedTensorType>(t).getElementType()));
71b7ae1d3dSnicolasvasilache }
72b7ae1d3dSnicolasvasilache return res;
73b7ae1d3dSnicolasvasilache }
74b7ae1d3dSnicolasvasilache
7553a0d45dSSean Silva namespace {
765488a6b0SSean Silva struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
ConvertAnyElementwiseMappableOpOnRankedTensors__anonab5fbb810111::ConvertAnyElementwiseMappableOpOnRankedTensors7776f3c2f3SRiver Riddle ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
7876f3c2f3SRiver Riddle : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
matchAndRewrite__anonab5fbb810111::ConvertAnyElementwiseMappableOpOnRankedTensors7953a0d45dSSean Silva LogicalResult matchAndRewrite(Operation *op,
8053a0d45dSSean Silva PatternRewriter &rewriter) const final {
8153a0d45dSSean Silva if (!isElementwiseMappableOpOnRankedTensors(op))
8253a0d45dSSean Silva return rewriter.notifyMatchFailure(
8353a0d45dSSean Silva op, "requires elementwise op on ranked tensors");
8453a0d45dSSean Silva
855550c821STres Popp auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
8653a0d45dSSean Silva SmallVector<AffineMap, 3> indexingMaps(
8753a0d45dSSean Silva op->getNumResults() + op->getNumOperands(),
8853a0d45dSSean Silva rewriter.getMultiDimIdentityMap(rank));
89e6598b05SOleg Shyshkov SmallVector<utils::IteratorType, 6> iteratorTypes(
90e6598b05SOleg Shyshkov rank, utils::IteratorType::parallel);
91b7ae1d3dSnicolasvasilache auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
9253a0d45dSSean Silva rewriter.replaceOpWithNewOp<linalg::GenericOp>(
9353a0d45dSSean Silva op, /*resultTensorTypes=*/op->getResultTypes(),
9453a0d45dSSean Silva /*inputs=*/op->getOperands(),
95b7ae1d3dSnicolasvasilache /*outputs=*/outputs,
9653a0d45dSSean Silva /*indexingMaps=*/indexingMaps,
9753a0d45dSSean Silva /*iteratorTypes=*/iteratorTypes,
9853a0d45dSSean Silva /*bodyBuilder=*/
9953a0d45dSSean Silva [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
10053a0d45dSSean Silva auto resultTypes = llvm::to_vector<6>(
10153a0d45dSSean Silva llvm::map_range(op->getResultTypes(), [](Type type) {
1025550c821STres Popp return cast<TensorType>(type).getElementType();
10353a0d45dSSean Silva }));
10414ecafd0SChia-hung Duan auto *scalarOp =
10514ecafd0SChia-hung Duan builder.create(loc, op->getName().getIdentifier(),
10614ecafd0SChia-hung Duan regionArgs.take_front(op->getNumOperands()),
10714ecafd0SChia-hung Duan resultTypes, op->getAttrs());
10853a0d45dSSean Silva builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
10953a0d45dSSean Silva });
11053a0d45dSSean Silva return success();
11153a0d45dSSean Silva }
11253a0d45dSSean Silva };
11353a0d45dSSean Silva } // namespace
11453a0d45dSSean Silva
populateElementwiseToLinalgConversionPatterns(RewritePatternSet & patterns)115ea069aebSMaheshRavishankar void mlir::linalg::populateElementwiseToLinalgConversionPatterns(
116dc4e913bSChris Lattner RewritePatternSet &patterns) {
11776f3c2f3SRiver Riddle patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
11876f3c2f3SRiver Riddle patterns.getContext());
11953a0d45dSSean Silva }
12053a0d45dSSean Silva
12153a0d45dSSean Silva namespace {
12253a0d45dSSean Silva class ConvertElementwiseToLinalgPass
1231e98d488SQuinn Dawkins : public impl::ConvertElementwiseToLinalgPassBase<
12467d0d7acSMichele Scuttari ConvertElementwiseToLinalgPass> {
1251e98d488SQuinn Dawkins using impl::ConvertElementwiseToLinalgPassBase<
1261e98d488SQuinn Dawkins ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase;
12753a0d45dSSean Silva
runOnOperation()128c10995a8SStella Laurenzo void runOnOperation() final {
12902b6fb21SMehdi Amini auto *func = getOperation();
13053a0d45dSSean Silva auto *context = &getContext();
13153a0d45dSSean Silva ConversionTarget target(*context);
132dc4e913bSChris Lattner RewritePatternSet patterns(context);
13353a0d45dSSean Silva
134ea069aebSMaheshRavishankar mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns);
13553a0d45dSSean Silva target.markUnknownOpDynamicallyLegal([](Operation *op) {
13653a0d45dSSean Silva return !isElementwiseMappableOpOnRankedTensors(op);
13753a0d45dSSean Silva });
13853a0d45dSSean Silva
13953a0d45dSSean Silva if (failed(applyPartialConversion(func, target, std::move(patterns))))
14053a0d45dSSean Silva signalPassFailure();
14153a0d45dSSean Silva }
14253a0d45dSSean Silva };
14353a0d45dSSean Silva } // namespace
144