13fef2d26SRiver Riddle //===- TestPolynomialApproximation.cpp - Test math ops approximations -----===// 23fef2d26SRiver Riddle // 33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 63fef2d26SRiver Riddle // 73fef2d26SRiver Riddle //===----------------------------------------------------------------------===// 83fef2d26SRiver Riddle // 93fef2d26SRiver Riddle // This file contains test passes for expanding math operations into 103fef2d26SRiver Riddle // polynomial approximations. 113fef2d26SRiver Riddle // 123fef2d26SRiver Riddle //===----------------------------------------------------------------------===// 133fef2d26SRiver Riddle 14abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 153fef2d26SRiver Riddle #include "mlir/Dialect/Math/IR/Math.h" 163fef2d26SRiver Riddle #include "mlir/Dialect/Math/Transforms/Passes.h" 1799ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h" 1835553d45SEmilio Cota #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 193fef2d26SRiver Riddle #include "mlir/Pass/Pass.h" 203fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 213fef2d26SRiver Riddle 223fef2d26SRiver Riddle using namespace mlir; 233fef2d26SRiver Riddle 243fef2d26SRiver Riddle namespace { 253fef2d26SRiver Riddle struct TestMathPolynomialApproximationPass 2687d6bf37SRiver Riddle : public PassWrapper<TestMathPolynomialApproximationPass, OperationPass<>> { 275e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 285e50dd04SRiver Riddle TestMathPolynomialApproximationPass) 295e50dd04SRiver Riddle 3035553d45SEmilio Cota TestMathPolynomialApproximationPass() = default; 3135553d45SEmilio Cota TestMathPolynomialApproximationPass( 323bab9d4eSMehdi Amini const TestMathPolynomialApproximationPass &pass) 333bab9d4eSMehdi Amini : PassWrapper(pass) {} 3435553d45SEmilio Cota 3541574554SRiver Riddle void runOnOperation() override; 363fef2d26SRiver Riddle void getDependentDialects(DialectRegistry ®istry) const override { 37abc362a1SJakub Kuderski registry.insert<arith::ArithDialect, math::MathDialect, 38a54f4eaeSMogball vector::VectorDialect>(); 3935553d45SEmilio Cota if (enableAvx2) 4035553d45SEmilio Cota registry.insert<x86vector::X86VectorDialect>(); 413fef2d26SRiver Riddle } 42b5e22e6dSMehdi Amini StringRef getArgument() const final { 43b5e22e6dSMehdi Amini return "test-math-polynomial-approximation"; 44b5e22e6dSMehdi Amini } 45b5e22e6dSMehdi Amini StringRef getDescription() const final { 46b5e22e6dSMehdi Amini return "Test math polynomial approximations"; 47b5e22e6dSMehdi Amini } 4835553d45SEmilio Cota 4935553d45SEmilio Cota Option<bool> enableAvx2{ 5035553d45SEmilio Cota *this, "enable-avx2", 5135553d45SEmilio Cota llvm::cl::desc("Enable approximations that emit AVX2 intrinsics via the " 5235553d45SEmilio Cota "X86Vector dialect"), 5335553d45SEmilio Cota llvm::cl::init(false)}; 543fef2d26SRiver Riddle }; 55be0a7e9fSMehdi Amini } // namespace 563fef2d26SRiver Riddle 5741574554SRiver Riddle void TestMathPolynomialApproximationPass::runOnOperation() { 583fef2d26SRiver Riddle RewritePatternSet patterns(&getContext()); 5902b6fb21SMehdi Amini MathPolynomialApproximationOptions approxOptions; 6002b6fb21SMehdi Amini approxOptions.enableAvx2 = enableAvx2; 6102b6fb21SMehdi Amini populateMathPolynomialApproximationPatterns(patterns, approxOptions); 62*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 633fef2d26SRiver Riddle } 643fef2d26SRiver Riddle 653fef2d26SRiver Riddle namespace mlir { 663fef2d26SRiver Riddle namespace test { 673fef2d26SRiver Riddle void registerTestMathPolynomialApproximationPass() { 68b5e22e6dSMehdi Amini PassRegistration<TestMathPolynomialApproximationPass>(); 693fef2d26SRiver Riddle } 703fef2d26SRiver Riddle } // namespace test 713fef2d26SRiver Riddle } // namespace mlir 72