1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LLVMCommon/Pattern.h" 10 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 12 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 13 #include "mlir/IR/AffineMap.h" 14 #include "mlir/IR/BuiltinAttributes.h" 15 16 using namespace mlir; 17 18 //===----------------------------------------------------------------------===// 19 // ConvertToLLVMPattern 20 //===----------------------------------------------------------------------===// 21 22 ConvertToLLVMPattern::ConvertToLLVMPattern( 23 StringRef rootOpName, MLIRContext *context, 24 const LLVMTypeConverter &typeConverter, PatternBenefit benefit) 25 : ConversionPattern(typeConverter, rootOpName, benefit, context) {} 26 27 const LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { 28 return static_cast<const LLVMTypeConverter *>( 29 ConversionPattern::getTypeConverter()); 30 } 31 32 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { 33 return *getTypeConverter()->getDialect(); 34 } 35 36 Type ConvertToLLVMPattern::getIndexType() const { 37 return getTypeConverter()->getIndexType(); 38 } 39 40 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { 41 return IntegerType::get(&getTypeConverter()->getContext(), 42 getTypeConverter()->getPointerBitwidth(addressSpace)); 43 } 44 45 Type ConvertToLLVMPattern::getVoidType() const { 46 return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext()); 47 } 48 49 Type ConvertToLLVMPattern::getVoidPtrType() const { 50 return LLVM::LLVMPointerType::get(&getTypeConverter()->getContext()); 51 } 52 53 Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, 54 Location loc, 55 Type resultType, 56 int64_t value) { 57 return builder.create<LLVM::ConstantOp>(loc, resultType, 58 builder.getIndexAttr(value)); 59 } 60 61 Value ConvertToLLVMPattern::getStridedElementPtr( 62 Location loc, MemRefType type, Value memRefDesc, ValueRange indices, 63 ConversionPatternRewriter &rewriter) const { 64 65 auto [strides, offset] = type.getStridesAndOffset(); 66 67 MemRefDescriptor memRefDescriptor(memRefDesc); 68 // Use a canonical representation of the start address so that later 69 // optimizations have a longer sequence of instructions to CSE. 70 // If we don't do that we would sprinkle the memref.offset in various 71 // position of the different address computations. 72 Value base = 73 memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type); 74 75 Type indexType = getIndexType(); 76 Value index; 77 for (int i = 0, e = indices.size(); i < e; ++i) { 78 Value increment = indices[i]; 79 if (strides[i] != 1) { // Skip if stride is 1. 80 Value stride = 81 ShapedType::isDynamic(strides[i]) 82 ? memRefDescriptor.stride(rewriter, loc, i) 83 : createIndexAttrConstant(rewriter, loc, indexType, strides[i]); 84 increment = rewriter.create<LLVM::MulOp>(loc, increment, stride); 85 } 86 index = 87 index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment; 88 } 89 90 Type elementPtrType = memRefDescriptor.getElementPtrType(); 91 return index ? rewriter.create<LLVM::GEPOp>( 92 loc, elementPtrType, 93 getTypeConverter()->convertType(type.getElementType()), 94 base, index) 95 : base; 96 } 97 98 // Check if the MemRefType `type` is supported by the lowering. We currently 99 // only support memrefs with identity maps. 100 bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps( 101 MemRefType type) const { 102 if (!typeConverter->convertType(type.getElementType())) 103 return false; 104 return type.getLayout().isIdentity(); 105 } 106 107 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { 108 auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type); 109 if (failed(addressSpace)) 110 return {}; 111 return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace); 112 } 113 114 void ConvertToLLVMPattern::getMemRefDescriptorSizes( 115 Location loc, MemRefType memRefType, ValueRange dynamicSizes, 116 ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes, 117 SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const { 118 assert(isConvertibleAndHasIdentityMaps(memRefType) && 119 "layout maps must have been normalized away"); 120 assert(count(memRefType.getShape(), ShapedType::kDynamic) == 121 static_cast<ssize_t>(dynamicSizes.size()) && 122 "dynamicSizes size doesn't match dynamic sizes count in memref shape"); 123 124 sizes.reserve(memRefType.getRank()); 125 unsigned dynamicIndex = 0; 126 Type indexType = getIndexType(); 127 for (int64_t size : memRefType.getShape()) { 128 sizes.push_back( 129 size == ShapedType::kDynamic 130 ? dynamicSizes[dynamicIndex++] 131 : createIndexAttrConstant(rewriter, loc, indexType, size)); 132 } 133 134 // Strides: iterate sizes in reverse order and multiply. 135 int64_t stride = 1; 136 Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1); 137 strides.resize(memRefType.getRank()); 138 for (auto i = memRefType.getRank(); i-- > 0;) { 139 strides[i] = runningStride; 140 141 int64_t staticSize = memRefType.getShape()[i]; 142 bool useSizeAsStride = stride == 1; 143 if (staticSize == ShapedType::kDynamic) 144 stride = ShapedType::kDynamic; 145 if (stride != ShapedType::kDynamic) 146 stride *= staticSize; 147 148 if (useSizeAsStride) 149 runningStride = sizes[i]; 150 else if (stride == ShapedType::kDynamic) 151 runningStride = 152 rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]); 153 else 154 runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride); 155 } 156 if (sizeInBytes) { 157 // Buffer size in bytes. 158 Type elementType = typeConverter->convertType(memRefType.getElementType()); 159 auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); 160 Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType); 161 Value gepPtr = rewriter.create<LLVM::GEPOp>( 162 loc, elementPtrType, elementType, nullPtr, runningStride); 163 size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); 164 } else { 165 size = runningStride; 166 } 167 } 168 169 Value ConvertToLLVMPattern::getSizeInBytes( 170 Location loc, Type type, ConversionPatternRewriter &rewriter) const { 171 // Compute the size of an individual element. This emits the MLIR equivalent 172 // of the following sizeof(...) implementation in LLVM IR: 173 // %0 = getelementptr %elementType* null, %indexType 1 174 // %1 = ptrtoint %elementType* %0 to %indexType 175 // which is a common pattern of getting the size of a type in bytes. 176 Type llvmType = typeConverter->convertType(type); 177 auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); 178 auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, convertedPtrType); 179 auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, llvmType, 180 nullPtr, ArrayRef<LLVM::GEPArg>{1}); 181 return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep); 182 } 183 184 Value ConvertToLLVMPattern::getNumElements( 185 Location loc, MemRefType memRefType, ValueRange dynamicSizes, 186 ConversionPatternRewriter &rewriter) const { 187 assert(count(memRefType.getShape(), ShapedType::kDynamic) == 188 static_cast<ssize_t>(dynamicSizes.size()) && 189 "dynamicSizes size doesn't match dynamic sizes count in memref shape"); 190 191 Type indexType = getIndexType(); 192 Value numElements = memRefType.getRank() == 0 193 ? createIndexAttrConstant(rewriter, loc, indexType, 1) 194 : nullptr; 195 unsigned dynamicIndex = 0; 196 197 // Compute the total number of memref elements. 198 for (int64_t staticSize : memRefType.getShape()) { 199 if (numElements) { 200 Value size = 201 staticSize == ShapedType::kDynamic 202 ? dynamicSizes[dynamicIndex++] 203 : createIndexAttrConstant(rewriter, loc, indexType, staticSize); 204 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 205 } else { 206 numElements = 207 staticSize == ShapedType::kDynamic 208 ? dynamicSizes[dynamicIndex++] 209 : createIndexAttrConstant(rewriter, loc, indexType, staticSize); 210 } 211 } 212 return numElements; 213 } 214 215 /// Creates and populates the memref descriptor struct given all its fields. 216 MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( 217 Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, 218 ArrayRef<Value> sizes, ArrayRef<Value> strides, 219 ConversionPatternRewriter &rewriter) const { 220 auto structType = typeConverter->convertType(memRefType); 221 auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); 222 223 // Field 1: Allocated pointer, used for malloc/free. 224 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr); 225 226 // Field 2: Actual aligned pointer to payload. 227 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); 228 229 // Field 3: Offset in aligned pointer. 230 Type indexType = getIndexType(); 231 memRefDescriptor.setOffset( 232 rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0)); 233 234 // Fields 4: Sizes. 235 for (const auto &en : llvm::enumerate(sizes)) 236 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value()); 237 238 // Field 5: Strides. 239 for (const auto &en : llvm::enumerate(strides)) 240 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value()); 241 242 return memRefDescriptor; 243 } 244 245 LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( 246 OpBuilder &builder, Location loc, TypeRange origTypes, 247 SmallVectorImpl<Value> &operands, bool toDynamic) const { 248 assert(origTypes.size() == operands.size() && 249 "expected as may original types as operands"); 250 251 // Find operands of unranked memref type and store them. 252 SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs; 253 SmallVector<unsigned> unrankedAddressSpaces; 254 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 255 if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) { 256 unrankedMemrefs.emplace_back(operands[i]); 257 FailureOr<unsigned> addressSpace = 258 getTypeConverter()->getMemRefAddressSpace(memRefType); 259 if (failed(addressSpace)) 260 return failure(); 261 unrankedAddressSpaces.emplace_back(*addressSpace); 262 } 263 } 264 265 if (unrankedMemrefs.empty()) 266 return success(); 267 268 // Compute allocation sizes. 269 SmallVector<Value> sizes; 270 UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(), 271 unrankedMemrefs, unrankedAddressSpaces, 272 sizes); 273 274 // Get frequently used types. 275 Type indexType = getTypeConverter()->getIndexType(); 276 277 // Find the malloc and free, or declare them if necessary. 278 auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>(); 279 FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc; 280 if (toDynamic) { 281 mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType); 282 if (failed(mallocFunc)) 283 return failure(); 284 } 285 if (!toDynamic) { 286 freeFunc = LLVM::lookupOrCreateFreeFn(module); 287 if (failed(freeFunc)) 288 return failure(); 289 } 290 291 unsigned unrankedMemrefPos = 0; 292 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 293 Type type = origTypes[i]; 294 if (!isa<UnrankedMemRefType>(type)) 295 continue; 296 Value allocationSize = sizes[unrankedMemrefPos++]; 297 UnrankedMemRefDescriptor desc(operands[i]); 298 299 // Allocate memory, copy, and free the source if necessary. 300 Value memory = 301 toDynamic 302 ? builder 303 .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize) 304 .getResult() 305 : builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(), 306 IntegerType::get(getContext(), 8), 307 allocationSize, 308 /*alignment=*/0); 309 Value source = desc.memRefDescPtr(builder, loc); 310 builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false); 311 if (!toDynamic) 312 builder.create<LLVM::CallOp>(loc, freeFunc.value(), source); 313 314 // Create a new descriptor. The same descriptor can be returned multiple 315 // times, attempting to modify its pointer can lead to memory leaks 316 // (allocated twice and overwritten) or double frees (the caller does not 317 // know if the descriptor points to the same memory). 318 Type descriptorType = getTypeConverter()->convertType(type); 319 if (!descriptorType) 320 return failure(); 321 auto updatedDesc = 322 UnrankedMemRefDescriptor::undef(builder, loc, descriptorType); 323 Value rank = desc.rank(builder, loc); 324 updatedDesc.setRank(builder, loc, rank); 325 updatedDesc.setMemRefDescPtr(builder, loc, memory); 326 327 operands[i] = updatedDesc; 328 } 329 330 return success(); 331 } 332 333 //===----------------------------------------------------------------------===// 334 // Detail methods 335 //===----------------------------------------------------------------------===// 336 337 void LLVM::detail::setNativeProperties(Operation *op, 338 IntegerOverflowFlags overflowFlags) { 339 if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) 340 iface.setOverflowFlags(overflowFlags); 341 } 342 343 /// Replaces the given operation "op" with a new operation of type "targetOp" 344 /// and given operands. 345 LogicalResult LLVM::detail::oneToOneRewrite( 346 Operation *op, StringRef targetOp, ValueRange operands, 347 ArrayRef<NamedAttribute> targetAttrs, 348 const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, 349 IntegerOverflowFlags overflowFlags) { 350 unsigned numResults = op->getNumResults(); 351 352 SmallVector<Type> resultTypes; 353 if (numResults != 0) { 354 resultTypes.push_back( 355 typeConverter.packOperationResults(op->getResultTypes())); 356 if (!resultTypes.back()) 357 return failure(); 358 } 359 360 // Create the operation through state since we don't know its C++ type. 361 Operation *newOp = 362 rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, 363 resultTypes, targetAttrs); 364 365 setNativeProperties(newOp, overflowFlags); 366 367 // If the operation produced 0 or 1 result, return them immediately. 368 if (numResults == 0) 369 return rewriter.eraseOp(op), success(); 370 if (numResults == 1) 371 return rewriter.replaceOp(op, newOp->getResult(0)), success(); 372 373 // Otherwise, it had been converted to an operation producing a structure. 374 // Extract individual results from the structure and return them as list. 375 SmallVector<Value, 4> results; 376 results.reserve(numResults); 377 for (unsigned i = 0; i < numResults; ++i) { 378 results.push_back(rewriter.create<LLVM::ExtractValueOp>( 379 op->getLoc(), newOp->getResult(0), i)); 380 } 381 rewriter.replaceOp(op, results); 382 return success(); 383 } 384