1 //===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation pass legalizes Tosa operations to the Linalg dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" 14 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Math/IR/Math.h" 19 #include "mlir/Dialect/SCF/IR/SCF.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 22 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 23 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Pass/PassManager.h" 26 #include "mlir/Transforms/DialectConversion.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 29 namespace mlir { 30 #define GEN_PASS_DEF_TOSATOLINALGNAMED 31 #include "mlir/Conversion/Passes.h.inc" 32 } // namespace mlir 33 34 using namespace mlir; 35 36 namespace { 37 struct TosaToLinalgNamed 38 : public impl::TosaToLinalgNamedBase<TosaToLinalgNamed> { 39 public: 40 TosaToLinalgNamed(const TosaToLinalgNamedOptions &options) 41 : impl::TosaToLinalgNamedBase<TosaToLinalgNamed>(options) {} 42 43 void getDependentDialects(DialectRegistry ®istry) const override { 44 registry 45 .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect, 46 tensor::TensorDialect, scf::SCFDialect>(); 47 } 48 49 void runOnOperation() override { 50 TypeConverter converter; 51 tosa::populateTosaTypeConversion(converter); 52 53 RewritePatternSet patterns(&getContext()); 54 ConversionTarget target(getContext()); 55 target.addLegalDialect<linalg::LinalgDialect, tosa::TosaDialect, 56 tensor::TensorDialect, scf::SCFDialect>(); 57 58 // Not every TOSA op can be legalized to linalg. 59 target.addIllegalOp<tosa::Conv2DOp>(); 60 target.addIllegalOp<tosa::Conv3DOp>(); 61 target.addIllegalOp<tosa::DepthwiseConv2DOp>(); 62 target.addIllegalOp<tosa::MaxPool2dOp>(); 63 target.addIllegalOp<tosa::AvgPool2dOp>(); 64 target.addIllegalOp<tosa::MatMulOp>(); 65 target.addIllegalOp<tosa::FullyConnectedOp>(); 66 target.addIllegalOp<tosa::TransposeOp>(); 67 68 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); 69 70 FunctionOpInterface func = getOperation(); 71 TosaToLinalgNamedOptions options; 72 options.preferConv2DKernelLayoutHWCF = preferConv2DKernelLayoutHWCF; 73 tosa::populateTosaToLinalgNamedConversionPatterns(converter, &patterns, 74 options); 75 if (failed(applyFullConversion(func, target, std::move(patterns)))) 76 signalPassFailure(); 77 } 78 }; 79 } // namespace 80 81 std::unique_ptr<Pass> 82 mlir::tosa::createTosaToLinalgNamed(const TosaToLinalgNamedOptions &options) { 83 return std::make_unique<TosaToLinalgNamed>(options); 84 } 85