1 //===- PrintCallHelper.cpp - Helper to emit runtime print calls -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" 10 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 11 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "llvm/ADT/ArrayRef.h" 16 17 using namespace mlir; 18 using namespace llvm; 19 20 static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp, 21 StringRef symbolName) { 22 static int counter = 0; 23 std::string uniqueName = std::string(symbolName); 24 while (moduleOp.lookupSymbol(uniqueName)) { 25 uniqueName = std::string(symbolName) + "_" + std::to_string(counter++); 26 } 27 return uniqueName; 28 } 29 30 LogicalResult mlir::LLVM::createPrintStrCall( 31 OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, 32 StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline, 33 std::optional<StringRef> runtimeFunctionName) { 34 auto ip = builder.saveInsertionPoint(); 35 builder.setInsertionPointToStart(moduleOp.getBody()); 36 MLIRContext *ctx = builder.getContext(); 37 38 // Create a zero-terminated byte representation and allocate global symbol. 39 SmallVector<uint8_t> elementVals; 40 elementVals.append(string.begin(), string.end()); 41 if (addNewline) 42 elementVals.push_back('\n'); 43 elementVals.push_back('\0'); 44 auto dataAttrType = RankedTensorType::get( 45 {static_cast<int64_t>(elementVals.size())}, builder.getI8Type()); 46 auto dataAttr = 47 DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals)); 48 auto arrayTy = 49 LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size()); 50 auto globalOp = builder.create<LLVM::GlobalOp>( 51 loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, 52 ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr); 53 54 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext()); 55 // Emit call to `printStr` in runtime library. 56 builder.restoreInsertionPoint(ip); 57 auto msgAddr = 58 builder.create<LLVM::AddressOfOp>(loc, ptrTy, globalOp.getName()); 59 SmallVector<LLVM::GEPArg> indices(1, 0); 60 Value gep = 61 builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices); 62 FailureOr<LLVM::LLVMFuncOp> printer = 63 LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName); 64 if (failed(printer)) 65 return failure(); 66 builder.create<LLVM::CallOp>(loc, TypeRange(), 67 SymbolRefAttr::get(printer.value()), gep); 68 return success(); 69 } 70