1 //===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This pass promotes small floats (of some unsupported types T) to a supported 9 // type U by wrapping all float operations on Ts with expansion to and 10 // truncation from U, then operating on U. 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/Transforms/Passes.h" 14 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Arith/Utils/Utils.h" 17 #include "mlir/Dialect/Vector/IR/VectorOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/Location.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "llvm/ADT/STLExtras.h" 23 #include "llvm/Support/ErrorHandling.h" 24 #include <optional> 25 26 namespace mlir::arith { 27 #define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS 28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" 29 } // namespace mlir::arith 30 31 using namespace mlir; 32 33 namespace { 34 struct EmulateUnsupportedFloatsPass 35 : arith::impl::ArithEmulateUnsupportedFloatsBase< 36 EmulateUnsupportedFloatsPass> { 37 using arith::impl::ArithEmulateUnsupportedFloatsBase< 38 EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase; 39 40 void runOnOperation() override; 41 }; 42 43 struct EmulateFloatPattern final : ConversionPattern { 44 EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx) 45 : ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {} 46 47 LogicalResult match(Operation *op) const override; 48 void rewrite(Operation *op, ArrayRef<Value> operands, 49 ConversionPatternRewriter &rewriter) const override; 50 }; 51 } // end namespace 52 53 LogicalResult EmulateFloatPattern::match(Operation *op) const { 54 if (getTypeConverter()->isLegal(op)) 55 return failure(); 56 // The rewrite doesn't handle cloning regions. 57 if (op->getNumRegions() != 0) 58 return failure(); 59 return success(); 60 } 61 62 void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands, 63 ConversionPatternRewriter &rewriter) const { 64 Location loc = op->getLoc(); 65 const TypeConverter *converter = getTypeConverter(); 66 SmallVector<Type> resultTypes; 67 if (failed(converter->convertTypes(op->getResultTypes(), resultTypes))) { 68 // Note to anyone looking for this error message: this is a "can't happen". 69 // If you're seeing it, there's a bug. 70 op->emitOpError("type conversion failed in float emulation"); 71 return; 72 } 73 Operation *expandedOp = 74 rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes, 75 op->getAttrs(), op->getSuccessors(), /*regions=*/{}); 76 SmallVector<Value> newResults(expandedOp->getResults()); 77 for (auto [res, oldType, newType] : llvm::zip_equal( 78 MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) { 79 if (oldType != newType) { 80 auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res); 81 truncFOp.setFastmath(arith::FastMathFlags::contract); 82 res = truncFOp.getResult(); 83 } 84 } 85 rewriter.replaceOp(op, newResults); 86 } 87 88 void mlir::arith::populateEmulateUnsupportedFloatsConversions( 89 TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) { 90 converter.addConversion([sourceTypes = SmallVector<Type>(sourceTypes), 91 targetType](Type type) -> std::optional<Type> { 92 if (llvm::is_contained(sourceTypes, type)) 93 return targetType; 94 if (auto shaped = dyn_cast<ShapedType>(type)) 95 if (llvm::is_contained(sourceTypes, shaped.getElementType())) 96 return shaped.clone(targetType); 97 // All other types legal 98 return type; 99 }); 100 converter.addTargetMaterialization( 101 [](OpBuilder &b, Type target, ValueRange input, Location loc) { 102 auto extFOp = b.create<arith::ExtFOp>(loc, target, input); 103 extFOp.setFastmath(arith::FastMathFlags::contract); 104 return extFOp; 105 }); 106 } 107 108 void mlir::arith::populateEmulateUnsupportedFloatsPatterns( 109 RewritePatternSet &patterns, const TypeConverter &converter) { 110 patterns.add<EmulateFloatPattern>(converter, patterns.getContext()); 111 } 112 113 void mlir::arith::populateEmulateUnsupportedFloatsLegality( 114 ConversionTarget &target, const TypeConverter &converter) { 115 // Don't try to legalize functions and other ops that don't need expansion. 116 target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; }); 117 target.addDynamicallyLegalDialect<arith::ArithDialect>( 118 [&](Operation *op) -> std::optional<bool> { 119 return converter.isLegal(op); 120 }); 121 // Manually mark arithmetic-performing vector instructions. 122 target.addDynamicallyLegalOp< 123 vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp, 124 vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>( 125 [&](Operation *op) { return converter.isLegal(op); }); 126 target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp, 127 arith::ConstantOp, vector::SplatOp>(); 128 } 129 130 void EmulateUnsupportedFloatsPass::runOnOperation() { 131 MLIRContext *ctx = &getContext(); 132 Operation *op = getOperation(); 133 SmallVector<Type> sourceTypes; 134 Type targetType; 135 136 std::optional<FloatType> maybeTargetType = 137 arith::parseFloatType(ctx, targetTypeStr); 138 if (!maybeTargetType) { 139 emitError(UnknownLoc::get(ctx), "could not map target type '" + 140 targetTypeStr + 141 "' to a known floating-point type"); 142 return signalPassFailure(); 143 } 144 targetType = *maybeTargetType; 145 for (StringRef sourceTypeStr : sourceTypeStrs) { 146 std::optional<FloatType> maybeSourceType = 147 arith::parseFloatType(ctx, sourceTypeStr); 148 if (!maybeSourceType) { 149 emitError(UnknownLoc::get(ctx), "could not map source type '" + 150 sourceTypeStr + 151 "' to a known floating-point type"); 152 return signalPassFailure(); 153 } 154 sourceTypes.push_back(*maybeSourceType); 155 } 156 if (sourceTypes.empty()) 157 (void)emitOptionalWarning( 158 std::nullopt, 159 "no source types specified, float emulation will do nothing"); 160 161 if (llvm::is_contained(sourceTypes, targetType)) { 162 emitError(UnknownLoc::get(ctx), 163 "target type cannot be an unsupported source type"); 164 return signalPassFailure(); 165 } 166 TypeConverter converter; 167 arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes, 168 targetType); 169 RewritePatternSet patterns(ctx); 170 arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter); 171 ConversionTarget target(getContext()); 172 arith::populateEmulateUnsupportedFloatsLegality(target, converter); 173 174 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 175 signalPassFailure(); 176 } 177