xref: /llvm-project/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp (revision f09db6a3af971ab7d9bbc7ba574a8dc0c10b2940)
1 //===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass legalizes Tosa operations to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/Index/IR/IndexDialect.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Math/IR/Math.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
23 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
24 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Pass/PassManager.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29 #include "mlir/Transforms/Passes.h"
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_TOSATOLINALG
33 #include "mlir/Conversion/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 namespace {
39 struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
40 public:
41   void getDependentDialects(DialectRegistry &registry) const override {
42     registry
43         .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
44                 index::IndexDialect, tensor::TensorDialect, scf::SCFDialect>();
45   }
46 
47   void runOnOperation() override {
48     RewritePatternSet patterns(&getContext());
49     ConversionTarget target(getContext());
50     target.addLegalDialect<linalg::LinalgDialect, tensor::TensorDialect,
51                            scf::SCFDialect>();
52     target.addIllegalDialect<tosa::TosaDialect>();
53 
54     // Not every TOSA op can be legalized to linalg.
55     target.addLegalOp<tosa::ApplyScaleOp>();
56     target.addLegalOp<tosa::IfOp>();
57     target.addLegalOp<tosa::ConstOp>();
58     target.addLegalOp<tosa::ConstShapeOp>();
59     target.addLegalOp<tosa::WhileOp>();
60     target.addLegalOp<tosa::ConcatOp>();
61     target.addLegalOp<tosa::SliceOp>();
62     target.addLegalOp<tosa::ReshapeOp>();
63     target.addLegalOp<tosa::PadOp>();
64 
65     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
66 
67     TypeConverter converter;
68     tosa::populateTosaTypeConversion(converter);
69 
70     FunctionOpInterface func = getOperation();
71     mlir::tosa::populateTosaToLinalgConversionPatterns(converter, &patterns);
72     if (failed(applyFullConversion(func, target, std::move(patterns))))
73       signalPassFailure();
74   }
75 };
76 } // namespace
77 
78 std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
79   return std::make_unique<TosaToLinalg>();
80 }
81 
82 void mlir::tosa::addTosaToLinalgPasses(
83     OpPassManager &pm, const TosaToLinalgOptions &options,
84     const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions,
85     std::optional<tosa::TosaValidationOptions> validationOptions) {
86   // Optional decompositions are designed to benefit linalg.
87   if (!options.disableTosaDecompositions)
88     pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
89   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
90 
91   pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass());
92   pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
93   pm.addNestedPass<func::FuncOp>(
94       tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions));
95   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
96   // TODO: Remove pass that operates on const tensor and enable optionality
97   pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
98       {options.aggressiveReduceConstant}));
99   pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
100   if (validationOptions)
101     pm.addPass(tosa::createTosaValidation(*validationOptions));
102   pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // Pipeline registration.
107 //===----------------------------------------------------------------------===//
108 
109 void mlir::tosa::registerTosaToLinalgPipelines() {
110   PassPipelineRegistration<>(
111       "tosa-to-linalg-pipeline",
112       "The default pipeline for converting TOSA operators to the equivalent "
113       "operations using the tensor operations in LinAlg as well as LinAlg "
114       "named operations.",
115       [](OpPassManager &pm) {
116         TosaToLinalgOptions tosaToLinalgOptions;
117         TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
118         TosaValidationOptions validationOptions;
119         validationOptions.profile = {"none"};
120         validationOptions.StrictOperationSpecAlignment = true;
121         validationOptions.level = tosa::TosaLevelEnum::EightK;
122         tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
123                                     tosaToLinalgNamedOptions,
124                                     validationOptions);
125       });
126 }
127