xref: /llvm-project/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp (revision f5aee1f18bdbc5694330a5e86eb46cf60e653d0c)
1 //===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "mlir/Dialect/Arith/Transforms/Passes.h"
11 
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "llvm/ADT/APInt.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/Support/MathExtras.h"
22 #include <cassert>
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Public Interface Definition
28 //===----------------------------------------------------------------------===//
29 
30 arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter(
31     unsigned targetBitwidth)
32     : loadStoreBitwidth(targetBitwidth) {
33   assert(llvm::isPowerOf2_32(targetBitwidth) &&
34          "Only power-of-two integers are supported");
35 
36   // Allow unknown types.
37   addConversion([](Type ty) -> std::optional<Type> { return ty; });
38 
39   // Function case.
40   addConversion([this](FunctionType ty) -> std::optional<Type> {
41     SmallVector<Type> inputs;
42     if (failed(convertTypes(ty.getInputs(), inputs)))
43       return nullptr;
44 
45     SmallVector<Type> results;
46     if (failed(convertTypes(ty.getResults(), results)))
47       return nullptr;
48 
49     return FunctionType::get(ty.getContext(), inputs, results);
50   });
51 }
52 
53 void arith::populateArithNarrowTypeEmulationPatterns(
54     const NarrowTypeEmulationConverter &typeConverter,
55     RewritePatternSet &patterns) {
56   // Populate `func.*` conversion patterns.
57   populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
58                                                                  typeConverter);
59   populateCallOpTypeConversionPattern(patterns, typeConverter);
60   populateReturnOpTypeConversionPattern(patterns, typeConverter);
61 }
62