xref: /llvm-project/flang/lib/Optimizer/Transforms/AnnotateConstant.cpp (revision 67d0d7ac0acb0665d6a09f61278fbcf51f0114c2)
1 //===-- AnnotateConstant.cpp ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // #include "PassDetail.h"
10 #include "flang/Optimizer/Dialect/FIRDialect.h"
11 #include "flang/Optimizer/Dialect/FIROps.h"
12 #include "flang/Optimizer/Transforms/Passes.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 
15 namespace fir {
16 #define GEN_PASS_DEF_ANNOTATECONSTANTOPERANDS
17 #include "flang/Optimizer/Transforms/Passes.h.inc"
18 } // namespace fir
19 
20 #define DEBUG_TYPE "flang-annotate-constant"
21 
22 using namespace fir;
23 
24 namespace {
25 struct AnnotateConstantOperands
26     : public impl::AnnotateConstantOperandsBase<AnnotateConstantOperands> {
runOnOperation__anond97b18640111::AnnotateConstantOperands27   void runOnOperation() override {
28     auto *context = &getContext();
29     mlir::Dialect *firDialect = context->getLoadedDialect("fir");
30     getOperation()->walk([&](mlir::Operation *op) {
31       // We filter out other dialects even though they may undergo merging of
32       // non-equal constant values by the canonicalizer as well.
33       if (op->getDialect() == firDialect) {
34         llvm::SmallVector<mlir::Attribute> attrs;
35         bool hasOneOrMoreConstOpnd = false;
36         for (mlir::Value opnd : op->getOperands()) {
37           if (auto constOp = mlir::dyn_cast_or_null<mlir::arith::ConstantOp>(
38                   opnd.getDefiningOp())) {
39             attrs.push_back(constOp.getValue());
40             hasOneOrMoreConstOpnd = true;
41           } else if (auto addrOp = mlir::dyn_cast_or_null<fir::AddrOfOp>(
42                          opnd.getDefiningOp())) {
43             attrs.push_back(addrOp.getSymbol());
44             hasOneOrMoreConstOpnd = true;
45           } else {
46             attrs.push_back(mlir::UnitAttr::get(context));
47           }
48         }
49         if (hasOneOrMoreConstOpnd)
50           op->setAttr("canonicalize_constant_operands",
51                       mlir::ArrayAttr::get(context, attrs));
52       }
53     });
54   }
55 };
56 
57 } // namespace
58 
createAnnotateConstantOperandsPass()59 std::unique_ptr<mlir::Pass> fir::createAnnotateConstantOperandsPass() {
60   return std::make_unique<AnnotateConstantOperands>();
61 }
62