xref: /llvm-project/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp (revision e84f6b6a88c1222d512edf0644c8f869dd12b8ef)
13be3883eSBenjamin Maxwell //===- PrintCallHelper.cpp - Helper to emit runtime print calls -----------===//
23be3883eSBenjamin Maxwell //
33be3883eSBenjamin Maxwell // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43be3883eSBenjamin Maxwell // See https://llvm.org/LICENSE.txt for license information.
53be3883eSBenjamin Maxwell // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63be3883eSBenjamin Maxwell //
73be3883eSBenjamin Maxwell //===----------------------------------------------------------------------===//
83be3883eSBenjamin Maxwell 
93be3883eSBenjamin Maxwell #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
103be3883eSBenjamin Maxwell #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
113be3883eSBenjamin Maxwell #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
123be3883eSBenjamin Maxwell #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
133be3883eSBenjamin Maxwell #include "mlir/IR/Builders.h"
143be3883eSBenjamin Maxwell #include "mlir/IR/BuiltinOps.h"
153be3883eSBenjamin Maxwell #include "llvm/ADT/ArrayRef.h"
163be3883eSBenjamin Maxwell 
173be3883eSBenjamin Maxwell using namespace mlir;
183be3883eSBenjamin Maxwell using namespace llvm;
193be3883eSBenjamin Maxwell 
203be3883eSBenjamin Maxwell static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
213be3883eSBenjamin Maxwell                                             StringRef symbolName) {
223be3883eSBenjamin Maxwell   static int counter = 0;
233be3883eSBenjamin Maxwell   std::string uniqueName = std::string(symbolName);
243be3883eSBenjamin Maxwell   while (moduleOp.lookupSymbol(uniqueName)) {
253be3883eSBenjamin Maxwell     uniqueName = std::string(symbolName) + "_" + std::to_string(counter++);
263be3883eSBenjamin Maxwell   }
273be3883eSBenjamin Maxwell   return uniqueName;
283be3883eSBenjamin Maxwell }
293be3883eSBenjamin Maxwell 
30*e84f6b6aSLuohao Wang LogicalResult mlir::LLVM::createPrintStrCall(
313be3883eSBenjamin Maxwell     OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
323be3883eSBenjamin Maxwell     StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
333be3883eSBenjamin Maxwell     std::optional<StringRef> runtimeFunctionName) {
343be3883eSBenjamin Maxwell   auto ip = builder.saveInsertionPoint();
353be3883eSBenjamin Maxwell   builder.setInsertionPointToStart(moduleOp.getBody());
363be3883eSBenjamin Maxwell   MLIRContext *ctx = builder.getContext();
373be3883eSBenjamin Maxwell 
383be3883eSBenjamin Maxwell   // Create a zero-terminated byte representation and allocate global symbol.
393be3883eSBenjamin Maxwell   SmallVector<uint8_t> elementVals;
403be3883eSBenjamin Maxwell   elementVals.append(string.begin(), string.end());
413be3883eSBenjamin Maxwell   if (addNewline)
423be3883eSBenjamin Maxwell     elementVals.push_back('\n');
433be3883eSBenjamin Maxwell   elementVals.push_back('\0');
443be3883eSBenjamin Maxwell   auto dataAttrType = RankedTensorType::get(
453be3883eSBenjamin Maxwell       {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
463be3883eSBenjamin Maxwell   auto dataAttr =
473be3883eSBenjamin Maxwell       DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
483be3883eSBenjamin Maxwell   auto arrayTy =
493be3883eSBenjamin Maxwell       LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
503be3883eSBenjamin Maxwell   auto globalOp = builder.create<LLVM::GlobalOp>(
513be3883eSBenjamin Maxwell       loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
523be3883eSBenjamin Maxwell       ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
533be3883eSBenjamin Maxwell 
5497a238e8SChristian Ulmann   auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
553be3883eSBenjamin Maxwell   // Emit call to `printStr` in runtime library.
563be3883eSBenjamin Maxwell   builder.restoreInsertionPoint(ip);
5797a238e8SChristian Ulmann   auto msgAddr =
5897a238e8SChristian Ulmann       builder.create<LLVM::AddressOfOp>(loc, ptrTy, globalOp.getName());
593be3883eSBenjamin Maxwell   SmallVector<LLVM::GEPArg> indices(1, 0);
6097a238e8SChristian Ulmann   Value gep =
6197a238e8SChristian Ulmann       builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
62*e84f6b6aSLuohao Wang   FailureOr<LLVM::LLVMFuncOp> printer =
6397a238e8SChristian Ulmann       LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
64*e84f6b6aSLuohao Wang   if (failed(printer))
65*e84f6b6aSLuohao Wang     return failure();
66*e84f6b6aSLuohao Wang   builder.create<LLVM::CallOp>(loc, TypeRange(),
67*e84f6b6aSLuohao Wang                                SymbolRefAttr::get(printer.value()), gep);
68*e84f6b6aSLuohao Wang   return success();
693be3883eSBenjamin Maxwell }
70