xref: /llvm-project/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This pass promotes small floats (of some unsupported types T) to a supported
9 // type U by wrapping all float operations on Ts with expansion to and
10 // truncation from U, then operating on U.
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/Transforms/Passes.h"
14 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Arith/Utils/Utils.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/ErrorHandling.h"
24 #include <optional>
25 
26 namespace mlir::arith {
27 #define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
29 } // namespace mlir::arith
30 
31 using namespace mlir;
32 
33 namespace {
34 struct EmulateUnsupportedFloatsPass
35     : arith::impl::ArithEmulateUnsupportedFloatsBase<
36           EmulateUnsupportedFloatsPass> {
37   using arith::impl::ArithEmulateUnsupportedFloatsBase<
38       EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
39 
40   void runOnOperation() override;
41 };
42 
43 struct EmulateFloatPattern final : ConversionPattern {
44   EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx)
45       : ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
46 
47   LogicalResult match(Operation *op) const override;
48   void rewrite(Operation *op, ArrayRef<Value> operands,
49                ConversionPatternRewriter &rewriter) const override;
50 };
51 } // end namespace
52 
53 LogicalResult EmulateFloatPattern::match(Operation *op) const {
54   if (getTypeConverter()->isLegal(op))
55     return failure();
56   // The rewrite doesn't handle cloning regions.
57   if (op->getNumRegions() != 0)
58     return failure();
59   return success();
60 }
61 
62 void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
63                                   ConversionPatternRewriter &rewriter) const {
64   Location loc = op->getLoc();
65   const TypeConverter *converter = getTypeConverter();
66   SmallVector<Type> resultTypes;
67   if (failed(converter->convertTypes(op->getResultTypes(), resultTypes))) {
68     // Note to anyone looking for this error message: this is a "can't happen".
69     // If you're seeing it, there's a bug.
70     op->emitOpError("type conversion failed in float emulation");
71     return;
72   }
73   Operation *expandedOp =
74       rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
75                       op->getAttrs(), op->getSuccessors(), /*regions=*/{});
76   SmallVector<Value> newResults(expandedOp->getResults());
77   for (auto [res, oldType, newType] : llvm::zip_equal(
78            MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
79     if (oldType != newType) {
80       auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
81       truncFOp.setFastmath(arith::FastMathFlags::contract);
82       res = truncFOp.getResult();
83     }
84   }
85   rewriter.replaceOp(op, newResults);
86 }
87 
88 void mlir::arith::populateEmulateUnsupportedFloatsConversions(
89     TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) {
90   converter.addConversion([sourceTypes = SmallVector<Type>(sourceTypes),
91                            targetType](Type type) -> std::optional<Type> {
92     if (llvm::is_contained(sourceTypes, type))
93       return targetType;
94     if (auto shaped = dyn_cast<ShapedType>(type))
95       if (llvm::is_contained(sourceTypes, shaped.getElementType()))
96         return shaped.clone(targetType);
97     // All other types legal
98     return type;
99   });
100   converter.addTargetMaterialization(
101       [](OpBuilder &b, Type target, ValueRange input, Location loc) {
102         auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
103         extFOp.setFastmath(arith::FastMathFlags::contract);
104         return extFOp;
105       });
106 }
107 
108 void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
109     RewritePatternSet &patterns, const TypeConverter &converter) {
110   patterns.add<EmulateFloatPattern>(converter, patterns.getContext());
111 }
112 
113 void mlir::arith::populateEmulateUnsupportedFloatsLegality(
114     ConversionTarget &target, const TypeConverter &converter) {
115   // Don't try to legalize functions and other ops that don't need expansion.
116   target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
117   target.addDynamicallyLegalDialect<arith::ArithDialect>(
118       [&](Operation *op) -> std::optional<bool> {
119         return converter.isLegal(op);
120       });
121   // Manually mark arithmetic-performing vector instructions.
122   target.addDynamicallyLegalOp<
123       vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
124       vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
125       [&](Operation *op) { return converter.isLegal(op); });
126   target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
127                     arith::ConstantOp, vector::SplatOp>();
128 }
129 
130 void EmulateUnsupportedFloatsPass::runOnOperation() {
131   MLIRContext *ctx = &getContext();
132   Operation *op = getOperation();
133   SmallVector<Type> sourceTypes;
134   Type targetType;
135 
136   std::optional<FloatType> maybeTargetType =
137       arith::parseFloatType(ctx, targetTypeStr);
138   if (!maybeTargetType) {
139     emitError(UnknownLoc::get(ctx), "could not map target type '" +
140                                         targetTypeStr +
141                                         "' to a known floating-point type");
142     return signalPassFailure();
143   }
144   targetType = *maybeTargetType;
145   for (StringRef sourceTypeStr : sourceTypeStrs) {
146     std::optional<FloatType> maybeSourceType =
147         arith::parseFloatType(ctx, sourceTypeStr);
148     if (!maybeSourceType) {
149       emitError(UnknownLoc::get(ctx), "could not map source type '" +
150                                           sourceTypeStr +
151                                           "' to a known floating-point type");
152       return signalPassFailure();
153     }
154     sourceTypes.push_back(*maybeSourceType);
155   }
156   if (sourceTypes.empty())
157     (void)emitOptionalWarning(
158         std::nullopt,
159         "no source types specified, float emulation will do nothing");
160 
161   if (llvm::is_contained(sourceTypes, targetType)) {
162     emitError(UnknownLoc::get(ctx),
163               "target type cannot be an unsupported source type");
164     return signalPassFailure();
165   }
166   TypeConverter converter;
167   arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes,
168                                                      targetType);
169   RewritePatternSet patterns(ctx);
170   arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter);
171   ConversionTarget target(getContext());
172   arith::populateEmulateUnsupportedFloatsLegality(target, converter);
173 
174   if (failed(applyPartialConversion(op, target, std::move(patterns))))
175     signalPassFailure();
176 }
177