xref: /llvm-project/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp (revision af22e274e9c5643780f25066442e05b5bd453328)
1 //===- TosaToTensorPass.cpp - Lowering Tosa to Tensor Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass legalizes Tosa operations to the Tensor dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
14 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
18 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/PassManager.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_TOSATOTENSOR
26 #include "mlir/Conversion/Passes.h.inc"
27 } // namespace mlir
28 
29 using namespace mlir;
30 using namespace tosa;
31 
32 namespace {
33 struct TosaToTensor : public impl::TosaToTensorBase<TosaToTensor> {
34 public:
runOnOperation__anon849f45e10111::TosaToTensor35   void runOnOperation() override {
36     RewritePatternSet patterns(&getContext());
37     ConversionTarget target(getContext());
38     target.addIllegalOp<tosa::ConcatOp>();
39     target.addIllegalOp<tosa::ReshapeOp>();
40     target.addIllegalOp<tosa::SliceOp>();
41     target.addIllegalOp<tosa::PadOp>();
42     target.addLegalDialect<arith::ArithDialect>();
43     target.addLegalDialect<tensor::TensorDialect>();
44 
45     TypeConverter converter;
46     mlir::tosa::populateTosaTypeConversion(converter);
47 
48     mlir::tosa::populateTosaToTensorConversionPatterns(converter, &patterns);
49 
50     if (failed(applyPartialConversion(getOperation(), target,
51                                       std::move(patterns))))
52       signalPassFailure();
53   }
54 };
55 } // namespace
56 
createTosaToTensor()57 std::unique_ptr<Pass> mlir::tosa::createTosaToTensor() {
58   return std::make_unique<TosaToTensor>();
59 }
60