xref: /llvm-project/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
13fef2d26SRiver Riddle //===- TestPolynomialApproximation.cpp - Test math ops approximations -----===//
23fef2d26SRiver Riddle //
33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63fef2d26SRiver Riddle //
73fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
83fef2d26SRiver Riddle //
93fef2d26SRiver Riddle // This file contains test passes for expanding math operations into
103fef2d26SRiver Riddle // polynomial approximations.
113fef2d26SRiver Riddle //
123fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
133fef2d26SRiver Riddle 
14abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
153fef2d26SRiver Riddle #include "mlir/Dialect/Math/IR/Math.h"
163fef2d26SRiver Riddle #include "mlir/Dialect/Math/Transforms/Passes.h"
1799ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
1835553d45SEmilio Cota #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
193fef2d26SRiver Riddle #include "mlir/Pass/Pass.h"
203fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
213fef2d26SRiver Riddle 
223fef2d26SRiver Riddle using namespace mlir;
233fef2d26SRiver Riddle 
243fef2d26SRiver Riddle namespace {
253fef2d26SRiver Riddle struct TestMathPolynomialApproximationPass
2687d6bf37SRiver Riddle     : public PassWrapper<TestMathPolynomialApproximationPass, OperationPass<>> {
275e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
285e50dd04SRiver Riddle       TestMathPolynomialApproximationPass)
295e50dd04SRiver Riddle 
3035553d45SEmilio Cota   TestMathPolynomialApproximationPass() = default;
3135553d45SEmilio Cota   TestMathPolynomialApproximationPass(
323bab9d4eSMehdi Amini       const TestMathPolynomialApproximationPass &pass)
333bab9d4eSMehdi Amini       : PassWrapper(pass) {}
3435553d45SEmilio Cota 
3541574554SRiver Riddle   void runOnOperation() override;
363fef2d26SRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
37abc362a1SJakub Kuderski     registry.insert<arith::ArithDialect, math::MathDialect,
38a54f4eaeSMogball                     vector::VectorDialect>();
3935553d45SEmilio Cota     if (enableAvx2)
4035553d45SEmilio Cota       registry.insert<x86vector::X86VectorDialect>();
413fef2d26SRiver Riddle   }
42b5e22e6dSMehdi Amini   StringRef getArgument() const final {
43b5e22e6dSMehdi Amini     return "test-math-polynomial-approximation";
44b5e22e6dSMehdi Amini   }
45b5e22e6dSMehdi Amini   StringRef getDescription() const final {
46b5e22e6dSMehdi Amini     return "Test math polynomial approximations";
47b5e22e6dSMehdi Amini   }
4835553d45SEmilio Cota 
4935553d45SEmilio Cota   Option<bool> enableAvx2{
5035553d45SEmilio Cota       *this, "enable-avx2",
5135553d45SEmilio Cota       llvm::cl::desc("Enable approximations that emit AVX2 intrinsics via the "
5235553d45SEmilio Cota                      "X86Vector dialect"),
5335553d45SEmilio Cota       llvm::cl::init(false)};
543fef2d26SRiver Riddle };
55be0a7e9fSMehdi Amini } // namespace
563fef2d26SRiver Riddle 
5741574554SRiver Riddle void TestMathPolynomialApproximationPass::runOnOperation() {
583fef2d26SRiver Riddle   RewritePatternSet patterns(&getContext());
5902b6fb21SMehdi Amini   MathPolynomialApproximationOptions approxOptions;
6002b6fb21SMehdi Amini   approxOptions.enableAvx2 = enableAvx2;
6102b6fb21SMehdi Amini   populateMathPolynomialApproximationPatterns(patterns, approxOptions);
62*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(getOperation(), std::move(patterns));
633fef2d26SRiver Riddle }
643fef2d26SRiver Riddle 
653fef2d26SRiver Riddle namespace mlir {
663fef2d26SRiver Riddle namespace test {
673fef2d26SRiver Riddle void registerTestMathPolynomialApproximationPass() {
68b5e22e6dSMehdi Amini   PassRegistration<TestMathPolynomialApproximationPass>();
693fef2d26SRiver Riddle }
703fef2d26SRiver Riddle } // namespace test
713fef2d26SRiver Riddle } // namespace mlir
72