xref: /llvm-project/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp (revision e84f6b6a88c1222d512edf0644c8f869dd12b8ef)
1 //===- PrintCallHelper.cpp - Helper to emit runtime print calls -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
10 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
11 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "llvm/ADT/ArrayRef.h"
16 
17 using namespace mlir;
18 using namespace llvm;
19 
20 static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
21                                             StringRef symbolName) {
22   static int counter = 0;
23   std::string uniqueName = std::string(symbolName);
24   while (moduleOp.lookupSymbol(uniqueName)) {
25     uniqueName = std::string(symbolName) + "_" + std::to_string(counter++);
26   }
27   return uniqueName;
28 }
29 
30 LogicalResult mlir::LLVM::createPrintStrCall(
31     OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
32     StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
33     std::optional<StringRef> runtimeFunctionName) {
34   auto ip = builder.saveInsertionPoint();
35   builder.setInsertionPointToStart(moduleOp.getBody());
36   MLIRContext *ctx = builder.getContext();
37 
38   // Create a zero-terminated byte representation and allocate global symbol.
39   SmallVector<uint8_t> elementVals;
40   elementVals.append(string.begin(), string.end());
41   if (addNewline)
42     elementVals.push_back('\n');
43   elementVals.push_back('\0');
44   auto dataAttrType = RankedTensorType::get(
45       {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
46   auto dataAttr =
47       DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
48   auto arrayTy =
49       LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
50   auto globalOp = builder.create<LLVM::GlobalOp>(
51       loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
52       ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
53 
54   auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
55   // Emit call to `printStr` in runtime library.
56   builder.restoreInsertionPoint(ip);
57   auto msgAddr =
58       builder.create<LLVM::AddressOfOp>(loc, ptrTy, globalOp.getName());
59   SmallVector<LLVM::GEPArg> indices(1, 0);
60   Value gep =
61       builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
62   FailureOr<LLVM::LLVMFuncOp> printer =
63       LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
64   if (failed(printer))
65     return failure();
66   builder.create<LLVM::CallOp>(loc, TypeRange(),
67                                SymbolRefAttr::get(printer.value()), gep);
68   return success();
69 }
70