11d973b7dSRob Suderman //===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===// 21d973b7dSRob Suderman // 31d973b7dSRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 41d973b7dSRob Suderman // See https://llvm.org/LICENSE.txt for license information. 51d973b7dSRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 61d973b7dSRob Suderman // 71d973b7dSRob Suderman //===----------------------------------------------------------------------===// 81d973b7dSRob Suderman // 91d973b7dSRob Suderman // This transformation pass legalizes Tosa operations to the Linalg dialect. 101d973b7dSRob Suderman // 111d973b7dSRob Suderman //===----------------------------------------------------------------------===// 121d973b7dSRob Suderman 131d973b7dSRob Suderman #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" 1467d0d7acSMichele Scuttari 15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 1667d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h" 17ee0284ecSDmitriy Smirnov #include "mlir/Dialect/Index/IR/IndexDialect.h" 18b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h" 194348d8abSStephan Herhut #include "mlir/Dialect/Math/IR/Math.h" 208b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 214157a079SRob Suderman #include "mlir/Dialect/Tensor/IR/Tensor.h" 221d973b7dSRob Suderman #include "mlir/Dialect/Tosa/IR/TosaOps.h" 231d973b7dSRob Suderman #include "mlir/Dialect/Tosa/Transforms/Passes.h" 241d973b7dSRob Suderman #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" 251d973b7dSRob Suderman #include "mlir/IR/PatternMatch.h" 261d973b7dSRob Suderman #include "mlir/Pass/PassManager.h" 271d973b7dSRob Suderman #include "mlir/Transforms/DialectConversion.h" 281d973b7dSRob Suderman #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29f0cb77d7SRob Suderman #include "mlir/Transforms/Passes.h" 301d973b7dSRob Suderman 3167d0d7acSMichele Scuttari namespace mlir { 3267d0d7acSMichele Scuttari #define GEN_PASS_DEF_TOSATOLINALG 3367d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 3467d0d7acSMichele Scuttari } // namespace mlir 3567d0d7acSMichele Scuttari 361d973b7dSRob Suderman using namespace mlir; 371d973b7dSRob Suderman 381d973b7dSRob Suderman namespace { 3967d0d7acSMichele Scuttari struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> { 401d973b7dSRob Suderman public: 411d973b7dSRob Suderman void getDependentDialects(DialectRegistry ®istry) const override { 421f971e23SRiver Riddle registry 43abc362a1SJakub Kuderski .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect, 44ee0284ecSDmitriy Smirnov index::IndexDialect, tensor::TensorDialect, scf::SCFDialect>(); 451d973b7dSRob Suderman } 461d973b7dSRob Suderman 4741574554SRiver Riddle void runOnOperation() override { 48dc4e913bSChris Lattner RewritePatternSet patterns(&getContext()); 491d973b7dSRob Suderman ConversionTarget target(getContext()); 501f971e23SRiver Riddle target.addLegalDialect<linalg::LinalgDialect, tensor::TensorDialect, 511f971e23SRiver Riddle scf::SCFDialect>(); 521d973b7dSRob Suderman target.addIllegalDialect<tosa::TosaDialect>(); 53286a9d46SRob Suderman 54286a9d46SRob Suderman // Not every TOSA op can be legalized to linalg. 55286a9d46SRob Suderman target.addLegalOp<tosa::ApplyScaleOp>(); 56286a9d46SRob Suderman target.addLegalOp<tosa::IfOp>(); 57286a9d46SRob Suderman target.addLegalOp<tosa::ConstOp>(); 58*f09db6a3SJerry-Ge target.addLegalOp<tosa::ConstShapeOp>(); 59286a9d46SRob Suderman target.addLegalOp<tosa::WhileOp>(); 60fbf719b8SMaya Amrami target.addLegalOp<tosa::ConcatOp>(); 6154eec7caSRob Suderman target.addLegalOp<tosa::SliceOp>(); 62723979efSKrzysztof Drewniak target.addLegalOp<tosa::ReshapeOp>(); 632a196254SRamkumar Ramachandra target.addLegalOp<tosa::PadOp>(); 64286a9d46SRob Suderman 651d973b7dSRob Suderman target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); 661d973b7dSRob Suderman 678d237190SMatthias Gehre TypeConverter converter; 688d237190SMatthias Gehre tosa::populateTosaTypeConversion(converter); 698d237190SMatthias Gehre 7047f175b0SRiver Riddle FunctionOpInterface func = getOperation(); 718d237190SMatthias Gehre mlir::tosa::populateTosaToLinalgConversionPatterns(converter, &patterns); 721d973b7dSRob Suderman if (failed(applyFullConversion(func, target, std::move(patterns)))) 731d973b7dSRob Suderman signalPassFailure(); 741d973b7dSRob Suderman } 751d973b7dSRob Suderman }; 761d973b7dSRob Suderman } // namespace 771d973b7dSRob Suderman 78039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() { 79039b969bSMichele Scuttari return std::make_unique<TosaToLinalg>(); 80039b969bSMichele Scuttari } 81039b969bSMichele Scuttari 8232b7c1ffSBenjamin Maxwell void mlir::tosa::addTosaToLinalgPasses( 839dd15f74SAmir Bishara OpPassManager &pm, const TosaToLinalgOptions &options, 84acc6f3e9Sbjacob const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions, 85ecce5ccdSMatthias Gehre std::optional<tosa::TosaValidationOptions> validationOptions) { 86173fce42SRob Suderman // Optional decompositions are designed to benefit linalg. 879dd15f74SAmir Bishara if (!options.disableTosaDecompositions) 88039b969bSMichele Scuttari pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions()); 8958ceae95SRiver Riddle pm.addNestedPass<func::FuncOp>(createCanonicalizerPass()); 90173fce42SRob Suderman 91309bfecfSAviad Cohen pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass()); 9258ceae95SRiver Riddle pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass()); 93acc6f3e9Sbjacob pm.addNestedPass<func::FuncOp>( 94acc6f3e9Sbjacob tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions)); 9558ceae95SRiver Riddle pm.addNestedPass<func::FuncOp>(createCanonicalizerPass()); 963bcaf2ebSGeorgios Pinitas // TODO: Remove pass that operates on const tensor and enable optionality 979dd15f74SAmir Bishara pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass( 989dd15f74SAmir Bishara {options.aggressiveReduceConstant})); 9958ceae95SRiver Riddle pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass()); 100ecce5ccdSMatthias Gehre if (validationOptions) 101ecce5ccdSMatthias Gehre pm.addPass(tosa::createTosaValidation(*validationOptions)); 102039b969bSMichele Scuttari pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg()); 1031d973b7dSRob Suderman } 104cfc922fcSTai Ly 105cfc922fcSTai Ly //===----------------------------------------------------------------------===// 106cfc922fcSTai Ly // Pipeline registration. 107cfc922fcSTai Ly //===----------------------------------------------------------------------===// 108cfc922fcSTai Ly 109cfc922fcSTai Ly void mlir::tosa::registerTosaToLinalgPipelines() { 110cfc922fcSTai Ly PassPipelineRegistration<>( 111cfc922fcSTai Ly "tosa-to-linalg-pipeline", 112cfc922fcSTai Ly "The default pipeline for converting TOSA operators to the equivalent " 113cfc922fcSTai Ly "operations using the tensor operations in LinAlg as well as LinAlg " 114cfc922fcSTai Ly "named operations.", 115cfc922fcSTai Ly [](OpPassManager &pm) { 116cfc922fcSTai Ly TosaToLinalgOptions tosaToLinalgOptions; 117acc6f3e9Sbjacob TosaToLinalgNamedOptions tosaToLinalgNamedOptions; 118ecce5ccdSMatthias Gehre TosaValidationOptions validationOptions; 119cc9e7cb9STatWai Chong validationOptions.profile = {"none"}; 120ecce5ccdSMatthias Gehre validationOptions.StrictOperationSpecAlignment = true; 121ecce5ccdSMatthias Gehre validationOptions.level = tosa::TosaLevelEnum::EightK; 122cfc922fcSTai Ly tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions, 123acc6f3e9Sbjacob tosaToLinalgNamedOptions, 124ecce5ccdSMatthias Gehre validationOptions); 125cfc922fcSTai Ly }); 126cfc922fcSTai Ly } 127