1 //===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation pass legalizes Tosa operations to the Linalg dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" 14 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/Dialect/Index/IR/IndexDialect.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Math/IR/Math.h" 20 #include "mlir/Dialect/SCF/IR/SCF.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 23 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 24 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/Pass/PassManager.h" 27 #include "mlir/Transforms/DialectConversion.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 #include "mlir/Transforms/Passes.h" 30 31 namespace mlir { 32 #define GEN_PASS_DEF_TOSATOLINALG 33 #include "mlir/Conversion/Passes.h.inc" 34 } // namespace mlir 35 36 using namespace mlir; 37 38 namespace { 39 struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> { 40 public: 41 void getDependentDialects(DialectRegistry ®istry) const override { 42 registry 43 .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect, 44 index::IndexDialect, tensor::TensorDialect, scf::SCFDialect>(); 45 } 46 47 void runOnOperation() override { 48 RewritePatternSet patterns(&getContext()); 49 ConversionTarget target(getContext()); 50 target.addLegalDialect<linalg::LinalgDialect, tensor::TensorDialect, 51 scf::SCFDialect>(); 52 target.addIllegalDialect<tosa::TosaDialect>(); 53 54 // Not every TOSA op can be legalized to linalg. 55 target.addLegalOp<tosa::ApplyScaleOp>(); 56 target.addLegalOp<tosa::IfOp>(); 57 target.addLegalOp<tosa::ConstOp>(); 58 target.addLegalOp<tosa::ConstShapeOp>(); 59 target.addLegalOp<tosa::WhileOp>(); 60 target.addLegalOp<tosa::ConcatOp>(); 61 target.addLegalOp<tosa::SliceOp>(); 62 target.addLegalOp<tosa::ReshapeOp>(); 63 target.addLegalOp<tosa::PadOp>(); 64 65 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); 66 67 TypeConverter converter; 68 tosa::populateTosaTypeConversion(converter); 69 70 FunctionOpInterface func = getOperation(); 71 mlir::tosa::populateTosaToLinalgConversionPatterns(converter, &patterns); 72 if (failed(applyFullConversion(func, target, std::move(patterns)))) 73 signalPassFailure(); 74 } 75 }; 76 } // namespace 77 78 std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() { 79 return std::make_unique<TosaToLinalg>(); 80 } 81 82 void mlir::tosa::addTosaToLinalgPasses( 83 OpPassManager &pm, const TosaToLinalgOptions &options, 84 const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions, 85 std::optional<tosa::TosaValidationOptions> validationOptions) { 86 // Optional decompositions are designed to benefit linalg. 87 if (!options.disableTosaDecompositions) 88 pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions()); 89 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass()); 90 91 pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass()); 92 pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass()); 93 pm.addNestedPass<func::FuncOp>( 94 tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions)); 95 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass()); 96 // TODO: Remove pass that operates on const tensor and enable optionality 97 pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass( 98 {options.aggressiveReduceConstant})); 99 pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass()); 100 if (validationOptions) 101 pm.addPass(tosa::createTosaValidation(*validationOptions)); 102 pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg()); 103 } 104 105 //===----------------------------------------------------------------------===// 106 // Pipeline registration. 107 //===----------------------------------------------------------------------===// 108 109 void mlir::tosa::registerTosaToLinalgPipelines() { 110 PassPipelineRegistration<>( 111 "tosa-to-linalg-pipeline", 112 "The default pipeline for converting TOSA operators to the equivalent " 113 "operations using the tensor operations in LinAlg as well as LinAlg " 114 "named operations.", 115 [](OpPassManager &pm) { 116 TosaToLinalgOptions tosaToLinalgOptions; 117 TosaToLinalgNamedOptions tosaToLinalgNamedOptions; 118 TosaValidationOptions validationOptions; 119 validationOptions.profile = {"none"}; 120 validationOptions.StrictOperationSpecAlignment = true; 121 validationOptions.level = tosa::TosaLevelEnum::EightK; 122 tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions, 123 tosaToLinalgNamedOptions, 124 validationOptions); 125 }); 126 } 127