xref: /llvm-project/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp (revision 5ce271ef74dd3325993c827f496e460ced41af11)
1f0cb77d7SRob Suderman //===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===//
2f0cb77d7SRob Suderman //
3f0cb77d7SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f0cb77d7SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5f0cb77d7SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f0cb77d7SRob Suderman //
7f0cb77d7SRob Suderman //===----------------------------------------------------------------------===//
8f0cb77d7SRob Suderman //
9f0cb77d7SRob Suderman // This transformation pass legalizes Tosa operations to the Linalg dialect.
10f0cb77d7SRob Suderman //
11f0cb77d7SRob Suderman //===----------------------------------------------------------------------===//
12f0cb77d7SRob Suderman 
13f0cb77d7SRob Suderman #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
1467d0d7acSMichele Scuttari 
15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1667d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h"
17f0cb77d7SRob Suderman #include "mlir/Dialect/Linalg/IR/Linalg.h"
18f0cb77d7SRob Suderman #include "mlir/Dialect/Math/IR/Math.h"
198b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
20f0cb77d7SRob Suderman #include "mlir/Dialect/Tensor/IR/Tensor.h"
21f0cb77d7SRob Suderman #include "mlir/Dialect/Tosa/IR/TosaOps.h"
22f0cb77d7SRob Suderman #include "mlir/Dialect/Tosa/Transforms/Passes.h"
23f0cb77d7SRob Suderman #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
24f0cb77d7SRob Suderman #include "mlir/IR/PatternMatch.h"
25f0cb77d7SRob Suderman #include "mlir/Pass/PassManager.h"
26f0cb77d7SRob Suderman #include "mlir/Transforms/DialectConversion.h"
27f0cb77d7SRob Suderman #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28f0cb77d7SRob Suderman 
2967d0d7acSMichele Scuttari namespace mlir {
3067d0d7acSMichele Scuttari #define GEN_PASS_DEF_TOSATOLINALGNAMED
3167d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
3267d0d7acSMichele Scuttari } // namespace mlir
3367d0d7acSMichele Scuttari 
34f0cb77d7SRob Suderman using namespace mlir;
35f0cb77d7SRob Suderman 
36f0cb77d7SRob Suderman namespace {
3767d0d7acSMichele Scuttari struct TosaToLinalgNamed
3867d0d7acSMichele Scuttari     : public impl::TosaToLinalgNamedBase<TosaToLinalgNamed> {
39f0cb77d7SRob Suderman public:
40acc6f3e9Sbjacob   TosaToLinalgNamed(const TosaToLinalgNamedOptions &options)
41acc6f3e9Sbjacob       : impl::TosaToLinalgNamedBase<TosaToLinalgNamed>(options) {}
42acc6f3e9Sbjacob 
43f0cb77d7SRob Suderman   void getDependentDialects(DialectRegistry &registry) const override {
441f971e23SRiver Riddle     registry
45abc362a1SJakub Kuderski         .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
46abc362a1SJakub Kuderski                 tensor::TensorDialect, scf::SCFDialect>();
47f0cb77d7SRob Suderman   }
48f0cb77d7SRob Suderman 
4941574554SRiver Riddle   void runOnOperation() override {
50*5ce271efSMatthias Gehre     TypeConverter converter;
51*5ce271efSMatthias Gehre     tosa::populateTosaTypeConversion(converter);
52*5ce271efSMatthias Gehre 
53f0cb77d7SRob Suderman     RewritePatternSet patterns(&getContext());
54f0cb77d7SRob Suderman     ConversionTarget target(getContext());
551f971e23SRiver Riddle     target.addLegalDialect<linalg::LinalgDialect, tosa::TosaDialect,
561f971e23SRiver Riddle                            tensor::TensorDialect, scf::SCFDialect>();
57f0cb77d7SRob Suderman 
58f0cb77d7SRob Suderman     // Not every TOSA op can be legalized to linalg.
59f0cb77d7SRob Suderman     target.addIllegalOp<tosa::Conv2DOp>();
607ce53e31SRob Suderman     target.addIllegalOp<tosa::Conv3DOp>();
61f0cb77d7SRob Suderman     target.addIllegalOp<tosa::DepthwiseConv2DOp>();
62f0cb77d7SRob Suderman     target.addIllegalOp<tosa::MaxPool2dOp>();
63f0cb77d7SRob Suderman     target.addIllegalOp<tosa::AvgPool2dOp>();
64f0cb77d7SRob Suderman     target.addIllegalOp<tosa::MatMulOp>();
65f0cb77d7SRob Suderman     target.addIllegalOp<tosa::FullyConnectedOp>();
6611ac97c6SFelix Schneider     target.addIllegalOp<tosa::TransposeOp>();
67f0cb77d7SRob Suderman 
68f0cb77d7SRob Suderman     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
69f0cb77d7SRob Suderman 
7047f175b0SRiver Riddle     FunctionOpInterface func = getOperation();
71acc6f3e9Sbjacob     TosaToLinalgNamedOptions options;
72acc6f3e9Sbjacob     options.preferConv2DKernelLayoutHWCF = preferConv2DKernelLayoutHWCF;
73*5ce271efSMatthias Gehre     tosa::populateTosaToLinalgNamedConversionPatterns(converter, &patterns,
74*5ce271efSMatthias Gehre                                                       options);
75f0cb77d7SRob Suderman     if (failed(applyFullConversion(func, target, std::move(patterns))))
76f0cb77d7SRob Suderman       signalPassFailure();
77f0cb77d7SRob Suderman   }
78f0cb77d7SRob Suderman };
79f0cb77d7SRob Suderman } // namespace
80039b969bSMichele Scuttari 
81acc6f3e9Sbjacob std::unique_ptr<Pass>
82acc6f3e9Sbjacob mlir::tosa::createTosaToLinalgNamed(const TosaToLinalgNamedOptions &options) {
83acc6f3e9Sbjacob   return std::make_unique<TosaToLinalgNamed>(options);
84039b969bSMichele Scuttari }
85