1 //===- FunctionCallUtils.cpp - Utilities for C function calls -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements helper functions to call common simple C functions in 10 // LLVMIR (e.g. amon others to support printing and debugging). 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/OpDefinition.h" 18 #include "mlir/Support/LLVM.h" 19 20 using namespace mlir; 21 using namespace mlir::LLVM; 22 23 /// Helper functions to lookup or create the declaration for commonly used 24 /// external C function calls. The list of functions provided here must be 25 /// implemented separately (e.g. as part of a support runtime library or as 26 /// part of the libc). 27 static constexpr llvm::StringRef kPrintI64 = "printI64"; 28 static constexpr llvm::StringRef kPrintU64 = "printU64"; 29 static constexpr llvm::StringRef kPrintF16 = "printF16"; 30 static constexpr llvm::StringRef kPrintBF16 = "printBF16"; 31 static constexpr llvm::StringRef kPrintF32 = "printF32"; 32 static constexpr llvm::StringRef kPrintF64 = "printF64"; 33 static constexpr llvm::StringRef kPrintString = "printString"; 34 static constexpr llvm::StringRef kPrintOpen = "printOpen"; 35 static constexpr llvm::StringRef kPrintClose = "printClose"; 36 static constexpr llvm::StringRef kPrintComma = "printComma"; 37 static constexpr llvm::StringRef kPrintNewline = "printNewline"; 38 static constexpr llvm::StringRef kMalloc = "malloc"; 39 static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc"; 40 static constexpr llvm::StringRef kFree = "free"; 41 static constexpr llvm::StringRef kGenericAlloc = "_mlir_memref_to_llvm_alloc"; 42 static constexpr llvm::StringRef kGenericAlignedAlloc = 43 "_mlir_memref_to_llvm_aligned_alloc"; 44 static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free"; 45 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy"; 46 47 /// Generic print function lookupOrCreate helper. 48 FailureOr<LLVM::LLVMFuncOp> 49 mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name, 50 ArrayRef<Type> paramTypes, Type resultType, 51 bool isVarArg, bool isReserved) { 52 assert(moduleOp->hasTrait<OpTrait::SymbolTable>() && 53 "expected SymbolTable operation"); 54 auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>( 55 SymbolTable::lookupSymbolIn(moduleOp, name)); 56 auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg); 57 // Assert the signature of the found function is same as expected 58 if (func) { 59 if (funcT != func.getFunctionType()) { 60 if (isReserved) { 61 func.emitError("redefinition of reserved function '") 62 << name << "' of different type " << func.getFunctionType() 63 << " is prohibited"; 64 } else { 65 func.emitError("redefinition of function '") 66 << name << "' of different type " << funcT << " is prohibited"; 67 } 68 return failure(); 69 } 70 return func; 71 } 72 OpBuilder b(moduleOp->getRegion(0)); 73 return b.create<LLVM::LLVMFuncOp>( 74 moduleOp->getLoc(), name, 75 LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg)); 76 } 77 78 static FailureOr<LLVM::LLVMFuncOp> 79 lookupOrCreateReservedFn(Operation *moduleOp, StringRef name, 80 ArrayRef<Type> paramTypes, Type resultType) { 81 return lookupOrCreateFn(moduleOp, name, paramTypes, resultType, 82 /*isVarArg=*/false, /*isReserved=*/true); 83 } 84 85 FailureOr<LLVM::LLVMFuncOp> 86 mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) { 87 return lookupOrCreateReservedFn( 88 moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64), 89 LLVM::LLVMVoidType::get(moduleOp->getContext())); 90 } 91 92 FailureOr<LLVM::LLVMFuncOp> 93 mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) { 94 return lookupOrCreateReservedFn( 95 moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64), 96 LLVM::LLVMVoidType::get(moduleOp->getContext())); 97 } 98 99 FailureOr<LLVM::LLVMFuncOp> 100 mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) { 101 return lookupOrCreateReservedFn( 102 moduleOp, kPrintF16, 103 IntegerType::get(moduleOp->getContext(), 16), // bits! 104 LLVM::LLVMVoidType::get(moduleOp->getContext())); 105 } 106 107 FailureOr<LLVM::LLVMFuncOp> 108 mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) { 109 return lookupOrCreateReservedFn( 110 moduleOp, kPrintBF16, 111 IntegerType::get(moduleOp->getContext(), 16), // bits! 112 LLVM::LLVMVoidType::get(moduleOp->getContext())); 113 } 114 115 FailureOr<LLVM::LLVMFuncOp> 116 mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) { 117 return lookupOrCreateReservedFn( 118 moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()), 119 LLVM::LLVMVoidType::get(moduleOp->getContext())); 120 } 121 122 FailureOr<LLVM::LLVMFuncOp> 123 mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) { 124 return lookupOrCreateReservedFn( 125 moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()), 126 LLVM::LLVMVoidType::get(moduleOp->getContext())); 127 } 128 129 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { 130 return LLVM::LLVMPointerType::get(context); 131 } 132 133 static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) { 134 // A char pointer and void ptr are the same in LLVM IR. 135 return getCharPtr(context); 136 } 137 138 FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn( 139 Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) { 140 return lookupOrCreateReservedFn( 141 moduleOp, runtimeFunctionName.value_or(kPrintString), 142 getCharPtr(moduleOp->getContext()), 143 LLVM::LLVMVoidType::get(moduleOp->getContext())); 144 } 145 146 FailureOr<LLVM::LLVMFuncOp> 147 mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) { 148 return lookupOrCreateReservedFn( 149 moduleOp, kPrintOpen, {}, 150 LLVM::LLVMVoidType::get(moduleOp->getContext())); 151 } 152 153 FailureOr<LLVM::LLVMFuncOp> 154 mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) { 155 return lookupOrCreateReservedFn( 156 moduleOp, kPrintClose, {}, 157 LLVM::LLVMVoidType::get(moduleOp->getContext())); 158 } 159 160 FailureOr<LLVM::LLVMFuncOp> 161 mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) { 162 return lookupOrCreateReservedFn( 163 moduleOp, kPrintComma, {}, 164 LLVM::LLVMVoidType::get(moduleOp->getContext())); 165 } 166 167 FailureOr<LLVM::LLVMFuncOp> 168 mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) { 169 return lookupOrCreateReservedFn( 170 moduleOp, kPrintNewline, {}, 171 LLVM::LLVMVoidType::get(moduleOp->getContext())); 172 } 173 174 FailureOr<LLVM::LLVMFuncOp> 175 mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp, Type indexType) { 176 return lookupOrCreateReservedFn(moduleOp, kMalloc, indexType, 177 getVoidPtr(moduleOp->getContext())); 178 } 179 180 FailureOr<LLVM::LLVMFuncOp> 181 mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType) { 182 return lookupOrCreateReservedFn(moduleOp, kAlignedAlloc, 183 {indexType, indexType}, 184 getVoidPtr(moduleOp->getContext())); 185 } 186 187 FailureOr<LLVM::LLVMFuncOp> 188 mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) { 189 return lookupOrCreateReservedFn( 190 moduleOp, kFree, getVoidPtr(moduleOp->getContext()), 191 LLVM::LLVMVoidType::get(moduleOp->getContext())); 192 } 193 194 FailureOr<LLVM::LLVMFuncOp> 195 mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType) { 196 return lookupOrCreateReservedFn(moduleOp, kGenericAlloc, indexType, 197 getVoidPtr(moduleOp->getContext())); 198 } 199 200 FailureOr<LLVM::LLVMFuncOp> 201 mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, 202 Type indexType) { 203 return lookupOrCreateReservedFn(moduleOp, kGenericAlignedAlloc, 204 {indexType, indexType}, 205 getVoidPtr(moduleOp->getContext())); 206 } 207 208 FailureOr<LLVM::LLVMFuncOp> 209 mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) { 210 return lookupOrCreateReservedFn( 211 moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()), 212 LLVM::LLVMVoidType::get(moduleOp->getContext())); 213 } 214 215 FailureOr<LLVM::LLVMFuncOp> 216 mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, 217 Type unrankedDescriptorType) { 218 return lookupOrCreateReservedFn( 219 moduleOp, kMemRefCopy, 220 ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType}, 221 LLVM::LLVMVoidType::get(moduleOp->getContext())); 222 } 223