xref: /llvm-project/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp (revision 6e6352f4da760cfe7486d7340da0113b98cb797f)
1 //===- TosaToMLProgram.cpp - Lowering Tosa to MLProgram Dialect------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the TOSA dialect to the MLProgram dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
14 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/IR/IRMapping.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 using namespace mlir;
20 using namespace tosa;
21 namespace {
22 
23 class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
24 public:
25   using OpRewritePattern<tosa::VariableOp>::OpRewritePattern;
26 
matchAndRewrite(tosa::VariableOp op,PatternRewriter & rewriter) const27   LogicalResult matchAndRewrite(tosa::VariableOp op,
28                                 PatternRewriter &rewriter) const final {
29     auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
30         op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
31         op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
32     newVariable.setPrivate();
33     rewriter.replaceOp(op, newVariable);
34     return success();
35   }
36 };
37 
38 class VariableWriteOpConverter
39     : public OpRewritePattern<tosa::VariableWriteOp> {
40 public:
41   using OpRewritePattern<tosa::VariableWriteOp>::OpRewritePattern;
42 
matchAndRewrite(tosa::VariableWriteOp op,PatternRewriter & rewriter) const43   LogicalResult matchAndRewrite(tosa::VariableWriteOp op,
44                                 PatternRewriter &rewriter) const final {
45     auto globalSymbolRef =
46         SymbolRefAttr::get(rewriter.getContext(), op.getName());
47     auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
48         op.getLoc(), globalSymbolRef, op.getValue());
49     rewriter.replaceOp(op, newVariableWrite);
50     return success();
51   }
52 };
53 
54 class VariableReadOpConverter : public OpRewritePattern<tosa::VariableReadOp> {
55 public:
56   using OpRewritePattern<tosa::VariableReadOp>::OpRewritePattern;
57 
matchAndRewrite(tosa::VariableReadOp op,PatternRewriter & rewriter) const58   LogicalResult matchAndRewrite(tosa::VariableReadOp op,
59                                 PatternRewriter &rewriter) const final {
60     auto globalSymbolRef =
61         SymbolRefAttr::get(rewriter.getContext(), op.getName());
62     auto newVariableRead = rewriter.create<ml_program::GlobalLoadOp>(
63         op.getLoc(), op.getType(), globalSymbolRef);
64     rewriter.replaceOp(op, newVariableRead);
65 
66     return success();
67   }
68 };
69 
70 } // namespace
71 
populateTosaToMLProgramConversionPatterns(RewritePatternSet * patterns)72 void mlir::tosa::populateTosaToMLProgramConversionPatterns(
73     RewritePatternSet *patterns) {
74   patterns->add<VariableOpConverter, VariableWriteOpConverter,
75                 VariableReadOpConverter>(patterns->getContext());
76 }
77