xref: /llvm-project/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp (revision f09db6a3af971ab7d9bbc7ba574a8dc0c10b2940)
11d973b7dSRob Suderman //===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===//
21d973b7dSRob Suderman //
31d973b7dSRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41d973b7dSRob Suderman // See https://llvm.org/LICENSE.txt for license information.
51d973b7dSRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61d973b7dSRob Suderman //
71d973b7dSRob Suderman //===----------------------------------------------------------------------===//
81d973b7dSRob Suderman //
91d973b7dSRob Suderman // This transformation pass legalizes Tosa operations to the Linalg dialect.
101d973b7dSRob Suderman //
111d973b7dSRob Suderman //===----------------------------------------------------------------------===//
121d973b7dSRob Suderman 
131d973b7dSRob Suderman #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
1467d0d7acSMichele Scuttari 
15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1667d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h"
17ee0284ecSDmitriy Smirnov #include "mlir/Dialect/Index/IR/IndexDialect.h"
18b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
194348d8abSStephan Herhut #include "mlir/Dialect/Math/IR/Math.h"
208b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
214157a079SRob Suderman #include "mlir/Dialect/Tensor/IR/Tensor.h"
221d973b7dSRob Suderman #include "mlir/Dialect/Tosa/IR/TosaOps.h"
231d973b7dSRob Suderman #include "mlir/Dialect/Tosa/Transforms/Passes.h"
241d973b7dSRob Suderman #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
251d973b7dSRob Suderman #include "mlir/IR/PatternMatch.h"
261d973b7dSRob Suderman #include "mlir/Pass/PassManager.h"
271d973b7dSRob Suderman #include "mlir/Transforms/DialectConversion.h"
281d973b7dSRob Suderman #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29f0cb77d7SRob Suderman #include "mlir/Transforms/Passes.h"
301d973b7dSRob Suderman 
3167d0d7acSMichele Scuttari namespace mlir {
3267d0d7acSMichele Scuttari #define GEN_PASS_DEF_TOSATOLINALG
3367d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
3467d0d7acSMichele Scuttari } // namespace mlir
3567d0d7acSMichele Scuttari 
361d973b7dSRob Suderman using namespace mlir;
371d973b7dSRob Suderman 
381d973b7dSRob Suderman namespace {
3967d0d7acSMichele Scuttari struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
401d973b7dSRob Suderman public:
411d973b7dSRob Suderman   void getDependentDialects(DialectRegistry &registry) const override {
421f971e23SRiver Riddle     registry
43abc362a1SJakub Kuderski         .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
44ee0284ecSDmitriy Smirnov                 index::IndexDialect, tensor::TensorDialect, scf::SCFDialect>();
451d973b7dSRob Suderman   }
461d973b7dSRob Suderman 
4741574554SRiver Riddle   void runOnOperation() override {
48dc4e913bSChris Lattner     RewritePatternSet patterns(&getContext());
491d973b7dSRob Suderman     ConversionTarget target(getContext());
501f971e23SRiver Riddle     target.addLegalDialect<linalg::LinalgDialect, tensor::TensorDialect,
511f971e23SRiver Riddle                            scf::SCFDialect>();
521d973b7dSRob Suderman     target.addIllegalDialect<tosa::TosaDialect>();
53286a9d46SRob Suderman 
54286a9d46SRob Suderman     // Not every TOSA op can be legalized to linalg.
55286a9d46SRob Suderman     target.addLegalOp<tosa::ApplyScaleOp>();
56286a9d46SRob Suderman     target.addLegalOp<tosa::IfOp>();
57286a9d46SRob Suderman     target.addLegalOp<tosa::ConstOp>();
58*f09db6a3SJerry-Ge     target.addLegalOp<tosa::ConstShapeOp>();
59286a9d46SRob Suderman     target.addLegalOp<tosa::WhileOp>();
60fbf719b8SMaya Amrami     target.addLegalOp<tosa::ConcatOp>();
6154eec7caSRob Suderman     target.addLegalOp<tosa::SliceOp>();
62723979efSKrzysztof Drewniak     target.addLegalOp<tosa::ReshapeOp>();
632a196254SRamkumar Ramachandra     target.addLegalOp<tosa::PadOp>();
64286a9d46SRob Suderman 
651d973b7dSRob Suderman     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
661d973b7dSRob Suderman 
678d237190SMatthias Gehre     TypeConverter converter;
688d237190SMatthias Gehre     tosa::populateTosaTypeConversion(converter);
698d237190SMatthias Gehre 
7047f175b0SRiver Riddle     FunctionOpInterface func = getOperation();
718d237190SMatthias Gehre     mlir::tosa::populateTosaToLinalgConversionPatterns(converter, &patterns);
721d973b7dSRob Suderman     if (failed(applyFullConversion(func, target, std::move(patterns))))
731d973b7dSRob Suderman       signalPassFailure();
741d973b7dSRob Suderman   }
751d973b7dSRob Suderman };
761d973b7dSRob Suderman } // namespace
771d973b7dSRob Suderman 
78039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
79039b969bSMichele Scuttari   return std::make_unique<TosaToLinalg>();
80039b969bSMichele Scuttari }
81039b969bSMichele Scuttari 
8232b7c1ffSBenjamin Maxwell void mlir::tosa::addTosaToLinalgPasses(
839dd15f74SAmir Bishara     OpPassManager &pm, const TosaToLinalgOptions &options,
84acc6f3e9Sbjacob     const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions,
85ecce5ccdSMatthias Gehre     std::optional<tosa::TosaValidationOptions> validationOptions) {
86173fce42SRob Suderman   // Optional decompositions are designed to benefit linalg.
879dd15f74SAmir Bishara   if (!options.disableTosaDecompositions)
88039b969bSMichele Scuttari     pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
8958ceae95SRiver Riddle   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
90173fce42SRob Suderman 
91309bfecfSAviad Cohen   pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass());
9258ceae95SRiver Riddle   pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
93acc6f3e9Sbjacob   pm.addNestedPass<func::FuncOp>(
94acc6f3e9Sbjacob       tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions));
9558ceae95SRiver Riddle   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
963bcaf2ebSGeorgios Pinitas   // TODO: Remove pass that operates on const tensor and enable optionality
979dd15f74SAmir Bishara   pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
989dd15f74SAmir Bishara       {options.aggressiveReduceConstant}));
9958ceae95SRiver Riddle   pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
100ecce5ccdSMatthias Gehre   if (validationOptions)
101ecce5ccdSMatthias Gehre     pm.addPass(tosa::createTosaValidation(*validationOptions));
102039b969bSMichele Scuttari   pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
1031d973b7dSRob Suderman }
104cfc922fcSTai Ly 
105cfc922fcSTai Ly //===----------------------------------------------------------------------===//
106cfc922fcSTai Ly // Pipeline registration.
107cfc922fcSTai Ly //===----------------------------------------------------------------------===//
108cfc922fcSTai Ly 
109cfc922fcSTai Ly void mlir::tosa::registerTosaToLinalgPipelines() {
110cfc922fcSTai Ly   PassPipelineRegistration<>(
111cfc922fcSTai Ly       "tosa-to-linalg-pipeline",
112cfc922fcSTai Ly       "The default pipeline for converting TOSA operators to the equivalent "
113cfc922fcSTai Ly       "operations using the tensor operations in LinAlg as well as LinAlg "
114cfc922fcSTai Ly       "named operations.",
115cfc922fcSTai Ly       [](OpPassManager &pm) {
116cfc922fcSTai Ly         TosaToLinalgOptions tosaToLinalgOptions;
117acc6f3e9Sbjacob         TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
118ecce5ccdSMatthias Gehre         TosaValidationOptions validationOptions;
119cc9e7cb9STatWai Chong         validationOptions.profile = {"none"};
120ecce5ccdSMatthias Gehre         validationOptions.StrictOperationSpecAlignment = true;
121ecce5ccdSMatthias Gehre         validationOptions.level = tosa::TosaLevelEnum::EightK;
122cfc922fcSTai Ly         tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
123acc6f3e9Sbjacob                                     tosaToLinalgNamedOptions,
124ecce5ccdSMatthias Gehre                                     validationOptions);
125cfc922fcSTai Ly       });
126cfc922fcSTai Ly }
127