xref: /llvm-project/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1d9803841SGil Rapaport //===- FormExpressions.cpp - Form C-style expressions --------*- C++ -*-===//
2d9803841SGil Rapaport //
3d9803841SGil Rapaport // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d9803841SGil Rapaport // See https://llvm.org/LICENSE.txt for license information.
5d9803841SGil Rapaport // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d9803841SGil Rapaport //
7d9803841SGil Rapaport //===----------------------------------------------------------------------===//
8d9803841SGil Rapaport //
9d9803841SGil Rapaport // This file implements a pass that forms EmitC operations modeling C operators
10d9803841SGil Rapaport // into C-style expressions using the emitc.expression op.
11d9803841SGil Rapaport //
12d9803841SGil Rapaport //===----------------------------------------------------------------------===//
13d9803841SGil Rapaport 
14d9803841SGil Rapaport #include "mlir/Dialect/EmitC/IR/EmitC.h"
15d9803841SGil Rapaport #include "mlir/Dialect/EmitC/Transforms/Passes.h"
16d9803841SGil Rapaport #include "mlir/Dialect/EmitC/Transforms/Transforms.h"
17d9803841SGil Rapaport #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18d9803841SGil Rapaport 
19d9803841SGil Rapaport namespace mlir {
20d9803841SGil Rapaport namespace emitc {
21d9803841SGil Rapaport #define GEN_PASS_DEF_FORMEXPRESSIONS
22d9803841SGil Rapaport #include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
23d9803841SGil Rapaport } // namespace emitc
24d9803841SGil Rapaport } // namespace mlir
25d9803841SGil Rapaport 
26d9803841SGil Rapaport using namespace mlir;
27d9803841SGil Rapaport using namespace emitc;
28d9803841SGil Rapaport 
29d9803841SGil Rapaport namespace {
30d9803841SGil Rapaport struct FormExpressionsPass
31d9803841SGil Rapaport     : public emitc::impl::FormExpressionsBase<FormExpressionsPass> {
32d9803841SGil Rapaport   void runOnOperation() override {
33d9803841SGil Rapaport     Operation *rootOp = getOperation();
34d9803841SGil Rapaport     MLIRContext *context = rootOp->getContext();
35d9803841SGil Rapaport 
36d9803841SGil Rapaport     // Wrap each C operator op with an expression op.
37d9803841SGil Rapaport     OpBuilder builder(context);
38d9803841SGil Rapaport     auto matchFun = [&](Operation *op) {
395344a370SKirill Chibisov       if (op->hasTrait<OpTrait::emitc::CExpression>() &&
4009adb531SChris           !op->getParentOfType<emitc::ExpressionOp>() &&
4109adb531SChris           op->getNumResults() == 1)
42d9803841SGil Rapaport         createExpression(op, builder);
43d9803841SGil Rapaport     };
44d9803841SGil Rapaport     rootOp->walk(matchFun);
45d9803841SGil Rapaport 
46d9803841SGil Rapaport     // Fold expressions where possible.
47d9803841SGil Rapaport     RewritePatternSet patterns(context);
48d9803841SGil Rapaport     populateExpressionPatterns(patterns);
49d9803841SGil Rapaport 
50*09dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(rootOp, std::move(patterns))))
51d9803841SGil Rapaport       return signalPassFailure();
52d9803841SGil Rapaport   }
53d9803841SGil Rapaport 
54d9803841SGil Rapaport   void getDependentDialects(DialectRegistry &registry) const override {
55d9803841SGil Rapaport     registry.insert<emitc::EmitCDialect>();
56d9803841SGil Rapaport   }
57d9803841SGil Rapaport };
58d9803841SGil Rapaport } // namespace
59d9803841SGil Rapaport 
60d9803841SGil Rapaport std::unique_ptr<Pass> mlir::emitc::createFormExpressionsPass() {
61d9803841SGil Rapaport   return std::make_unique<FormExpressionsPass>();
62d9803841SGil Rapaport }
63