xref: /llvm-project/mlir/test/lib/Dialect/Math/TestExpandMath.cpp (revision 45d83ae7df65a3c9843270d970119bc97957d830)
1 //===- TestExpandMath.cpp - Test expand math op into exp form -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains test passes for expanding math operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Math/Transforms/Passes.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 
20 using namespace mlir;
21 
22 namespace {
23 struct TestExpandMathPass
24     : public PassWrapper<TestExpandMathPass, OperationPass<>> {
25   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandMathPass)
26 
27   void runOnOperation() override;
28   StringRef getArgument() const final { return "test-expand-math"; }
29   void getDependentDialects(DialectRegistry &registry) const override {
30     registry
31         .insert<arith::ArithDialect, scf::SCFDialect, vector::VectorDialect>();
32   }
33   StringRef getDescription() const final { return "Test expanding math"; }
34 };
35 } // namespace
36 
37 void TestExpandMathPass::runOnOperation() {
38   RewritePatternSet patterns(&getContext());
39   populateExpandCtlzPattern(patterns);
40   populateExpandExp2FPattern(patterns);
41   populateExpandTanPattern(patterns);
42   populateExpandSinhPattern(patterns);
43   populateExpandCoshPattern(patterns);
44   populateExpandTanhPattern(patterns);
45   populateExpandAsinhPattern(patterns);
46   populateExpandAcoshPattern(patterns);
47   populateExpandAtanhPattern(patterns);
48   populateExpandFmaFPattern(patterns);
49   populateExpandCeilFPattern(patterns);
50   populateExpandPowFPattern(patterns);
51   populateExpandFPowIPattern(patterns);
52   populateExpandRoundFPattern(patterns);
53   populateExpandRoundEvenPattern(patterns);
54   populateExpandRsqrtPattern(patterns);
55   (void)applyPatternsGreedily(getOperation(), std::move(patterns));
56 }
57 
58 namespace mlir {
59 namespace test {
60 void registerTestExpandMathPass() { PassRegistration<TestExpandMathPass>(); }
61 } // namespace test
62 } // namespace mlir
63