xref: /llvm-project/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp (revision e84f6b6a88c1222d512edf0644c8f869dd12b8ef)
1 //===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements helper functions to call common simple C functions in
10 // LLVMIR (e.g. amon others to support printing and debugging).
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/Support/LLVM.h"
19 
20 using namespace mlir;
21 using namespace mlir::LLVM;
22 
23 /// Helper functions to lookup or create the declaration for commonly used
24 /// external C function calls. The list of functions provided here must be
25 /// implemented separately (e.g. as  part of a support runtime library or as
26 /// part of the libc).
27 static constexpr llvm::StringRef kPrintI64 = "printI64";
28 static constexpr llvm::StringRef kPrintU64 = "printU64";
29 static constexpr llvm::StringRef kPrintF16 = "printF16";
30 static constexpr llvm::StringRef kPrintBF16 = "printBF16";
31 static constexpr llvm::StringRef kPrintF32 = "printF32";
32 static constexpr llvm::StringRef kPrintF64 = "printF64";
33 static constexpr llvm::StringRef kPrintString = "printString";
34 static constexpr llvm::StringRef kPrintOpen = "printOpen";
35 static constexpr llvm::StringRef kPrintClose = "printClose";
36 static constexpr llvm::StringRef kPrintComma = "printComma";
37 static constexpr llvm::StringRef kPrintNewline = "printNewline";
38 static constexpr llvm::StringRef kMalloc = "malloc";
39 static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
40 static constexpr llvm::StringRef kFree = "free";
41 static constexpr llvm::StringRef kGenericAlloc = "_mlir_memref_to_llvm_alloc";
42 static constexpr llvm::StringRef kGenericAlignedAlloc =
43     "_mlir_memref_to_llvm_aligned_alloc";
44 static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
45 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
46 
47 /// Generic print function lookupOrCreate helper.
48 FailureOr<LLVM::LLVMFuncOp>
49 mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name,
50                              ArrayRef<Type> paramTypes, Type resultType,
51                              bool isVarArg, bool isReserved) {
52   assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
53          "expected SymbolTable operation");
54   auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
55       SymbolTable::lookupSymbolIn(moduleOp, name));
56   auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
57   // Assert the signature of the found function is same as expected
58   if (func) {
59     if (funcT != func.getFunctionType()) {
60       if (isReserved) {
61         func.emitError("redefinition of reserved function '")
62             << name << "' of different type " << func.getFunctionType()
63             << " is prohibited";
64       } else {
65         func.emitError("redefinition of function '")
66             << name << "' of different type " << funcT << " is prohibited";
67       }
68       return failure();
69     }
70     return func;
71   }
72   OpBuilder b(moduleOp->getRegion(0));
73   return b.create<LLVM::LLVMFuncOp>(
74       moduleOp->getLoc(), name,
75       LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
76 }
77 
78 static FailureOr<LLVM::LLVMFuncOp>
79 lookupOrCreateReservedFn(Operation *moduleOp, StringRef name,
80                          ArrayRef<Type> paramTypes, Type resultType) {
81   return lookupOrCreateFn(moduleOp, name, paramTypes, resultType,
82                           /*isVarArg=*/false, /*isReserved=*/true);
83 }
84 
85 FailureOr<LLVM::LLVMFuncOp>
86 mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
87   return lookupOrCreateReservedFn(
88       moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
89       LLVM::LLVMVoidType::get(moduleOp->getContext()));
90 }
91 
92 FailureOr<LLVM::LLVMFuncOp>
93 mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
94   return lookupOrCreateReservedFn(
95       moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
96       LLVM::LLVMVoidType::get(moduleOp->getContext()));
97 }
98 
99 FailureOr<LLVM::LLVMFuncOp>
100 mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
101   return lookupOrCreateReservedFn(
102       moduleOp, kPrintF16,
103       IntegerType::get(moduleOp->getContext(), 16), // bits!
104       LLVM::LLVMVoidType::get(moduleOp->getContext()));
105 }
106 
107 FailureOr<LLVM::LLVMFuncOp>
108 mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
109   return lookupOrCreateReservedFn(
110       moduleOp, kPrintBF16,
111       IntegerType::get(moduleOp->getContext(), 16), // bits!
112       LLVM::LLVMVoidType::get(moduleOp->getContext()));
113 }
114 
115 FailureOr<LLVM::LLVMFuncOp>
116 mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
117   return lookupOrCreateReservedFn(
118       moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
119       LLVM::LLVMVoidType::get(moduleOp->getContext()));
120 }
121 
122 FailureOr<LLVM::LLVMFuncOp>
123 mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
124   return lookupOrCreateReservedFn(
125       moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
126       LLVM::LLVMVoidType::get(moduleOp->getContext()));
127 }
128 
129 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
130   return LLVM::LLVMPointerType::get(context);
131 }
132 
133 static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
134   // A char pointer and void ptr are the same in LLVM IR.
135   return getCharPtr(context);
136 }
137 
138 FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn(
139     Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
140   return lookupOrCreateReservedFn(
141       moduleOp, runtimeFunctionName.value_or(kPrintString),
142       getCharPtr(moduleOp->getContext()),
143       LLVM::LLVMVoidType::get(moduleOp->getContext()));
144 }
145 
146 FailureOr<LLVM::LLVMFuncOp>
147 mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
148   return lookupOrCreateReservedFn(
149       moduleOp, kPrintOpen, {},
150       LLVM::LLVMVoidType::get(moduleOp->getContext()));
151 }
152 
153 FailureOr<LLVM::LLVMFuncOp>
154 mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
155   return lookupOrCreateReservedFn(
156       moduleOp, kPrintClose, {},
157       LLVM::LLVMVoidType::get(moduleOp->getContext()));
158 }
159 
160 FailureOr<LLVM::LLVMFuncOp>
161 mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
162   return lookupOrCreateReservedFn(
163       moduleOp, kPrintComma, {},
164       LLVM::LLVMVoidType::get(moduleOp->getContext()));
165 }
166 
167 FailureOr<LLVM::LLVMFuncOp>
168 mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
169   return lookupOrCreateReservedFn(
170       moduleOp, kPrintNewline, {},
171       LLVM::LLVMVoidType::get(moduleOp->getContext()));
172 }
173 
174 FailureOr<LLVM::LLVMFuncOp>
175 mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp, Type indexType) {
176   return lookupOrCreateReservedFn(moduleOp, kMalloc, indexType,
177                                   getVoidPtr(moduleOp->getContext()));
178 }
179 
180 FailureOr<LLVM::LLVMFuncOp>
181 mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType) {
182   return lookupOrCreateReservedFn(moduleOp, kAlignedAlloc,
183                                   {indexType, indexType},
184                                   getVoidPtr(moduleOp->getContext()));
185 }
186 
187 FailureOr<LLVM::LLVMFuncOp>
188 mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
189   return lookupOrCreateReservedFn(
190       moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
191       LLVM::LLVMVoidType::get(moduleOp->getContext()));
192 }
193 
194 FailureOr<LLVM::LLVMFuncOp>
195 mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType) {
196   return lookupOrCreateReservedFn(moduleOp, kGenericAlloc, indexType,
197                                   getVoidPtr(moduleOp->getContext()));
198 }
199 
200 FailureOr<LLVM::LLVMFuncOp>
201 mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
202                                                 Type indexType) {
203   return lookupOrCreateReservedFn(moduleOp, kGenericAlignedAlloc,
204                                   {indexType, indexType},
205                                   getVoidPtr(moduleOp->getContext()));
206 }
207 
208 FailureOr<LLVM::LLVMFuncOp>
209 mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
210   return lookupOrCreateReservedFn(
211       moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
212       LLVM::LLVMVoidType::get(moduleOp->getContext()));
213 }
214 
215 FailureOr<LLVM::LLVMFuncOp>
216 mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
217                                        Type unrankedDescriptorType) {
218   return lookupOrCreateReservedFn(
219       moduleOp, kMemRefCopy,
220       ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
221       LLVM::LLVMVoidType::get(moduleOp->getContext()));
222 }
223