1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 10 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 14 #include "mlir/Conversion/LLVMCommon/Pattern.h" 15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 16 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" 17 #include "mlir/Dialect/Arith/IR/Arith.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/BuiltinTypes.h" 26 #include "mlir/IR/IRMapping.h" 27 #include "mlir/Pass/Pass.h" 28 #include "llvm/ADT/SmallBitVector.h" 29 #include "llvm/Support/MathExtras.h" 30 #include <optional> 31 32 namespace mlir { 33 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS 34 #include "mlir/Conversion/Passes.h.inc" 35 } // namespace mlir 36 37 using namespace mlir; 38 39 namespace { 40 41 static bool isStaticStrideOrOffset(int64_t strideOrOffset) { 42 return !ShapedType::isDynamic(strideOrOffset); 43 } 44 45 static FailureOr<LLVM::LLVMFuncOp> 46 getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) { 47 bool useGenericFn = typeConverter->getOptions().useGenericFunctions; 48 49 if (useGenericFn) 50 return LLVM::lookupOrCreateGenericFreeFn(module); 51 52 return LLVM::lookupOrCreateFreeFn(module); 53 } 54 55 struct AllocOpLowering : public AllocLikeOpLLVMLowering { 56 AllocOpLowering(const LLVMTypeConverter &converter) 57 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 58 converter) {} 59 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 60 Location loc, Value sizeBytes, 61 Operation *op) const override { 62 return allocateBufferManuallyAlign( 63 rewriter, loc, sizeBytes, op, 64 getAlignment(rewriter, loc, cast<memref::AllocOp>(op))); 65 } 66 }; 67 68 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { 69 AlignedAllocOpLowering(const LLVMTypeConverter &converter) 70 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), 71 converter) {} 72 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 73 Location loc, Value sizeBytes, 74 Operation *op) const override { 75 Value ptr = allocateBufferAutoAlign( 76 rewriter, loc, sizeBytes, op, &defaultLayout, 77 alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op), 78 &defaultLayout)); 79 if (!ptr) 80 return std::make_tuple(Value(), Value()); 81 return std::make_tuple(ptr, ptr); 82 } 83 84 private: 85 /// Default layout to use in absence of the corresponding analysis. 86 DataLayout defaultLayout; 87 }; 88 89 struct AllocaOpLowering : public AllocLikeOpLLVMLowering { 90 AllocaOpLowering(const LLVMTypeConverter &converter) 91 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), 92 converter) { 93 setRequiresNumElements(); 94 } 95 96 /// Allocates the underlying buffer using the right call. `allocatedBytePtr` 97 /// is set to null for stack allocations. `accessAlignment` is set if 98 /// alignment is needed post allocation (for eg. in conjunction with malloc). 99 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 100 Location loc, Value size, 101 Operation *op) const override { 102 103 // With alloca, one gets a pointer to the element type right away. 104 // For stack allocations. 105 auto allocaOp = cast<memref::AllocaOp>(op); 106 auto elementType = 107 typeConverter->convertType(allocaOp.getType().getElementType()); 108 unsigned addrSpace = 109 *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType()); 110 auto elementPtrType = 111 LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace); 112 113 auto allocatedElementPtr = 114 rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size, 115 allocaOp.getAlignment().value_or(0)); 116 117 return std::make_tuple(allocatedElementPtr, allocatedElementPtr); 118 } 119 }; 120 121 struct AllocaScopeOpLowering 122 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { 123 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; 124 125 LogicalResult 126 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, 127 ConversionPatternRewriter &rewriter) const override { 128 OpBuilder::InsertionGuard guard(rewriter); 129 Location loc = allocaScopeOp.getLoc(); 130 131 // Split the current block before the AllocaScopeOp to create the inlining 132 // point. 133 auto *currentBlock = rewriter.getInsertionBlock(); 134 auto *remainingOpsBlock = 135 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 136 Block *continueBlock; 137 if (allocaScopeOp.getNumResults() == 0) { 138 continueBlock = remainingOpsBlock; 139 } else { 140 continueBlock = rewriter.createBlock( 141 remainingOpsBlock, allocaScopeOp.getResultTypes(), 142 SmallVector<Location>(allocaScopeOp->getNumResults(), 143 allocaScopeOp.getLoc())); 144 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); 145 } 146 147 // Inline body region. 148 Block *beforeBody = &allocaScopeOp.getBodyRegion().front(); 149 Block *afterBody = &allocaScopeOp.getBodyRegion().back(); 150 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock); 151 152 // Save stack and then branch into the body of the region. 153 rewriter.setInsertionPointToEnd(currentBlock); 154 auto stackSaveOp = 155 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 156 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); 157 158 // Replace the alloca_scope return with a branch that jumps out of the body. 159 // Stack restore before leaving the body region. 160 rewriter.setInsertionPointToEnd(afterBody); 161 auto returnOp = 162 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); 163 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 164 returnOp, returnOp.getResults(), continueBlock); 165 166 // Insert stack restore before jumping out the body of the region. 167 rewriter.setInsertionPoint(branchOp); 168 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 169 170 // Replace the op with values return from the body region. 171 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); 172 173 return success(); 174 } 175 }; 176 177 struct AssumeAlignmentOpLowering 178 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { 179 using ConvertOpToLLVMPattern< 180 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; 181 explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter) 182 : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {} 183 184 LogicalResult 185 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 186 ConversionPatternRewriter &rewriter) const override { 187 Value memref = adaptor.getMemref(); 188 unsigned alignment = op.getAlignment(); 189 auto loc = op.getLoc(); 190 191 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType()); 192 Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{}, 193 rewriter); 194 195 // Emit llvm.assume(true) ["align"(memref, alignment)]. 196 // This is more direct than ptrtoint-based checks, is explicitly supported, 197 // and works with non-integral address spaces. 198 Value trueCond = 199 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true)); 200 Value alignmentConst = 201 createIndexAttrConstant(rewriter, loc, getIndexType(), alignment); 202 rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr, 203 alignmentConst); 204 205 rewriter.eraseOp(op); 206 return success(); 207 } 208 }; 209 210 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 211 // The memref descriptor being an SSA value, there is no need to clean it up 212 // in any way. 213 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { 214 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; 215 216 explicit DeallocOpLowering(const LLVMTypeConverter &converter) 217 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} 218 219 LogicalResult 220 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 221 ConversionPatternRewriter &rewriter) const override { 222 // Insert the `free` declaration if it is not already present. 223 FailureOr<LLVM::LLVMFuncOp> freeFunc = 224 getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>()); 225 if (failed(freeFunc)) 226 return failure(); 227 Value allocatedPtr; 228 if (auto unrankedTy = 229 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) { 230 auto elementPtrTy = LLVM::LLVMPointerType::get( 231 rewriter.getContext(), unrankedTy.getMemorySpaceAsInt()); 232 allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 233 rewriter, op.getLoc(), 234 UnrankedMemRefDescriptor(adaptor.getMemref()) 235 .memRefDescPtr(rewriter, op.getLoc()), 236 elementPtrTy); 237 } else { 238 allocatedPtr = MemRefDescriptor(adaptor.getMemref()) 239 .allocatedPtr(rewriter, op.getLoc()); 240 } 241 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(), 242 allocatedPtr); 243 return success(); 244 } 245 }; 246 247 // A `dim` is converted to a constant for static sizes and to an access to the 248 // size stored in the memref descriptor for dynamic sizes. 249 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { 250 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; 251 252 LogicalResult 253 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, 254 ConversionPatternRewriter &rewriter) const override { 255 Type operandType = dimOp.getSource().getType(); 256 if (isa<UnrankedMemRefType>(operandType)) { 257 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef( 258 operandType, dimOp, adaptor.getOperands(), rewriter); 259 if (failed(extractedSize)) 260 return failure(); 261 rewriter.replaceOp(dimOp, {*extractedSize}); 262 return success(); 263 } 264 if (isa<MemRefType>(operandType)) { 265 rewriter.replaceOp( 266 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, 267 adaptor.getOperands(), rewriter)}); 268 return success(); 269 } 270 llvm_unreachable("expected MemRefType or UnrankedMemRefType"); 271 } 272 273 private: 274 FailureOr<Value> 275 extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, 276 OpAdaptor adaptor, 277 ConversionPatternRewriter &rewriter) const { 278 Location loc = dimOp.getLoc(); 279 280 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType); 281 auto scalarMemRefType = 282 MemRefType::get({}, unrankedMemRefType.getElementType()); 283 FailureOr<unsigned> maybeAddressSpace = 284 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType); 285 if (failed(maybeAddressSpace)) { 286 dimOp.emitOpError("memref memory space must be convertible to an integer " 287 "address space"); 288 return failure(); 289 } 290 unsigned addressSpace = *maybeAddressSpace; 291 292 // Extract pointer to the underlying ranked descriptor and bitcast it to a 293 // memref<element_type> descriptor pointer to minimize the number of GEP 294 // operations. 295 UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource()); 296 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); 297 298 Type elementType = typeConverter->convertType(scalarMemRefType); 299 300 // Get pointer to offset field of memref<element_type> descriptor. 301 auto indexPtrTy = 302 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); 303 Value offsetPtr = rewriter.create<LLVM::GEPOp>( 304 loc, indexPtrTy, elementType, underlyingRankedDesc, 305 ArrayRef<LLVM::GEPArg>{0, 2}); 306 307 // The size value that we have to extract can be obtained using GEPop with 308 // `dimOp.index() + 1` index argument. 309 Value idxPlusOne = rewriter.create<LLVM::AddOp>( 310 loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1), 311 adaptor.getIndex()); 312 Value sizePtr = rewriter.create<LLVM::GEPOp>( 313 loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr, 314 idxPlusOne); 315 return rewriter 316 .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr) 317 .getResult(); 318 } 319 320 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { 321 if (auto idx = dimOp.getConstantIndex()) 322 return idx; 323 324 if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>()) 325 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue(); 326 327 return std::nullopt; 328 } 329 330 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, 331 OpAdaptor adaptor, 332 ConversionPatternRewriter &rewriter) const { 333 Location loc = dimOp.getLoc(); 334 335 // Take advantage if index is constant. 336 MemRefType memRefType = cast<MemRefType>(operandType); 337 Type indexType = getIndexType(); 338 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) { 339 int64_t i = *index; 340 if (i >= 0 && i < memRefType.getRank()) { 341 if (memRefType.isDynamicDim(i)) { 342 // extract dynamic size from the memref descriptor. 343 MemRefDescriptor descriptor(adaptor.getSource()); 344 return descriptor.size(rewriter, loc, i); 345 } 346 // Use constant for static size. 347 int64_t dimSize = memRefType.getDimSize(i); 348 return createIndexAttrConstant(rewriter, loc, indexType, dimSize); 349 } 350 } 351 Value index = adaptor.getIndex(); 352 int64_t rank = memRefType.getRank(); 353 MemRefDescriptor memrefDescriptor(adaptor.getSource()); 354 return memrefDescriptor.size(rewriter, loc, index, rank); 355 } 356 }; 357 358 /// Common base for load and store operations on MemRefs. Restricts the match 359 /// to supported MemRef types. Provides functionality to emit code accessing a 360 /// specific element of the underlying data buffer. 361 template <typename Derived> 362 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { 363 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; 364 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; 365 using Base = LoadStoreOpLowering<Derived>; 366 367 LogicalResult match(Derived op) const override { 368 MemRefType type = op.getMemRefType(); 369 return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); 370 } 371 }; 372 373 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be 374 /// retried until it succeeds in atomically storing a new value into memory. 375 /// 376 /// +---------------------------------+ 377 /// | <code before the AtomicRMWOp> | 378 /// | <compute initial %loaded> | 379 /// | cf.br loop(%loaded) | 380 /// +---------------------------------+ 381 /// | 382 /// -------| | 383 /// | v v 384 /// | +--------------------------------+ 385 /// | | loop(%loaded): | 386 /// | | <body contents> | 387 /// | | %pair = cmpxchg | 388 /// | | %ok = %pair[0] | 389 /// | | %new = %pair[1] | 390 /// | | cf.cond_br %ok, end, loop(%new) | 391 /// | +--------------------------------+ 392 /// | | | 393 /// |----------- | 394 /// v 395 /// +--------------------------------+ 396 /// | end: | 397 /// | <code after the AtomicRMWOp> | 398 /// +--------------------------------+ 399 /// 400 struct GenericAtomicRMWOpLowering 401 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { 402 using Base::Base; 403 404 LogicalResult 405 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, 406 ConversionPatternRewriter &rewriter) const override { 407 auto loc = atomicOp.getLoc(); 408 Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); 409 410 // Split the block into initial, loop, and ending parts. 411 auto *initBlock = rewriter.getInsertionBlock(); 412 auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp)); 413 loopBlock->addArgument(valueType, loc); 414 415 auto *endBlock = 416 rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++); 417 418 // Compute the loaded value and branch to the loop block. 419 rewriter.setInsertionPointToEnd(initBlock); 420 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType()); 421 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(), 422 adaptor.getIndices(), rewriter); 423 Value init = rewriter.create<LLVM::LoadOp>( 424 loc, typeConverter->convertType(memRefType.getElementType()), dataPtr); 425 rewriter.create<LLVM::BrOp>(loc, init, loopBlock); 426 427 // Prepare the body of the loop block. 428 rewriter.setInsertionPointToStart(loopBlock); 429 430 // Clone the GenericAtomicRMWOp region and extract the result. 431 auto loopArgument = loopBlock->getArgument(0); 432 IRMapping mapping; 433 mapping.map(atomicOp.getCurrentValue(), loopArgument); 434 Block &entryBlock = atomicOp.body().front(); 435 for (auto &nestedOp : entryBlock.without_terminator()) { 436 Operation *clone = rewriter.clone(nestedOp, mapping); 437 mapping.map(nestedOp.getResults(), clone->getResults()); 438 } 439 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); 440 441 // Prepare the epilog of the loop block. 442 // Append the cmpxchg op to the end of the loop block. 443 auto successOrdering = LLVM::AtomicOrdering::acq_rel; 444 auto failureOrdering = LLVM::AtomicOrdering::monotonic; 445 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( 446 loc, dataPtr, loopArgument, result, successOrdering, failureOrdering); 447 // Extract the %new_loaded and %ok values from the pair. 448 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0); 449 Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1); 450 451 // Conditionally branch to the end or back to the loop depending on %ok. 452 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), 453 loopBlock, newLoaded); 454 455 rewriter.setInsertionPointToEnd(endBlock); 456 457 // The 'result' of the atomic_rmw op is the newly loaded value. 458 rewriter.replaceOp(atomicOp, {newLoaded}); 459 460 return success(); 461 } 462 }; 463 464 /// Returns the LLVM type of the global variable given the memref type `type`. 465 static Type 466 convertGlobalMemrefTypeToLLVM(MemRefType type, 467 const LLVMTypeConverter &typeConverter) { 468 // LLVM type for a global memref will be a multi-dimension array. For 469 // declarations or uninitialized global memrefs, we can potentially flatten 470 // this to a 1D array. However, for memref.global's with an initial value, 471 // we do not intend to flatten the ElementsAttribute when going from std -> 472 // LLVM dialect, so the LLVM type needs to me a multi-dimension array. 473 Type elementType = typeConverter.convertType(type.getElementType()); 474 Type arrayTy = elementType; 475 // Shape has the outermost dim at index 0, so need to walk it backwards 476 for (int64_t dim : llvm::reverse(type.getShape())) 477 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); 478 return arrayTy; 479 } 480 481 /// GlobalMemrefOp is lowered to a LLVM Global Variable. 482 struct GlobalMemrefOpLowering 483 : public ConvertOpToLLVMPattern<memref::GlobalOp> { 484 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; 485 486 LogicalResult 487 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, 488 ConversionPatternRewriter &rewriter) const override { 489 MemRefType type = global.getType(); 490 if (!isConvertibleAndHasIdentityMaps(type)) 491 return failure(); 492 493 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 494 495 LLVM::Linkage linkage = 496 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; 497 498 Attribute initialValue = nullptr; 499 if (!global.isExternal() && !global.isUninitialized()) { 500 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue()); 501 initialValue = elementsAttr; 502 503 // For scalar memrefs, the global variable created is of the element type, 504 // so unpack the elements attribute to extract the value. 505 if (type.getRank() == 0) 506 initialValue = elementsAttr.getSplatValue<Attribute>(); 507 } 508 509 uint64_t alignment = global.getAlignment().value_or(0); 510 FailureOr<unsigned> addressSpace = 511 getTypeConverter()->getMemRefAddressSpace(type); 512 if (failed(addressSpace)) 513 return global.emitOpError( 514 "memory space cannot be converted to an integer address space"); 515 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 516 global, arrayTy, global.getConstant(), linkage, global.getSymName(), 517 initialValue, alignment, *addressSpace); 518 if (!global.isExternal() && global.isUninitialized()) { 519 rewriter.createBlock(&newGlobal.getInitializerRegion()); 520 Value undef[] = { 521 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; 522 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); 523 } 524 return success(); 525 } 526 }; 527 528 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to 529 /// the first element stashed into the descriptor. This reuses 530 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. 531 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { 532 GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter) 533 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), 534 converter) {} 535 536 /// Buffer "allocation" for memref.get_global op is getting the address of 537 /// the global variable referenced. 538 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, 539 Location loc, Value sizeBytes, 540 Operation *op) const override { 541 auto getGlobalOp = cast<memref::GetGlobalOp>(op); 542 MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType()); 543 544 // This is called after a type conversion, which would have failed if this 545 // call fails. 546 FailureOr<unsigned> maybeAddressSpace = 547 getTypeConverter()->getMemRefAddressSpace(type); 548 if (failed(maybeAddressSpace)) 549 return std::make_tuple(Value(), Value()); 550 unsigned memSpace = *maybeAddressSpace; 551 552 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); 553 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace); 554 auto addressOf = 555 rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName()); 556 557 // Get the address of the first element in the array by creating a GEP with 558 // the address of the GV as the base, and (rank + 1) number of 0 indices. 559 auto gep = rewriter.create<LLVM::GEPOp>( 560 loc, ptrTy, arrayTy, addressOf, 561 SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0)); 562 563 // We do not expect the memref obtained using `memref.get_global` to be 564 // ever deallocated. Set the allocated pointer to be known bad value to 565 // help debug if that ever happens. 566 auto intPtrType = getIntPtrType(memSpace); 567 Value deadBeefConst = 568 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); 569 auto deadBeefPtr = 570 rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst); 571 572 // Both allocated and aligned pointers are same. We could potentially stash 573 // a nullptr for the allocated pointer since we do not expect any dealloc. 574 return std::make_tuple(deadBeefPtr, gep); 575 } 576 }; 577 578 // Load operation is lowered to obtaining a pointer to the indexed element 579 // and loading it. 580 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { 581 using Base::Base; 582 583 LogicalResult 584 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 585 ConversionPatternRewriter &rewriter) const override { 586 auto type = loadOp.getMemRefType(); 587 588 Value dataPtr = 589 getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(), 590 adaptor.getIndices(), rewriter); 591 rewriter.replaceOpWithNewOp<LLVM::LoadOp>( 592 loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0, 593 false, loadOp.getNontemporal()); 594 return success(); 595 } 596 }; 597 598 // Store operation is lowered to obtaining a pointer to the indexed element, 599 // and storing the given value to it. 600 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { 601 using Base::Base; 602 603 LogicalResult 604 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 605 ConversionPatternRewriter &rewriter) const override { 606 auto type = op.getMemRefType(); 607 608 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(), 609 adaptor.getIndices(), rewriter); 610 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr, 611 0, false, op.getNontemporal()); 612 return success(); 613 } 614 }; 615 616 // The prefetch operation is lowered in a way similar to the load operation 617 // except that the llvm.prefetch operation is used for replacement. 618 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { 619 using Base::Base; 620 621 LogicalResult 622 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, 623 ConversionPatternRewriter &rewriter) const override { 624 auto type = prefetchOp.getMemRefType(); 625 auto loc = prefetchOp.getLoc(); 626 627 Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(), 628 adaptor.getIndices(), rewriter); 629 630 // Replace with llvm.prefetch. 631 IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite()); 632 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr(); 633 IntegerAttr isData = 634 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache()); 635 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, 636 localityHint, isData); 637 return success(); 638 } 639 }; 640 641 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { 642 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; 643 644 LogicalResult 645 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, 646 ConversionPatternRewriter &rewriter) const override { 647 Location loc = op.getLoc(); 648 Type operandType = op.getMemref().getType(); 649 if (dyn_cast<UnrankedMemRefType>(operandType)) { 650 UnrankedMemRefDescriptor desc(adaptor.getMemref()); 651 rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); 652 return success(); 653 } 654 if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) { 655 Type indexType = getIndexType(); 656 rewriter.replaceOp(op, 657 {createIndexAttrConstant(rewriter, loc, indexType, 658 rankedMemRefType.getRank())}); 659 return success(); 660 } 661 return failure(); 662 } 663 }; 664 665 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { 666 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; 667 668 LogicalResult match(memref::CastOp memRefCastOp) const override { 669 Type srcType = memRefCastOp.getOperand().getType(); 670 Type dstType = memRefCastOp.getType(); 671 672 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be 673 // used for type erasure. For now they must preserve underlying element type 674 // and require source and result type to have the same rank. Therefore, 675 // perform a sanity check that the underlying structs are the same. Once op 676 // semantics are relaxed we can revisit. 677 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) 678 return success(typeConverter->convertType(srcType) == 679 typeConverter->convertType(dstType)); 680 681 // At least one of the operands is unranked type 682 assert(isa<UnrankedMemRefType>(srcType) || 683 isa<UnrankedMemRefType>(dstType)); 684 685 // Unranked to unranked cast is disallowed 686 return !(isa<UnrankedMemRefType>(srcType) && 687 isa<UnrankedMemRefType>(dstType)) 688 ? success() 689 : failure(); 690 } 691 692 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, 693 ConversionPatternRewriter &rewriter) const override { 694 auto srcType = memRefCastOp.getOperand().getType(); 695 auto dstType = memRefCastOp.getType(); 696 auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); 697 auto loc = memRefCastOp.getLoc(); 698 699 // For ranked/ranked case, just keep the original descriptor. 700 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) 701 return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()}); 702 703 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) { 704 // Casting ranked to unranked memref type 705 // Set the rank in the destination from the memref type 706 // Allocate space on the stack and copy the src memref descriptor 707 // Set the ptr in the destination to the stack space 708 auto srcMemRefType = cast<MemRefType>(srcType); 709 int64_t rank = srcMemRefType.getRank(); 710 // ptr = AllocaOp sizeof(MemRefDescriptor) 711 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( 712 loc, adaptor.getSource(), rewriter); 713 714 // rank = ConstantOp srcRank 715 auto rankVal = rewriter.create<LLVM::ConstantOp>( 716 loc, getIndexType(), rewriter.getIndexAttr(rank)); 717 // undef = UndefOp 718 UnrankedMemRefDescriptor memRefDesc = 719 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); 720 // d1 = InsertValueOp undef, rank, 0 721 memRefDesc.setRank(rewriter, loc, rankVal); 722 // d2 = InsertValueOp d1, ptr, 1 723 memRefDesc.setMemRefDescPtr(rewriter, loc, ptr); 724 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); 725 726 } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) { 727 // Casting from unranked type to ranked. 728 // The operation is assumed to be doing a correct cast. If the destination 729 // type mismatches the unranked the type, it is undefined behavior. 730 UnrankedMemRefDescriptor memRefDesc(adaptor.getSource()); 731 // ptr = ExtractValueOp src, 1 732 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); 733 734 // struct = LoadOp ptr 735 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr); 736 rewriter.replaceOp(memRefCastOp, loadOp.getResult()); 737 } else { 738 llvm_unreachable("Unsupported unranked memref to unranked memref cast"); 739 } 740 } 741 }; 742 743 /// Pattern to lower a `memref.copy` to llvm. 744 /// 745 /// For memrefs with identity layouts, the copy is lowered to the llvm 746 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call 747 /// to the generic `MemrefCopyFn`. 748 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { 749 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; 750 751 LogicalResult 752 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, 753 ConversionPatternRewriter &rewriter) const { 754 auto loc = op.getLoc(); 755 auto srcType = dyn_cast<MemRefType>(op.getSource().getType()); 756 757 MemRefDescriptor srcDesc(adaptor.getSource()); 758 759 // Compute number of elements. 760 Value numElements = rewriter.create<LLVM::ConstantOp>( 761 loc, getIndexType(), rewriter.getIndexAttr(1)); 762 for (int pos = 0; pos < srcType.getRank(); ++pos) { 763 auto size = srcDesc.size(rewriter, loc, pos); 764 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); 765 } 766 767 // Get element size. 768 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); 769 // Compute total. 770 Value totalSize = 771 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); 772 773 Type elementType = typeConverter->convertType(srcType.getElementType()); 774 775 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); 776 Value srcOffset = srcDesc.offset(rewriter, loc); 777 Value srcPtr = rewriter.create<LLVM::GEPOp>( 778 loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset); 779 MemRefDescriptor targetDesc(adaptor.getTarget()); 780 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); 781 Value targetOffset = targetDesc.offset(rewriter, loc); 782 Value targetPtr = rewriter.create<LLVM::GEPOp>( 783 loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset); 784 rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize, 785 /*isVolatile=*/false); 786 rewriter.eraseOp(op); 787 788 return success(); 789 } 790 791 LogicalResult 792 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, 793 ConversionPatternRewriter &rewriter) const { 794 auto loc = op.getLoc(); 795 auto srcType = cast<BaseMemRefType>(op.getSource().getType()); 796 auto targetType = cast<BaseMemRefType>(op.getTarget().getType()); 797 798 // First make sure we have an unranked memref descriptor representation. 799 auto makeUnranked = [&, this](Value ranked, MemRefType type) { 800 auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 801 type.getRank()); 802 auto *typeConverter = getTypeConverter(); 803 auto ptr = 804 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); 805 806 auto unrankedType = 807 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); 808 return UnrankedMemRefDescriptor::pack( 809 rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr}); 810 }; 811 812 // Save stack position before promoting descriptors 813 auto stackSaveOp = 814 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); 815 816 auto srcMemRefType = dyn_cast<MemRefType>(srcType); 817 Value unrankedSource = 818 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType) 819 : adaptor.getSource(); 820 auto targetMemRefType = dyn_cast<MemRefType>(targetType); 821 Value unrankedTarget = 822 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType) 823 : adaptor.getTarget(); 824 825 // Now promote the unranked descriptors to the stack. 826 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 827 rewriter.getIndexAttr(1)); 828 auto promote = [&](Value desc) { 829 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); 830 auto allocated = 831 rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one); 832 rewriter.create<LLVM::StoreOp>(loc, desc, allocated); 833 return allocated; 834 }; 835 836 auto sourcePtr = promote(unrankedSource); 837 auto targetPtr = promote(unrankedTarget); 838 839 // Derive size from llvm.getelementptr which will account for any 840 // potential alignment 841 auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter); 842 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( 843 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); 844 if (failed(copyFn)) 845 return failure(); 846 rewriter.create<LLVM::CallOp>(loc, copyFn.value(), 847 ValueRange{elemSize, sourcePtr, targetPtr}); 848 849 // Restore stack used for descriptors 850 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); 851 852 rewriter.eraseOp(op); 853 854 return success(); 855 } 856 857 LogicalResult 858 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 859 ConversionPatternRewriter &rewriter) const override { 860 auto srcType = cast<BaseMemRefType>(op.getSource().getType()); 861 auto targetType = cast<BaseMemRefType>(op.getTarget().getType()); 862 863 auto isContiguousMemrefType = [&](BaseMemRefType type) { 864 auto memrefType = dyn_cast<mlir::MemRefType>(type); 865 // We can use memcpy for memrefs if they have an identity layout or are 866 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a 867 // special case handled by memrefCopy. 868 return memrefType && 869 (memrefType.getLayout().isIdentity() || 870 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && 871 memref::isStaticShapeAndContiguousRowMajor(memrefType))); 872 }; 873 874 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) 875 return lowerToMemCopyIntrinsic(op, adaptor, rewriter); 876 877 return lowerToMemCopyFunctionCall(op, adaptor, rewriter); 878 } 879 }; 880 881 struct MemorySpaceCastOpLowering 882 : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> { 883 using ConvertOpToLLVMPattern< 884 memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern; 885 886 LogicalResult 887 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor, 888 ConversionPatternRewriter &rewriter) const override { 889 Location loc = op.getLoc(); 890 891 Type resultType = op.getDest().getType(); 892 if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) { 893 auto resultDescType = 894 cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR)); 895 Type newPtrType = resultDescType.getBody()[0]; 896 897 SmallVector<Value> descVals; 898 MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR, 899 descVals); 900 descVals[0] = 901 rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]); 902 descVals[1] = 903 rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]); 904 Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(), 905 resultTypeR, descVals); 906 rewriter.replaceOp(op, result); 907 return success(); 908 } 909 if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) { 910 // Since the type converter won't be doing this for us, get the address 911 // space. 912 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType()); 913 FailureOr<unsigned> maybeSourceAddrSpace = 914 getTypeConverter()->getMemRefAddressSpace(sourceType); 915 if (failed(maybeSourceAddrSpace)) 916 return rewriter.notifyMatchFailure(loc, 917 "non-integer source address space"); 918 unsigned sourceAddrSpace = *maybeSourceAddrSpace; 919 FailureOr<unsigned> maybeResultAddrSpace = 920 getTypeConverter()->getMemRefAddressSpace(resultTypeU); 921 if (failed(maybeResultAddrSpace)) 922 return rewriter.notifyMatchFailure(loc, 923 "non-integer result address space"); 924 unsigned resultAddrSpace = *maybeResultAddrSpace; 925 926 UnrankedMemRefDescriptor sourceDesc(adaptor.getSource()); 927 Value rank = sourceDesc.rank(rewriter, loc); 928 Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc); 929 930 // Create and allocate storage for new memref descriptor. 931 auto result = UnrankedMemRefDescriptor::undef( 932 rewriter, loc, typeConverter->convertType(resultTypeU)); 933 result.setRank(rewriter, loc, rank); 934 SmallVector<Value, 1> sizes; 935 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 936 result, resultAddrSpace, sizes); 937 Value resultUnderlyingSize = sizes.front(); 938 Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>( 939 loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize); 940 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc); 941 942 // Copy pointers, performing address space casts. 943 auto sourceElemPtrType = 944 LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace); 945 auto resultElemPtrType = 946 LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace); 947 948 Value allocatedPtr = sourceDesc.allocatedPtr( 949 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType); 950 Value alignedPtr = 951 sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(), 952 sourceUnderlyingDesc, sourceElemPtrType); 953 allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 954 loc, resultElemPtrType, allocatedPtr); 955 alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 956 loc, resultElemPtrType, alignedPtr); 957 958 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc, 959 resultElemPtrType, allocatedPtr); 960 result.setAlignedPtr(rewriter, loc, *getTypeConverter(), 961 resultUnderlyingDesc, resultElemPtrType, alignedPtr); 962 963 // Copy all the index-valued operands. 964 Value sourceIndexVals = 965 sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(), 966 sourceUnderlyingDesc, sourceElemPtrType); 967 Value resultIndexVals = 968 result.offsetBasePtr(rewriter, loc, *getTypeConverter(), 969 resultUnderlyingDesc, resultElemPtrType); 970 971 int64_t bytesToSkip = 972 2 * llvm::divideCeil( 973 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8); 974 Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>( 975 loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip)); 976 Value copySize = rewriter.create<LLVM::SubOp>( 977 loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst); 978 rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals, 979 copySize, /*isVolatile=*/false); 980 981 rewriter.replaceOp(op, ValueRange{result}); 982 return success(); 983 } 984 return rewriter.notifyMatchFailure(loc, "unexpected memref type"); 985 } 986 }; 987 988 /// Extracts allocated, aligned pointers and offset from a ranked or unranked 989 /// memref type. In unranked case, the fields are extracted from the underlying 990 /// ranked descriptor. 991 static void extractPointersAndOffset(Location loc, 992 ConversionPatternRewriter &rewriter, 993 const LLVMTypeConverter &typeConverter, 994 Value originalOperand, 995 Value convertedOperand, 996 Value *allocatedPtr, Value *alignedPtr, 997 Value *offset = nullptr) { 998 Type operandType = originalOperand.getType(); 999 if (isa<MemRefType>(operandType)) { 1000 MemRefDescriptor desc(convertedOperand); 1001 *allocatedPtr = desc.allocatedPtr(rewriter, loc); 1002 *alignedPtr = desc.alignedPtr(rewriter, loc); 1003 if (offset != nullptr) 1004 *offset = desc.offset(rewriter, loc); 1005 return; 1006 } 1007 1008 // These will all cause assert()s on unconvertible types. 1009 unsigned memorySpace = *typeConverter.getMemRefAddressSpace( 1010 cast<UnrankedMemRefType>(operandType)); 1011 auto elementPtrType = 1012 LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace); 1013 1014 // Extract pointer to the underlying ranked memref descriptor and cast it to 1015 // ElemType**. 1016 UnrankedMemRefDescriptor unrankedDesc(convertedOperand); 1017 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); 1018 1019 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( 1020 rewriter, loc, underlyingDescPtr, elementPtrType); 1021 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 1022 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType); 1023 if (offset != nullptr) { 1024 *offset = UnrankedMemRefDescriptor::offset( 1025 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType); 1026 } 1027 } 1028 1029 struct MemRefReinterpretCastOpLowering 1030 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { 1031 using ConvertOpToLLVMPattern< 1032 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; 1033 1034 LogicalResult 1035 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, 1036 ConversionPatternRewriter &rewriter) const override { 1037 Type srcType = castOp.getSource().getType(); 1038 1039 Value descriptor; 1040 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, 1041 adaptor, &descriptor))) 1042 return failure(); 1043 rewriter.replaceOp(castOp, {descriptor}); 1044 return success(); 1045 } 1046 1047 private: 1048 LogicalResult convertSourceMemRefToDescriptor( 1049 ConversionPatternRewriter &rewriter, Type srcType, 1050 memref::ReinterpretCastOp castOp, 1051 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { 1052 MemRefType targetMemRefType = 1053 cast<MemRefType>(castOp.getResult().getType()); 1054 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( 1055 typeConverter->convertType(targetMemRefType)); 1056 if (!llvmTargetDescriptorTy) 1057 return failure(); 1058 1059 // Create descriptor. 1060 Location loc = castOp.getLoc(); 1061 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1062 1063 // Set allocated and aligned pointers. 1064 Value allocatedPtr, alignedPtr; 1065 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1066 castOp.getSource(), adaptor.getSource(), 1067 &allocatedPtr, &alignedPtr); 1068 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1069 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1070 1071 // Set offset. 1072 if (castOp.isDynamicOffset(0)) 1073 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]); 1074 else 1075 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); 1076 1077 // Set sizes and strides. 1078 unsigned dynSizeId = 0; 1079 unsigned dynStrideId = 0; 1080 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { 1081 if (castOp.isDynamicSize(i)) 1082 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]); 1083 else 1084 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); 1085 1086 if (castOp.isDynamicStride(i)) 1087 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]); 1088 else 1089 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); 1090 } 1091 *descriptor = desc; 1092 return success(); 1093 } 1094 }; 1095 1096 struct MemRefReshapeOpLowering 1097 : public ConvertOpToLLVMPattern<memref::ReshapeOp> { 1098 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; 1099 1100 LogicalResult 1101 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, 1102 ConversionPatternRewriter &rewriter) const override { 1103 Type srcType = reshapeOp.getSource().getType(); 1104 1105 Value descriptor; 1106 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, 1107 adaptor, &descriptor))) 1108 return failure(); 1109 rewriter.replaceOp(reshapeOp, {descriptor}); 1110 return success(); 1111 } 1112 1113 private: 1114 LogicalResult 1115 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, 1116 Type srcType, memref::ReshapeOp reshapeOp, 1117 memref::ReshapeOp::Adaptor adaptor, 1118 Value *descriptor) const { 1119 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType()); 1120 if (shapeMemRefType.hasStaticShape()) { 1121 MemRefType targetMemRefType = 1122 cast<MemRefType>(reshapeOp.getResult().getType()); 1123 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( 1124 typeConverter->convertType(targetMemRefType)); 1125 if (!llvmTargetDescriptorTy) 1126 return failure(); 1127 1128 // Create descriptor. 1129 Location loc = reshapeOp.getLoc(); 1130 auto desc = 1131 MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1132 1133 // Set allocated and aligned pointers. 1134 Value allocatedPtr, alignedPtr; 1135 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1136 reshapeOp.getSource(), adaptor.getSource(), 1137 &allocatedPtr, &alignedPtr); 1138 desc.setAllocatedPtr(rewriter, loc, allocatedPtr); 1139 desc.setAlignedPtr(rewriter, loc, alignedPtr); 1140 1141 // Extract the offset and strides from the type. 1142 int64_t offset; 1143 SmallVector<int64_t> strides; 1144 if (failed(targetMemRefType.getStridesAndOffset(strides, offset))) 1145 return rewriter.notifyMatchFailure( 1146 reshapeOp, "failed to get stride and offset exprs"); 1147 1148 if (!isStaticStrideOrOffset(offset)) 1149 return rewriter.notifyMatchFailure(reshapeOp, 1150 "dynamic offset is unsupported"); 1151 1152 desc.setConstantOffset(rewriter, loc, offset); 1153 1154 assert(targetMemRefType.getLayout().isIdentity() && 1155 "Identity layout map is a precondition of a valid reshape op"); 1156 1157 Type indexType = getIndexType(); 1158 Value stride = nullptr; 1159 int64_t targetRank = targetMemRefType.getRank(); 1160 for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) { 1161 if (!ShapedType::isDynamic(strides[i])) { 1162 // If the stride for this dimension is dynamic, then use the product 1163 // of the sizes of the inner dimensions. 1164 stride = 1165 createIndexAttrConstant(rewriter, loc, indexType, strides[i]); 1166 } else if (!stride) { 1167 // `stride` is null only in the first iteration of the loop. However, 1168 // since the target memref has an identity layout, we can safely set 1169 // the innermost stride to 1. 1170 stride = createIndexAttrConstant(rewriter, loc, indexType, 1); 1171 } 1172 1173 Value dimSize; 1174 // If the size of this dimension is dynamic, then load it at runtime 1175 // from the shape operand. 1176 if (!targetMemRefType.isDynamicDim(i)) { 1177 dimSize = createIndexAttrConstant(rewriter, loc, indexType, 1178 targetMemRefType.getDimSize(i)); 1179 } else { 1180 Value shapeOp = reshapeOp.getShape(); 1181 Value index = createIndexAttrConstant(rewriter, loc, indexType, i); 1182 dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index); 1183 Type indexType = getIndexType(); 1184 if (dimSize.getType() != indexType) 1185 dimSize = typeConverter->materializeTargetConversion( 1186 rewriter, loc, indexType, dimSize); 1187 assert(dimSize && "Invalid memref element type"); 1188 } 1189 1190 desc.setSize(rewriter, loc, i, dimSize); 1191 desc.setStride(rewriter, loc, i, stride); 1192 1193 // Prepare the stride value for the next dimension. 1194 stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize); 1195 } 1196 1197 *descriptor = desc; 1198 return success(); 1199 } 1200 1201 // The shape is a rank-1 tensor with unknown length. 1202 Location loc = reshapeOp.getLoc(); 1203 MemRefDescriptor shapeDesc(adaptor.getShape()); 1204 Value resultRank = shapeDesc.size(rewriter, loc, 0); 1205 1206 // Extract address space and element type. 1207 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType()); 1208 unsigned addressSpace = 1209 *getTypeConverter()->getMemRefAddressSpace(targetType); 1210 1211 // Create the unranked memref descriptor that holds the ranked one. The 1212 // inner descriptor is allocated on stack. 1213 auto targetDesc = UnrankedMemRefDescriptor::undef( 1214 rewriter, loc, typeConverter->convertType(targetType)); 1215 targetDesc.setRank(rewriter, loc, resultRank); 1216 SmallVector<Value, 4> sizes; 1217 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), 1218 targetDesc, addressSpace, sizes); 1219 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( 1220 loc, getVoidPtrType(), IntegerType::get(getContext(), 8), 1221 sizes.front()); 1222 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); 1223 1224 // Extract pointers and offset from the source memref. 1225 Value allocatedPtr, alignedPtr, offset; 1226 extractPointersAndOffset(loc, rewriter, *getTypeConverter(), 1227 reshapeOp.getSource(), adaptor.getSource(), 1228 &allocatedPtr, &alignedPtr, &offset); 1229 1230 // Set pointers and offset. 1231 auto elementPtrType = 1232 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); 1233 1234 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, 1235 elementPtrType, allocatedPtr); 1236 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), 1237 underlyingDescPtr, elementPtrType, 1238 alignedPtr); 1239 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), 1240 underlyingDescPtr, elementPtrType, 1241 offset); 1242 1243 // Use the offset pointer as base for further addressing. Copy over the new 1244 // shape and compute strides. For this, we create a loop from rank-1 to 0. 1245 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( 1246 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType); 1247 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( 1248 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); 1249 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); 1250 Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1); 1251 Value resultRankMinusOne = 1252 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); 1253 1254 Block *initBlock = rewriter.getInsertionBlock(); 1255 Type indexType = getTypeConverter()->getIndexType(); 1256 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); 1257 1258 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, 1259 {indexType, indexType}, {loc, loc}); 1260 1261 // Move the remaining initBlock ops to condBlock. 1262 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); 1263 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); 1264 1265 rewriter.setInsertionPointToEnd(initBlock); 1266 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), 1267 condBlock); 1268 rewriter.setInsertionPointToStart(condBlock); 1269 Value indexArg = condBlock->getArgument(0); 1270 Value strideArg = condBlock->getArgument(1); 1271 1272 Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0); 1273 Value pred = rewriter.create<LLVM::ICmpOp>( 1274 loc, IntegerType::get(rewriter.getContext(), 1), 1275 LLVM::ICmpPredicate::sge, indexArg, zeroIndex); 1276 1277 Block *bodyBlock = 1278 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); 1279 rewriter.setInsertionPointToStart(bodyBlock); 1280 1281 // Copy size from shape to descriptor. 1282 auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); 1283 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( 1284 loc, llvmIndexPtrType, 1285 typeConverter->convertType(shapeMemRefType.getElementType()), 1286 shapeOperandPtr, indexArg); 1287 Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep); 1288 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), 1289 targetSizesBase, indexArg, size); 1290 1291 // Write stride value and compute next one. 1292 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), 1293 targetStridesBase, indexArg, strideArg); 1294 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); 1295 1296 // Decrement loop counter and branch back. 1297 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); 1298 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), 1299 condBlock); 1300 1301 Block *remainder = 1302 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); 1303 1304 // Hook up the cond exit to the remainder. 1305 rewriter.setInsertionPointToEnd(condBlock); 1306 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt, 1307 remainder, std::nullopt); 1308 1309 // Reset position to beginning of new remainder block. 1310 rewriter.setInsertionPointToStart(remainder); 1311 1312 *descriptor = targetDesc; 1313 return success(); 1314 } 1315 }; 1316 1317 /// RessociatingReshapeOp must be expanded before we reach this stage. 1318 /// Report that information. 1319 template <typename ReshapeOp> 1320 class ReassociatingReshapeOpConversion 1321 : public ConvertOpToLLVMPattern<ReshapeOp> { 1322 public: 1323 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; 1324 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; 1325 1326 LogicalResult 1327 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, 1328 ConversionPatternRewriter &rewriter) const override { 1329 return rewriter.notifyMatchFailure( 1330 reshapeOp, 1331 "reassociation operations should have been expanded beforehand"); 1332 } 1333 }; 1334 1335 /// Subviews must be expanded before we reach this stage. 1336 /// Report that information. 1337 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { 1338 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; 1339 1340 LogicalResult 1341 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 1342 ConversionPatternRewriter &rewriter) const override { 1343 return rewriter.notifyMatchFailure( 1344 subViewOp, "subview operations should have been expanded beforehand"); 1345 } 1346 }; 1347 1348 /// Conversion pattern that transforms a transpose op into: 1349 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 1350 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 1351 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 1352 /// and stride. Size and stride are permutations of the original values. 1353 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 1354 /// The transpose op is replaced by the alloca'ed pointer. 1355 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { 1356 public: 1357 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; 1358 1359 LogicalResult 1360 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, 1361 ConversionPatternRewriter &rewriter) const override { 1362 auto loc = transposeOp.getLoc(); 1363 MemRefDescriptor viewMemRef(adaptor.getIn()); 1364 1365 // No permutation, early exit. 1366 if (transposeOp.getPermutation().isIdentity()) 1367 return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); 1368 1369 auto targetMemRef = MemRefDescriptor::undef( 1370 rewriter, loc, 1371 typeConverter->convertType(transposeOp.getIn().getType())); 1372 1373 // Copy the base and aligned pointers from the old descriptor to the new 1374 // one. 1375 targetMemRef.setAllocatedPtr(rewriter, loc, 1376 viewMemRef.allocatedPtr(rewriter, loc)); 1377 targetMemRef.setAlignedPtr(rewriter, loc, 1378 viewMemRef.alignedPtr(rewriter, loc)); 1379 1380 // Copy the offset pointer from the old descriptor to the new one. 1381 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); 1382 1383 // Iterate over the dimensions and apply size/stride permutation: 1384 // When enumerating the results of the permutation map, the enumeration 1385 // index is the index into the target dimensions and the DimExpr points to 1386 // the dimension of the source memref. 1387 for (const auto &en : 1388 llvm::enumerate(transposeOp.getPermutation().getResults())) { 1389 int targetPos = en.index(); 1390 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition(); 1391 targetMemRef.setSize(rewriter, loc, targetPos, 1392 viewMemRef.size(rewriter, loc, sourcePos)); 1393 targetMemRef.setStride(rewriter, loc, targetPos, 1394 viewMemRef.stride(rewriter, loc, sourcePos)); 1395 } 1396 1397 rewriter.replaceOp(transposeOp, {targetMemRef}); 1398 return success(); 1399 } 1400 }; 1401 1402 /// Conversion pattern that transforms an op into: 1403 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor 1404 /// 2. Updates to the descriptor to introduce the data ptr, offset, size 1405 /// and stride. 1406 /// The view op is replaced by the descriptor. 1407 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { 1408 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; 1409 1410 // Build and return the value for the idx^th shape dimension, either by 1411 // returning the constant shape dimension or counting the proper dynamic size. 1412 Value getSize(ConversionPatternRewriter &rewriter, Location loc, 1413 ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx, 1414 Type indexType) const { 1415 assert(idx < shape.size()); 1416 if (!ShapedType::isDynamic(shape[idx])) 1417 return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]); 1418 // Count the number of dynamic dims in range [0, idx] 1419 unsigned nDynamic = 1420 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic); 1421 return dynamicSizes[nDynamic]; 1422 } 1423 1424 // Build and return the idx^th stride, either by returning the constant stride 1425 // or by computing the dynamic stride from the current `runningStride` and 1426 // `nextSize`. The caller should keep a running stride and update it with the 1427 // result returned by this function. 1428 Value getStride(ConversionPatternRewriter &rewriter, Location loc, 1429 ArrayRef<int64_t> strides, Value nextSize, 1430 Value runningStride, unsigned idx, Type indexType) const { 1431 assert(idx < strides.size()); 1432 if (!ShapedType::isDynamic(strides[idx])) 1433 return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]); 1434 if (nextSize) 1435 return runningStride 1436 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) 1437 : nextSize; 1438 assert(!runningStride); 1439 return createIndexAttrConstant(rewriter, loc, indexType, 1); 1440 } 1441 1442 LogicalResult 1443 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, 1444 ConversionPatternRewriter &rewriter) const override { 1445 auto loc = viewOp.getLoc(); 1446 1447 auto viewMemRefType = viewOp.getType(); 1448 auto targetElementTy = 1449 typeConverter->convertType(viewMemRefType.getElementType()); 1450 auto targetDescTy = typeConverter->convertType(viewMemRefType); 1451 if (!targetDescTy || !targetElementTy || 1452 !LLVM::isCompatibleType(targetElementTy) || 1453 !LLVM::isCompatibleType(targetDescTy)) 1454 return viewOp.emitWarning("Target descriptor type not converted to LLVM"), 1455 failure(); 1456 1457 int64_t offset; 1458 SmallVector<int64_t, 4> strides; 1459 auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset); 1460 if (failed(successStrides)) 1461 return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); 1462 assert(offset == 0 && "expected offset to be 0"); 1463 1464 // Target memref must be contiguous in memory (innermost stride is 1), or 1465 // empty (special case when at least one of the memref dimensions is 0). 1466 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0)) 1467 return viewOp.emitWarning("cannot cast to non-contiguous shape"), 1468 failure(); 1469 1470 // Create the descriptor. 1471 MemRefDescriptor sourceMemRef(adaptor.getSource()); 1472 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); 1473 1474 // Field 1: Copy the allocated pointer, used for malloc/free. 1475 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); 1476 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType()); 1477 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr); 1478 1479 // Field 2: Copy the actual aligned pointer to payload. 1480 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); 1481 alignedPtr = rewriter.create<LLVM::GEPOp>( 1482 loc, alignedPtr.getType(), 1483 typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr, 1484 adaptor.getByteShift()); 1485 1486 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr); 1487 1488 Type indexType = getIndexType(); 1489 // Field 3: The offset in the resulting type must be 0. This is 1490 // because of the type change: an offset on srcType* may not be 1491 // expressible as an offset on dstType*. 1492 targetMemRef.setOffset( 1493 rewriter, loc, 1494 createIndexAttrConstant(rewriter, loc, indexType, offset)); 1495 1496 // Early exit for 0-D corner case. 1497 if (viewMemRefType.getRank() == 0) 1498 return rewriter.replaceOp(viewOp, {targetMemRef}), success(); 1499 1500 // Fields 4 and 5: Update sizes and strides. 1501 Value stride = nullptr, nextSize = nullptr; 1502 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { 1503 // Update size. 1504 Value size = getSize(rewriter, loc, viewMemRefType.getShape(), 1505 adaptor.getSizes(), i, indexType); 1506 targetMemRef.setSize(rewriter, loc, i, size); 1507 // Update stride. 1508 stride = 1509 getStride(rewriter, loc, strides, nextSize, stride, i, indexType); 1510 targetMemRef.setStride(rewriter, loc, i, stride); 1511 nextSize = size; 1512 } 1513 1514 rewriter.replaceOp(viewOp, {targetMemRef}); 1515 return success(); 1516 } 1517 }; 1518 1519 //===----------------------------------------------------------------------===// 1520 // AtomicRMWOpLowering 1521 //===----------------------------------------------------------------------===// 1522 1523 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a 1524 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. 1525 static std::optional<LLVM::AtomicBinOp> 1526 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { 1527 switch (atomicOp.getKind()) { 1528 case arith::AtomicRMWKind::addf: 1529 return LLVM::AtomicBinOp::fadd; 1530 case arith::AtomicRMWKind::addi: 1531 return LLVM::AtomicBinOp::add; 1532 case arith::AtomicRMWKind::assign: 1533 return LLVM::AtomicBinOp::xchg; 1534 case arith::AtomicRMWKind::maximumf: 1535 return LLVM::AtomicBinOp::fmax; 1536 case arith::AtomicRMWKind::maxs: 1537 return LLVM::AtomicBinOp::max; 1538 case arith::AtomicRMWKind::maxu: 1539 return LLVM::AtomicBinOp::umax; 1540 case arith::AtomicRMWKind::minimumf: 1541 return LLVM::AtomicBinOp::fmin; 1542 case arith::AtomicRMWKind::mins: 1543 return LLVM::AtomicBinOp::min; 1544 case arith::AtomicRMWKind::minu: 1545 return LLVM::AtomicBinOp::umin; 1546 case arith::AtomicRMWKind::ori: 1547 return LLVM::AtomicBinOp::_or; 1548 case arith::AtomicRMWKind::andi: 1549 return LLVM::AtomicBinOp::_and; 1550 default: 1551 return std::nullopt; 1552 } 1553 llvm_unreachable("Invalid AtomicRMWKind"); 1554 } 1555 1556 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { 1557 using Base::Base; 1558 1559 LogicalResult 1560 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 1561 ConversionPatternRewriter &rewriter) const override { 1562 auto maybeKind = matchSimpleAtomicOp(atomicOp); 1563 if (!maybeKind) 1564 return failure(); 1565 auto memRefType = atomicOp.getMemRefType(); 1566 SmallVector<int64_t> strides; 1567 int64_t offset; 1568 if (failed(memRefType.getStridesAndOffset(strides, offset))) 1569 return failure(); 1570 auto dataPtr = 1571 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(), 1572 adaptor.getIndices(), rewriter); 1573 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( 1574 atomicOp, *maybeKind, dataPtr, adaptor.getValue(), 1575 LLVM::AtomicOrdering::acq_rel); 1576 return success(); 1577 } 1578 }; 1579 1580 /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index. 1581 class ConvertExtractAlignedPointerAsIndex 1582 : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> { 1583 public: 1584 using ConvertOpToLLVMPattern< 1585 memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern; 1586 1587 LogicalResult 1588 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp, 1589 OpAdaptor adaptor, 1590 ConversionPatternRewriter &rewriter) const override { 1591 BaseMemRefType sourceTy = extractOp.getSource().getType(); 1592 1593 Value alignedPtr; 1594 if (sourceTy.hasRank()) { 1595 MemRefDescriptor desc(adaptor.getSource()); 1596 alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc()); 1597 } else { 1598 auto elementPtrTy = LLVM::LLVMPointerType::get( 1599 rewriter.getContext(), sourceTy.getMemorySpaceAsInt()); 1600 1601 UnrankedMemRefDescriptor desc(adaptor.getSource()); 1602 Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc()); 1603 1604 alignedPtr = UnrankedMemRefDescriptor::alignedPtr( 1605 rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr, 1606 elementPtrTy); 1607 } 1608 1609 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>( 1610 extractOp, getTypeConverter()->getIndexType(), alignedPtr); 1611 return success(); 1612 } 1613 }; 1614 1615 /// Materialize the MemRef descriptor represented by the results of 1616 /// ExtractStridedMetadataOp. 1617 class ExtractStridedMetadataOpLowering 1618 : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> { 1619 public: 1620 using ConvertOpToLLVMPattern< 1621 memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern; 1622 1623 LogicalResult 1624 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, 1625 OpAdaptor adaptor, 1626 ConversionPatternRewriter &rewriter) const override { 1627 1628 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) 1629 return failure(); 1630 1631 // Create the descriptor. 1632 MemRefDescriptor sourceMemRef(adaptor.getSource()); 1633 Location loc = extractStridedMetadataOp.getLoc(); 1634 Value source = extractStridedMetadataOp.getSource(); 1635 1636 auto sourceMemRefType = cast<MemRefType>(source.getType()); 1637 int64_t rank = sourceMemRefType.getRank(); 1638 SmallVector<Value> results; 1639 results.reserve(2 + rank * 2); 1640 1641 // Base buffer. 1642 Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc); 1643 Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc); 1644 MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape( 1645 rewriter, loc, *getTypeConverter(), 1646 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()), 1647 baseBuffer, alignedBuffer); 1648 results.push_back((Value)dstMemRef); 1649 1650 // Offset. 1651 results.push_back(sourceMemRef.offset(rewriter, loc)); 1652 1653 // Sizes. 1654 for (unsigned i = 0; i < rank; ++i) 1655 results.push_back(sourceMemRef.size(rewriter, loc, i)); 1656 // Strides. 1657 for (unsigned i = 0; i < rank; ++i) 1658 results.push_back(sourceMemRef.stride(rewriter, loc, i)); 1659 1660 rewriter.replaceOp(extractStridedMetadataOp, results); 1661 return success(); 1662 } 1663 }; 1664 1665 } // namespace 1666 1667 void mlir::populateFinalizeMemRefToLLVMConversionPatterns( 1668 const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1669 // clang-format off 1670 patterns.add< 1671 AllocaOpLowering, 1672 AllocaScopeOpLowering, 1673 AtomicRMWOpLowering, 1674 AssumeAlignmentOpLowering, 1675 ConvertExtractAlignedPointerAsIndex, 1676 DimOpLowering, 1677 ExtractStridedMetadataOpLowering, 1678 GenericAtomicRMWOpLowering, 1679 GlobalMemrefOpLowering, 1680 GetGlobalMemrefOpLowering, 1681 LoadOpLowering, 1682 MemRefCastOpLowering, 1683 MemRefCopyOpLowering, 1684 MemorySpaceCastOpLowering, 1685 MemRefReinterpretCastOpLowering, 1686 MemRefReshapeOpLowering, 1687 PrefetchOpLowering, 1688 RankOpLowering, 1689 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, 1690 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, 1691 StoreOpLowering, 1692 SubViewOpLowering, 1693 TransposeOpLowering, 1694 ViewOpLowering>(converter); 1695 // clang-format on 1696 auto allocLowering = converter.getOptions().allocLowering; 1697 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) 1698 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter); 1699 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) 1700 patterns.add<AllocOpLowering, DeallocOpLowering>(converter); 1701 } 1702 1703 namespace { 1704 struct FinalizeMemRefToLLVMConversionPass 1705 : public impl::FinalizeMemRefToLLVMConversionPassBase< 1706 FinalizeMemRefToLLVMConversionPass> { 1707 using FinalizeMemRefToLLVMConversionPassBase:: 1708 FinalizeMemRefToLLVMConversionPassBase; 1709 1710 void runOnOperation() override { 1711 Operation *op = getOperation(); 1712 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 1713 LowerToLLVMOptions options(&getContext(), 1714 dataLayoutAnalysis.getAtOrAbove(op)); 1715 options.allocLowering = 1716 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 1717 : LowerToLLVMOptions::AllocLowering::Malloc); 1718 1719 options.useGenericFunctions = useGenericFunctions; 1720 1721 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 1722 options.overrideIndexBitwidth(indexBitwidth); 1723 1724 LLVMTypeConverter typeConverter(&getContext(), options, 1725 &dataLayoutAnalysis); 1726 RewritePatternSet patterns(&getContext()); 1727 populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); 1728 LLVMConversionTarget target(getContext()); 1729 target.addLegalOp<func::FuncOp>(); 1730 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 1731 signalPassFailure(); 1732 } 1733 }; 1734 1735 /// Implement the interface to convert MemRef to LLVM. 1736 struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface { 1737 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; 1738 void loadDependentDialects(MLIRContext *context) const final { 1739 context->loadDialect<LLVM::LLVMDialect>(); 1740 } 1741 1742 /// Hook for derived dialect interface to provide conversion patterns 1743 /// and mark dialect legal for the conversion target. 1744 void populateConvertToLLVMConversionPatterns( 1745 ConversionTarget &target, LLVMTypeConverter &typeConverter, 1746 RewritePatternSet &patterns) const final { 1747 populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); 1748 } 1749 }; 1750 1751 } // namespace 1752 1753 void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry ®istry) { 1754 registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { 1755 dialect->addInterfaces<MemRefToLLVMDialectInterface>(); 1756 }); 1757 } 1758