1 //===- OptimizeForNVVM.cpp - Optimize LLVM IR for NVVM ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h" 10 11 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 16 17 namespace mlir { 18 namespace NVVM { 19 #define GEN_PASS_DEF_NVVMOPTIMIZEFORTARGET 20 #include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc" 21 } // namespace NVVM 22 } // namespace mlir 23 24 using namespace mlir; 25 26 namespace { 27 // Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one 28 // (conditional) Newton iteration. 29 // 30 // This as accurate as promoting the division to fp32 in the NVPTX backend, but 31 // faster because it performs less Newton iterations, avoids the slow path 32 // for e.g. denormals, and allows reuse of the reciprocal for multiple divisions 33 // by the same divisor. 34 struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> { 35 using OpRewritePattern<LLVM::FDivOp>::OpRewritePattern; 36 37 private: 38 LogicalResult matchAndRewrite(LLVM::FDivOp op, 39 PatternRewriter &rewriter) const override; 40 }; 41 42 struct NVVMOptimizeForTarget 43 : public NVVM::impl::NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> { 44 void runOnOperation() override; 45 46 void getDependentDialects(DialectRegistry ®istry) const override { 47 registry.insert<NVVM::NVVMDialect>(); 48 } 49 }; 50 } // namespace 51 52 LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op, 53 PatternRewriter &rewriter) const { 54 if (!op.getType().isF16()) 55 return rewriter.notifyMatchFailure(op, "not f16"); 56 Location loc = op.getLoc(); 57 58 Type f32Type = rewriter.getF32Type(); 59 Type i32Type = rewriter.getI32Type(); 60 61 // Extend lhs and rhs to fp32. 62 Value lhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getLhs()); 63 Value rhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getRhs()); 64 65 // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp. 66 Value rcp = rewriter.create<NVVM::RcpApproxFtzF32Op>(loc, f32Type, rhs); 67 Value approx = rewriter.create<LLVM::FMulOp>(loc, lhs, rcp); 68 69 // Refine the approximation with one Newton iteration: 70 // float refined = approx + (lhs - approx * rhs) * rcp; 71 Value err = rewriter.create<LLVM::FMAOp>( 72 loc, approx, rewriter.create<LLVM::FNegOp>(loc, rhs), lhs); 73 Value refined = rewriter.create<LLVM::FMAOp>(loc, err, rcp, approx); 74 75 // Use refined value if approx is normal (exponent neither all 0 or all 1). 76 Value mask = rewriter.create<LLVM::ConstantOp>( 77 loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000)); 78 Value cast = rewriter.create<LLVM::BitcastOp>(loc, i32Type, approx); 79 Value exp = rewriter.create<LLVM::AndOp>(loc, i32Type, cast, mask); 80 Value zero = rewriter.create<LLVM::ConstantOp>( 81 loc, i32Type, rewriter.getUI32IntegerAttr(0)); 82 Value pred = rewriter.create<LLVM::OrOp>( 83 loc, 84 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, zero), 85 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, mask)); 86 Value result = 87 rewriter.create<LLVM::SelectOp>(loc, f32Type, pred, approx, refined); 88 89 // Replace with trucation back to fp16. 90 rewriter.replaceOpWithNewOp<LLVM::FPTruncOp>(op, op.getType(), result); 91 92 return success(); 93 } 94 95 void NVVMOptimizeForTarget::runOnOperation() { 96 MLIRContext *ctx = getOperation()->getContext(); 97 RewritePatternSet patterns(ctx); 98 patterns.add<ExpandDivF16>(ctx); 99 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 100 return signalPassFailure(); 101 } 102 103 std::unique_ptr<Pass> NVVM::createOptimizeForTargetPass() { 104 return std::make_unique<NVVMOptimizeForTarget>(); 105 } 106