//===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements expansion of tanh op. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; /// Create a float constant. static Value createFloatConst(Location loc, Type type, double value, OpBuilder &b) { auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value); if (auto shapedTy = dyn_cast(type)) { return b.create(loc, DenseElementsAttr::get(shapedTy, attr)); } return b.create(loc, attr); } /// Create a float constant. static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b) { auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value); if (auto shapedTy = dyn_cast(type)) { return b.create(loc, DenseElementsAttr::get(shapedTy, attr)); } return b.create(loc, attr); } static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) { Type opType = operand.getType(); Value fixedConvert = b.create(b.getI64Type(), operand); Value fpFixedConvert = b.create(opType, fixedConvert); return fpFixedConvert; } /// Expands tanh op into /// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0 /// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) { auto floatType = op.getOperand().getType(); Location loc = op.getLoc(); Value one = createFloatConst(loc, floatType, 1.0, rewriter); Value two = createFloatConst(loc, floatType, 2.0, rewriter); Value doubledX = rewriter.create(loc, op.getOperand(), two); // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x} Value negDoubledX = rewriter.create(loc, doubledX); Value exp2x = rewriter.create(loc, negDoubledX); Value dividend = rewriter.create(loc, one, exp2x); Value divisor = rewriter.create(loc, one, exp2x); Value positiveRes = rewriter.create(loc, dividend, divisor); // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1 exp2x = rewriter.create(loc, doubledX); dividend = rewriter.create(loc, exp2x, one); divisor = rewriter.create(loc, exp2x, one); Value negativeRes = rewriter.create(loc, dividend, divisor); // tanh(x) = x >= 0 ? positiveRes : negativeRes Value zero = createFloatConst(loc, floatType, 0.0, rewriter); Value cmpRes = rewriter.create(loc, arith::CmpFPredicate::OGE, op.getOperand(), zero); rewriter.replaceOpWithNewOp(op, cmpRes, positiveRes, negativeRes); return success(); } // Converts math.tan to math.sin, math.cos, and arith.divf. static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operand = op.getOperand(); Type type = operand.getType(); Value sin = b.create(type, operand); Value cos = b.create(type, operand); Value div = b.create(type, sin, cos); rewriter.replaceOp(op, div); return success(); } static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operandA = op.getOperand(0); Value operandB = op.getOperand(1); Value operandC = op.getOperand(2); Type type = op.getType(); Value mult = b.create(type, operandA, operandB); Value add = b.create(type, mult, operandC); rewriter.replaceOp(op, add); return success(); } // Converts a floorf() function to the following: // floorf(float x) -> // y = (float)(int) x // if (x < 0) then incr = -1 else incr = 0 // y = y + incr <= replace this op with the floorf op. static LogicalResult convertFloorOp(math::FloorOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operand = op.getOperand(); Type opType = operand.getType(); Value fpFixedConvert = createTruncatedFPValue(operand, b); // Creating constants for later use. Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter); Value negCheck = b.create(arith::CmpFPredicate::OLT, operand, zero); Value incrValue = b.create(op->getLoc(), negCheck, negOne, zero); Value ret = b.create(opType, fpFixedConvert, incrValue); rewriter.replaceOp(op, ret); return success(); } // Converts a ceilf() function to the following: // ceilf(float x) -> // y = (float)(int) x // if (x > y) then incr = 1 else incr = 0 // y = y + incr <= replace this op with the ceilf op. static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operand = op.getOperand(); Type opType = operand.getType(); Value fpFixedConvert = createTruncatedFPValue(operand, b); // Creating constants for later use. Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter); Value gtCheck = b.create(arith::CmpFPredicate::OGT, operand, fpFixedConvert); Value incrValue = b.create(op->getLoc(), gtCheck, one, zero); Value ret = b.create(opType, fpFixedConvert, incrValue); rewriter.replaceOp(op, ret); return success(); } // Converts math.ctlz to scf and arith operations. This is done // by performing a binary search on the bits. static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter) { auto operand = op.getOperand(); auto operandTy = operand.getType(); auto eTy = getElementTypeOrSelf(operandTy); Location loc = op.getLoc(); int32_t bitwidth = eTy.getIntOrFloatBitWidth(); if (bitwidth > 64) return failure(); uint64_t allbits = -1; if (bitwidth < 64) { allbits = allbits >> (64 - bitwidth); } Value x = operand; Value count = createIntConst(loc, operandTy, 0, rewriter); for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) { auto half = bw / 2; auto bits = createIntConst(loc, operandTy, half, rewriter); auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter); Value pred = rewriter.create(loc, arith::CmpIPredicate::ule, x, mask); Value add = rewriter.create(loc, count, bits); Value shift = rewriter.create(loc, x, bits); x = rewriter.create(loc, pred, shift, x); count = rewriter.create(loc, pred, add, count); } Value zero = createIntConst(loc, operandTy, 0, rewriter); Value pred = rewriter.create(loc, arith::CmpIPredicate::eq, operand, zero); Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter); Value sel = rewriter.create(loc, pred, bwval, count); rewriter.replaceOp(op, sel); return success(); } void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { patterns.add(convertCtlzOp); } void mlir::populateExpandTanPattern(RewritePatternSet &patterns) { patterns.add(convertTanOp); } void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { patterns.add(convertTanhOp); } void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { patterns.add(convertFmaFOp); } void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) { patterns.add(convertCeilOp); } void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) { patterns.add(convertFloorOp); }