xref: /llvm-project/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (revision e84f6b6a88c1222d512edf0644c8f869dd12b8ef)
1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"
23 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/IRMapping.h"
27 #include "mlir/Pass/Pass.h"
28 #include "llvm/ADT/SmallBitVector.h"
29 #include "llvm/Support/MathExtras.h"
30 #include <optional>
31 
32 namespace mlir {
33 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
34 #include "mlir/Conversion/Passes.h.inc"
35 } // namespace mlir
36 
37 using namespace mlir;
38 
39 namespace {
40 
41 static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
42   return !ShapedType::isDynamic(strideOrOffset);
43 }
44 
45 static FailureOr<LLVM::LLVMFuncOp>
46 getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) {
47   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
48 
49   if (useGenericFn)
50     return LLVM::lookupOrCreateGenericFreeFn(module);
51 
52   return LLVM::lookupOrCreateFreeFn(module);
53 }
54 
55 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
56   AllocOpLowering(const LLVMTypeConverter &converter)
57       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
58                                 converter) {}
59   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
60                                           Location loc, Value sizeBytes,
61                                           Operation *op) const override {
62     return allocateBufferManuallyAlign(
63         rewriter, loc, sizeBytes, op,
64         getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
65   }
66 };
67 
68 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
69   AlignedAllocOpLowering(const LLVMTypeConverter &converter)
70       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
71                                 converter) {}
72   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
73                                           Location loc, Value sizeBytes,
74                                           Operation *op) const override {
75     Value ptr = allocateBufferAutoAlign(
76         rewriter, loc, sizeBytes, op, &defaultLayout,
77         alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
78                                       &defaultLayout));
79     if (!ptr)
80       return std::make_tuple(Value(), Value());
81     return std::make_tuple(ptr, ptr);
82   }
83 
84 private:
85   /// Default layout to use in absence of the corresponding analysis.
86   DataLayout defaultLayout;
87 };
88 
89 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
90   AllocaOpLowering(const LLVMTypeConverter &converter)
91       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
92                                 converter) {
93     setRequiresNumElements();
94   }
95 
96   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
97   /// is set to null for stack allocations. `accessAlignment` is set if
98   /// alignment is needed post allocation (for eg. in conjunction with malloc).
99   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
100                                           Location loc, Value size,
101                                           Operation *op) const override {
102 
103     // With alloca, one gets a pointer to the element type right away.
104     // For stack allocations.
105     auto allocaOp = cast<memref::AllocaOp>(op);
106     auto elementType =
107         typeConverter->convertType(allocaOp.getType().getElementType());
108     unsigned addrSpace =
109         *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
110     auto elementPtrType =
111         LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
112 
113     auto allocatedElementPtr =
114         rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
115                                         allocaOp.getAlignment().value_or(0));
116 
117     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
118   }
119 };
120 
121 struct AllocaScopeOpLowering
122     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
123   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
124 
125   LogicalResult
126   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
127                   ConversionPatternRewriter &rewriter) const override {
128     OpBuilder::InsertionGuard guard(rewriter);
129     Location loc = allocaScopeOp.getLoc();
130 
131     // Split the current block before the AllocaScopeOp to create the inlining
132     // point.
133     auto *currentBlock = rewriter.getInsertionBlock();
134     auto *remainingOpsBlock =
135         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
136     Block *continueBlock;
137     if (allocaScopeOp.getNumResults() == 0) {
138       continueBlock = remainingOpsBlock;
139     } else {
140       continueBlock = rewriter.createBlock(
141           remainingOpsBlock, allocaScopeOp.getResultTypes(),
142           SmallVector<Location>(allocaScopeOp->getNumResults(),
143                                 allocaScopeOp.getLoc()));
144       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
145     }
146 
147     // Inline body region.
148     Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
149     Block *afterBody = &allocaScopeOp.getBodyRegion().back();
150     rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
151 
152     // Save stack and then branch into the body of the region.
153     rewriter.setInsertionPointToEnd(currentBlock);
154     auto stackSaveOp =
155         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
156     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
157 
158     // Replace the alloca_scope return with a branch that jumps out of the body.
159     // Stack restore before leaving the body region.
160     rewriter.setInsertionPointToEnd(afterBody);
161     auto returnOp =
162         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
163     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
164         returnOp, returnOp.getResults(), continueBlock);
165 
166     // Insert stack restore before jumping out the body of the region.
167     rewriter.setInsertionPoint(branchOp);
168     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
169 
170     // Replace the op with values return from the body region.
171     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
172 
173     return success();
174   }
175 };
176 
177 struct AssumeAlignmentOpLowering
178     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
179   using ConvertOpToLLVMPattern<
180       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
181   explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
182       : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
183 
184   LogicalResult
185   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
186                   ConversionPatternRewriter &rewriter) const override {
187     Value memref = adaptor.getMemref();
188     unsigned alignment = op.getAlignment();
189     auto loc = op.getLoc();
190 
191     auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
192     Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
193                                      rewriter);
194 
195     // Emit llvm.assume(true) ["align"(memref, alignment)].
196     // This is more direct than ptrtoint-based checks, is explicitly supported,
197     // and works with non-integral address spaces.
198     Value trueCond =
199         rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true));
200     Value alignmentConst =
201         createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
202     rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
203                                     alignmentConst);
204 
205     rewriter.eraseOp(op);
206     return success();
207   }
208 };
209 
210 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
211 // The memref descriptor being an SSA value, there is no need to clean it up
212 // in any way.
213 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
214   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
215 
216   explicit DeallocOpLowering(const LLVMTypeConverter &converter)
217       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
218 
219   LogicalResult
220   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
221                   ConversionPatternRewriter &rewriter) const override {
222     // Insert the `free` declaration if it is not already present.
223     FailureOr<LLVM::LLVMFuncOp> freeFunc =
224         getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
225     if (failed(freeFunc))
226       return failure();
227     Value allocatedPtr;
228     if (auto unrankedTy =
229             llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
230       auto elementPtrTy = LLVM::LLVMPointerType::get(
231           rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
232       allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
233           rewriter, op.getLoc(),
234           UnrankedMemRefDescriptor(adaptor.getMemref())
235               .memRefDescPtr(rewriter, op.getLoc()),
236           elementPtrTy);
237     } else {
238       allocatedPtr = MemRefDescriptor(adaptor.getMemref())
239                          .allocatedPtr(rewriter, op.getLoc());
240     }
241     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
242                                               allocatedPtr);
243     return success();
244   }
245 };
246 
247 // A `dim` is converted to a constant for static sizes and to an access to the
248 // size stored in the memref descriptor for dynamic sizes.
249 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
250   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
251 
252   LogicalResult
253   matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
254                   ConversionPatternRewriter &rewriter) const override {
255     Type operandType = dimOp.getSource().getType();
256     if (isa<UnrankedMemRefType>(operandType)) {
257       FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
258           operandType, dimOp, adaptor.getOperands(), rewriter);
259       if (failed(extractedSize))
260         return failure();
261       rewriter.replaceOp(dimOp, {*extractedSize});
262       return success();
263     }
264     if (isa<MemRefType>(operandType)) {
265       rewriter.replaceOp(
266           dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
267                                             adaptor.getOperands(), rewriter)});
268       return success();
269     }
270     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
271   }
272 
273 private:
274   FailureOr<Value>
275   extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
276                               OpAdaptor adaptor,
277                               ConversionPatternRewriter &rewriter) const {
278     Location loc = dimOp.getLoc();
279 
280     auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
281     auto scalarMemRefType =
282         MemRefType::get({}, unrankedMemRefType.getElementType());
283     FailureOr<unsigned> maybeAddressSpace =
284         getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
285     if (failed(maybeAddressSpace)) {
286       dimOp.emitOpError("memref memory space must be convertible to an integer "
287                         "address space");
288       return failure();
289     }
290     unsigned addressSpace = *maybeAddressSpace;
291 
292     // Extract pointer to the underlying ranked descriptor and bitcast it to a
293     // memref<element_type> descriptor pointer to minimize the number of GEP
294     // operations.
295     UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
296     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
297 
298     Type elementType = typeConverter->convertType(scalarMemRefType);
299 
300     // Get pointer to offset field of memref<element_type> descriptor.
301     auto indexPtrTy =
302         LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
303     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
304         loc, indexPtrTy, elementType, underlyingRankedDesc,
305         ArrayRef<LLVM::GEPArg>{0, 2});
306 
307     // The size value that we have to extract can be obtained using GEPop with
308     // `dimOp.index() + 1` index argument.
309     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
310         loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
311         adaptor.getIndex());
312     Value sizePtr = rewriter.create<LLVM::GEPOp>(
313         loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
314         idxPlusOne);
315     return rewriter
316         .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
317         .getResult();
318   }
319 
320   std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
321     if (auto idx = dimOp.getConstantIndex())
322       return idx;
323 
324     if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
325       return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
326 
327     return std::nullopt;
328   }
329 
330   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
331                                   OpAdaptor adaptor,
332                                   ConversionPatternRewriter &rewriter) const {
333     Location loc = dimOp.getLoc();
334 
335     // Take advantage if index is constant.
336     MemRefType memRefType = cast<MemRefType>(operandType);
337     Type indexType = getIndexType();
338     if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
339       int64_t i = *index;
340       if (i >= 0 && i < memRefType.getRank()) {
341         if (memRefType.isDynamicDim(i)) {
342           // extract dynamic size from the memref descriptor.
343           MemRefDescriptor descriptor(adaptor.getSource());
344           return descriptor.size(rewriter, loc, i);
345         }
346         // Use constant for static size.
347         int64_t dimSize = memRefType.getDimSize(i);
348         return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
349       }
350     }
351     Value index = adaptor.getIndex();
352     int64_t rank = memRefType.getRank();
353     MemRefDescriptor memrefDescriptor(adaptor.getSource());
354     return memrefDescriptor.size(rewriter, loc, index, rank);
355   }
356 };
357 
358 /// Common base for load and store operations on MemRefs. Restricts the match
359 /// to supported MemRef types. Provides functionality to emit code accessing a
360 /// specific element of the underlying data buffer.
361 template <typename Derived>
362 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
363   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
364   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
365   using Base = LoadStoreOpLowering<Derived>;
366 
367   LogicalResult match(Derived op) const override {
368     MemRefType type = op.getMemRefType();
369     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
370   }
371 };
372 
373 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
374 /// retried until it succeeds in atomically storing a new value into memory.
375 ///
376 ///      +---------------------------------+
377 ///      |   <code before the AtomicRMWOp> |
378 ///      |   <compute initial %loaded>     |
379 ///      |   cf.br loop(%loaded)              |
380 ///      +---------------------------------+
381 ///             |
382 ///  -------|   |
383 ///  |      v   v
384 ///  |   +--------------------------------+
385 ///  |   | loop(%loaded):                 |
386 ///  |   |   <body contents>              |
387 ///  |   |   %pair = cmpxchg              |
388 ///  |   |   %ok = %pair[0]               |
389 ///  |   |   %new = %pair[1]              |
390 ///  |   |   cf.cond_br %ok, end, loop(%new) |
391 ///  |   +--------------------------------+
392 ///  |          |        |
393 ///  |-----------        |
394 ///                      v
395 ///      +--------------------------------+
396 ///      | end:                           |
397 ///      |   <code after the AtomicRMWOp> |
398 ///      +--------------------------------+
399 ///
400 struct GenericAtomicRMWOpLowering
401     : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
402   using Base::Base;
403 
404   LogicalResult
405   matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
406                   ConversionPatternRewriter &rewriter) const override {
407     auto loc = atomicOp.getLoc();
408     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
409 
410     // Split the block into initial, loop, and ending parts.
411     auto *initBlock = rewriter.getInsertionBlock();
412     auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
413     loopBlock->addArgument(valueType, loc);
414 
415     auto *endBlock =
416         rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
417 
418     // Compute the loaded value and branch to the loop block.
419     rewriter.setInsertionPointToEnd(initBlock);
420     auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
421     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
422                                         adaptor.getIndices(), rewriter);
423     Value init = rewriter.create<LLVM::LoadOp>(
424         loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
425     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
426 
427     // Prepare the body of the loop block.
428     rewriter.setInsertionPointToStart(loopBlock);
429 
430     // Clone the GenericAtomicRMWOp region and extract the result.
431     auto loopArgument = loopBlock->getArgument(0);
432     IRMapping mapping;
433     mapping.map(atomicOp.getCurrentValue(), loopArgument);
434     Block &entryBlock = atomicOp.body().front();
435     for (auto &nestedOp : entryBlock.without_terminator()) {
436       Operation *clone = rewriter.clone(nestedOp, mapping);
437       mapping.map(nestedOp.getResults(), clone->getResults());
438     }
439     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
440 
441     // Prepare the epilog of the loop block.
442     // Append the cmpxchg op to the end of the loop block.
443     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
444     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
445     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
446         loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
447     // Extract the %new_loaded and %ok values from the pair.
448     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
449     Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
450 
451     // Conditionally branch to the end or back to the loop depending on %ok.
452     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
453                                     loopBlock, newLoaded);
454 
455     rewriter.setInsertionPointToEnd(endBlock);
456 
457     // The 'result' of the atomic_rmw op is the newly loaded value.
458     rewriter.replaceOp(atomicOp, {newLoaded});
459 
460     return success();
461   }
462 };
463 
464 /// Returns the LLVM type of the global variable given the memref type `type`.
465 static Type
466 convertGlobalMemrefTypeToLLVM(MemRefType type,
467                               const LLVMTypeConverter &typeConverter) {
468   // LLVM type for a global memref will be a multi-dimension array. For
469   // declarations or uninitialized global memrefs, we can potentially flatten
470   // this to a 1D array. However, for memref.global's with an initial value,
471   // we do not intend to flatten the ElementsAttribute when going from std ->
472   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
473   Type elementType = typeConverter.convertType(type.getElementType());
474   Type arrayTy = elementType;
475   // Shape has the outermost dim at index 0, so need to walk it backwards
476   for (int64_t dim : llvm::reverse(type.getShape()))
477     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
478   return arrayTy;
479 }
480 
481 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
482 struct GlobalMemrefOpLowering
483     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
484   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
485 
486   LogicalResult
487   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
488                   ConversionPatternRewriter &rewriter) const override {
489     MemRefType type = global.getType();
490     if (!isConvertibleAndHasIdentityMaps(type))
491       return failure();
492 
493     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
494 
495     LLVM::Linkage linkage =
496         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
497 
498     Attribute initialValue = nullptr;
499     if (!global.isExternal() && !global.isUninitialized()) {
500       auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
501       initialValue = elementsAttr;
502 
503       // For scalar memrefs, the global variable created is of the element type,
504       // so unpack the elements attribute to extract the value.
505       if (type.getRank() == 0)
506         initialValue = elementsAttr.getSplatValue<Attribute>();
507     }
508 
509     uint64_t alignment = global.getAlignment().value_or(0);
510     FailureOr<unsigned> addressSpace =
511         getTypeConverter()->getMemRefAddressSpace(type);
512     if (failed(addressSpace))
513       return global.emitOpError(
514           "memory space cannot be converted to an integer address space");
515     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
516         global, arrayTy, global.getConstant(), linkage, global.getSymName(),
517         initialValue, alignment, *addressSpace);
518     if (!global.isExternal() && global.isUninitialized()) {
519       rewriter.createBlock(&newGlobal.getInitializerRegion());
520       Value undef[] = {
521           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
522       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
523     }
524     return success();
525   }
526 };
527 
528 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
529 /// the first element stashed into the descriptor. This reuses
530 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
531 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
532   GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter)
533       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
534                                 converter) {}
535 
536   /// Buffer "allocation" for memref.get_global op is getting the address of
537   /// the global variable referenced.
538   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
539                                           Location loc, Value sizeBytes,
540                                           Operation *op) const override {
541     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
542     MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType());
543 
544     // This is called after a type conversion, which would have failed if this
545     // call fails.
546     FailureOr<unsigned> maybeAddressSpace =
547         getTypeConverter()->getMemRefAddressSpace(type);
548     if (failed(maybeAddressSpace))
549       return std::make_tuple(Value(), Value());
550     unsigned memSpace = *maybeAddressSpace;
551 
552     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
553     auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
554     auto addressOf =
555         rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName());
556 
557     // Get the address of the first element in the array by creating a GEP with
558     // the address of the GV as the base, and (rank + 1) number of 0 indices.
559     auto gep = rewriter.create<LLVM::GEPOp>(
560         loc, ptrTy, arrayTy, addressOf,
561         SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
562 
563     // We do not expect the memref obtained using `memref.get_global` to be
564     // ever deallocated. Set the allocated pointer to be known bad value to
565     // help debug if that ever happens.
566     auto intPtrType = getIntPtrType(memSpace);
567     Value deadBeefConst =
568         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
569     auto deadBeefPtr =
570         rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
571 
572     // Both allocated and aligned pointers are same. We could potentially stash
573     // a nullptr for the allocated pointer since we do not expect any dealloc.
574     return std::make_tuple(deadBeefPtr, gep);
575   }
576 };
577 
578 // Load operation is lowered to obtaining a pointer to the indexed element
579 // and loading it.
580 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
581   using Base::Base;
582 
583   LogicalResult
584   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
585                   ConversionPatternRewriter &rewriter) const override {
586     auto type = loadOp.getMemRefType();
587 
588     Value dataPtr =
589         getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
590                              adaptor.getIndices(), rewriter);
591     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
592         loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
593         false, loadOp.getNontemporal());
594     return success();
595   }
596 };
597 
598 // Store operation is lowered to obtaining a pointer to the indexed element,
599 // and storing the given value to it.
600 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
601   using Base::Base;
602 
603   LogicalResult
604   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
605                   ConversionPatternRewriter &rewriter) const override {
606     auto type = op.getMemRefType();
607 
608     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
609                                          adaptor.getIndices(), rewriter);
610     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
611                                                0, false, op.getNontemporal());
612     return success();
613   }
614 };
615 
616 // The prefetch operation is lowered in a way similar to the load operation
617 // except that the llvm.prefetch operation is used for replacement.
618 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
619   using Base::Base;
620 
621   LogicalResult
622   matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
623                   ConversionPatternRewriter &rewriter) const override {
624     auto type = prefetchOp.getMemRefType();
625     auto loc = prefetchOp.getLoc();
626 
627     Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
628                                          adaptor.getIndices(), rewriter);
629 
630     // Replace with llvm.prefetch.
631     IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
632     IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
633     IntegerAttr isData =
634         rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
635     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
636                                                 localityHint, isData);
637     return success();
638   }
639 };
640 
641 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
642   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
643 
644   LogicalResult
645   matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
646                   ConversionPatternRewriter &rewriter) const override {
647     Location loc = op.getLoc();
648     Type operandType = op.getMemref().getType();
649     if (dyn_cast<UnrankedMemRefType>(operandType)) {
650       UnrankedMemRefDescriptor desc(adaptor.getMemref());
651       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
652       return success();
653     }
654     if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
655       Type indexType = getIndexType();
656       rewriter.replaceOp(op,
657                          {createIndexAttrConstant(rewriter, loc, indexType,
658                                                   rankedMemRefType.getRank())});
659       return success();
660     }
661     return failure();
662   }
663 };
664 
665 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
666   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
667 
668   LogicalResult match(memref::CastOp memRefCastOp) const override {
669     Type srcType = memRefCastOp.getOperand().getType();
670     Type dstType = memRefCastOp.getType();
671 
672     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
673     // used for type erasure. For now they must preserve underlying element type
674     // and require source and result type to have the same rank. Therefore,
675     // perform a sanity check that the underlying structs are the same. Once op
676     // semantics are relaxed we can revisit.
677     if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
678       return success(typeConverter->convertType(srcType) ==
679                      typeConverter->convertType(dstType));
680 
681     // At least one of the operands is unranked type
682     assert(isa<UnrankedMemRefType>(srcType) ||
683            isa<UnrankedMemRefType>(dstType));
684 
685     // Unranked to unranked cast is disallowed
686     return !(isa<UnrankedMemRefType>(srcType) &&
687              isa<UnrankedMemRefType>(dstType))
688                ? success()
689                : failure();
690   }
691 
692   void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
693                ConversionPatternRewriter &rewriter) const override {
694     auto srcType = memRefCastOp.getOperand().getType();
695     auto dstType = memRefCastOp.getType();
696     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
697     auto loc = memRefCastOp.getLoc();
698 
699     // For ranked/ranked case, just keep the original descriptor.
700     if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
701       return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
702 
703     if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
704       // Casting ranked to unranked memref type
705       // Set the rank in the destination from the memref type
706       // Allocate space on the stack and copy the src memref descriptor
707       // Set the ptr in the destination to the stack space
708       auto srcMemRefType = cast<MemRefType>(srcType);
709       int64_t rank = srcMemRefType.getRank();
710       // ptr = AllocaOp sizeof(MemRefDescriptor)
711       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
712           loc, adaptor.getSource(), rewriter);
713 
714       // rank = ConstantOp srcRank
715       auto rankVal = rewriter.create<LLVM::ConstantOp>(
716           loc, getIndexType(), rewriter.getIndexAttr(rank));
717       // undef = UndefOp
718       UnrankedMemRefDescriptor memRefDesc =
719           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
720       // d1 = InsertValueOp undef, rank, 0
721       memRefDesc.setRank(rewriter, loc, rankVal);
722       // d2 = InsertValueOp d1, ptr, 1
723       memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
724       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
725 
726     } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
727       // Casting from unranked type to ranked.
728       // The operation is assumed to be doing a correct cast. If the destination
729       // type mismatches the unranked the type, it is undefined behavior.
730       UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
731       // ptr = ExtractValueOp src, 1
732       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
733 
734       // struct = LoadOp ptr
735       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
736       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
737     } else {
738       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
739     }
740   }
741 };
742 
743 /// Pattern to lower a `memref.copy` to llvm.
744 ///
745 /// For memrefs with identity layouts, the copy is lowered to the llvm
746 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
747 /// to the generic `MemrefCopyFn`.
748 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
749   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
750 
751   LogicalResult
752   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
753                           ConversionPatternRewriter &rewriter) const {
754     auto loc = op.getLoc();
755     auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
756 
757     MemRefDescriptor srcDesc(adaptor.getSource());
758 
759     // Compute number of elements.
760     Value numElements = rewriter.create<LLVM::ConstantOp>(
761         loc, getIndexType(), rewriter.getIndexAttr(1));
762     for (int pos = 0; pos < srcType.getRank(); ++pos) {
763       auto size = srcDesc.size(rewriter, loc, pos);
764       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
765     }
766 
767     // Get element size.
768     auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
769     // Compute total.
770     Value totalSize =
771         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
772 
773     Type elementType = typeConverter->convertType(srcType.getElementType());
774 
775     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
776     Value srcOffset = srcDesc.offset(rewriter, loc);
777     Value srcPtr = rewriter.create<LLVM::GEPOp>(
778         loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset);
779     MemRefDescriptor targetDesc(adaptor.getTarget());
780     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
781     Value targetOffset = targetDesc.offset(rewriter, loc);
782     Value targetPtr = rewriter.create<LLVM::GEPOp>(
783         loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset);
784     rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
785                                     /*isVolatile=*/false);
786     rewriter.eraseOp(op);
787 
788     return success();
789   }
790 
791   LogicalResult
792   lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
793                              ConversionPatternRewriter &rewriter) const {
794     auto loc = op.getLoc();
795     auto srcType = cast<BaseMemRefType>(op.getSource().getType());
796     auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
797 
798     // First make sure we have an unranked memref descriptor representation.
799     auto makeUnranked = [&, this](Value ranked, MemRefType type) {
800       auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
801                                                     type.getRank());
802       auto *typeConverter = getTypeConverter();
803       auto ptr =
804           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
805 
806       auto unrankedType =
807           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
808       return UnrankedMemRefDescriptor::pack(
809           rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
810     };
811 
812     // Save stack position before promoting descriptors
813     auto stackSaveOp =
814         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
815 
816     auto srcMemRefType = dyn_cast<MemRefType>(srcType);
817     Value unrankedSource =
818         srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
819                       : adaptor.getSource();
820     auto targetMemRefType = dyn_cast<MemRefType>(targetType);
821     Value unrankedTarget =
822         targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
823                          : adaptor.getTarget();
824 
825     // Now promote the unranked descriptors to the stack.
826     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
827                                                  rewriter.getIndexAttr(1));
828     auto promote = [&](Value desc) {
829       auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
830       auto allocated =
831           rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
832       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
833       return allocated;
834     };
835 
836     auto sourcePtr = promote(unrankedSource);
837     auto targetPtr = promote(unrankedTarget);
838 
839     // Derive size from llvm.getelementptr which will account for any
840     // potential alignment
841     auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
842     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
843         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
844     if (failed(copyFn))
845       return failure();
846     rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
847                                   ValueRange{elemSize, sourcePtr, targetPtr});
848 
849     // Restore stack used for descriptors
850     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
851 
852     rewriter.eraseOp(op);
853 
854     return success();
855   }
856 
857   LogicalResult
858   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
859                   ConversionPatternRewriter &rewriter) const override {
860     auto srcType = cast<BaseMemRefType>(op.getSource().getType());
861     auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
862 
863     auto isContiguousMemrefType = [&](BaseMemRefType type) {
864       auto memrefType = dyn_cast<mlir::MemRefType>(type);
865       // We can use memcpy for memrefs if they have an identity layout or are
866       // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
867       // special case handled by memrefCopy.
868       return memrefType &&
869              (memrefType.getLayout().isIdentity() ||
870               (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
871                memref::isStaticShapeAndContiguousRowMajor(memrefType)));
872     };
873 
874     if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
875       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
876 
877     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
878   }
879 };
880 
881 struct MemorySpaceCastOpLowering
882     : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
883   using ConvertOpToLLVMPattern<
884       memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
885 
886   LogicalResult
887   matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
888                   ConversionPatternRewriter &rewriter) const override {
889     Location loc = op.getLoc();
890 
891     Type resultType = op.getDest().getType();
892     if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
893       auto resultDescType =
894           cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
895       Type newPtrType = resultDescType.getBody()[0];
896 
897       SmallVector<Value> descVals;
898       MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
899                                descVals);
900       descVals[0] =
901           rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
902       descVals[1] =
903           rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
904       Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
905                                             resultTypeR, descVals);
906       rewriter.replaceOp(op, result);
907       return success();
908     }
909     if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
910       // Since the type converter won't be doing this for us, get the address
911       // space.
912       auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
913       FailureOr<unsigned> maybeSourceAddrSpace =
914           getTypeConverter()->getMemRefAddressSpace(sourceType);
915       if (failed(maybeSourceAddrSpace))
916         return rewriter.notifyMatchFailure(loc,
917                                            "non-integer source address space");
918       unsigned sourceAddrSpace = *maybeSourceAddrSpace;
919       FailureOr<unsigned> maybeResultAddrSpace =
920           getTypeConverter()->getMemRefAddressSpace(resultTypeU);
921       if (failed(maybeResultAddrSpace))
922         return rewriter.notifyMatchFailure(loc,
923                                            "non-integer result address space");
924       unsigned resultAddrSpace = *maybeResultAddrSpace;
925 
926       UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
927       Value rank = sourceDesc.rank(rewriter, loc);
928       Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
929 
930       // Create and allocate storage for new memref descriptor.
931       auto result = UnrankedMemRefDescriptor::undef(
932           rewriter, loc, typeConverter->convertType(resultTypeU));
933       result.setRank(rewriter, loc, rank);
934       SmallVector<Value, 1> sizes;
935       UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
936                                              result, resultAddrSpace, sizes);
937       Value resultUnderlyingSize = sizes.front();
938       Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
939           loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
940       result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
941 
942       // Copy pointers, performing address space casts.
943       auto sourceElemPtrType =
944           LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
945       auto resultElemPtrType =
946           LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
947 
948       Value allocatedPtr = sourceDesc.allocatedPtr(
949           rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
950       Value alignedPtr =
951           sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
952                                 sourceUnderlyingDesc, sourceElemPtrType);
953       allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
954           loc, resultElemPtrType, allocatedPtr);
955       alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
956           loc, resultElemPtrType, alignedPtr);
957 
958       result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
959                              resultElemPtrType, allocatedPtr);
960       result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
961                            resultUnderlyingDesc, resultElemPtrType, alignedPtr);
962 
963       // Copy all the index-valued operands.
964       Value sourceIndexVals =
965           sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
966                                    sourceUnderlyingDesc, sourceElemPtrType);
967       Value resultIndexVals =
968           result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
969                                resultUnderlyingDesc, resultElemPtrType);
970 
971       int64_t bytesToSkip =
972           2 * llvm::divideCeil(
973                   getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
974       Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
975           loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
976       Value copySize = rewriter.create<LLVM::SubOp>(
977           loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
978       rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
979                                       copySize, /*isVolatile=*/false);
980 
981       rewriter.replaceOp(op, ValueRange{result});
982       return success();
983     }
984     return rewriter.notifyMatchFailure(loc, "unexpected memref type");
985   }
986 };
987 
988 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
989 /// memref type. In unranked case, the fields are extracted from the underlying
990 /// ranked descriptor.
991 static void extractPointersAndOffset(Location loc,
992                                      ConversionPatternRewriter &rewriter,
993                                      const LLVMTypeConverter &typeConverter,
994                                      Value originalOperand,
995                                      Value convertedOperand,
996                                      Value *allocatedPtr, Value *alignedPtr,
997                                      Value *offset = nullptr) {
998   Type operandType = originalOperand.getType();
999   if (isa<MemRefType>(operandType)) {
1000     MemRefDescriptor desc(convertedOperand);
1001     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1002     *alignedPtr = desc.alignedPtr(rewriter, loc);
1003     if (offset != nullptr)
1004       *offset = desc.offset(rewriter, loc);
1005     return;
1006   }
1007 
1008   // These will all cause assert()s on unconvertible types.
1009   unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1010       cast<UnrankedMemRefType>(operandType));
1011   auto elementPtrType =
1012       LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1013 
1014   // Extract pointer to the underlying ranked memref descriptor and cast it to
1015   // ElemType**.
1016   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1017   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1018 
1019   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
1020       rewriter, loc, underlyingDescPtr, elementPtrType);
1021   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1022       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1023   if (offset != nullptr) {
1024     *offset = UnrankedMemRefDescriptor::offset(
1025         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1026   }
1027 }
1028 
1029 struct MemRefReinterpretCastOpLowering
1030     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1031   using ConvertOpToLLVMPattern<
1032       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1033 
1034   LogicalResult
1035   matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1036                   ConversionPatternRewriter &rewriter) const override {
1037     Type srcType = castOp.getSource().getType();
1038 
1039     Value descriptor;
1040     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1041                                                adaptor, &descriptor)))
1042       return failure();
1043     rewriter.replaceOp(castOp, {descriptor});
1044     return success();
1045   }
1046 
1047 private:
1048   LogicalResult convertSourceMemRefToDescriptor(
1049       ConversionPatternRewriter &rewriter, Type srcType,
1050       memref::ReinterpretCastOp castOp,
1051       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1052     MemRefType targetMemRefType =
1053         cast<MemRefType>(castOp.getResult().getType());
1054     auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1055         typeConverter->convertType(targetMemRefType));
1056     if (!llvmTargetDescriptorTy)
1057       return failure();
1058 
1059     // Create descriptor.
1060     Location loc = castOp.getLoc();
1061     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1062 
1063     // Set allocated and aligned pointers.
1064     Value allocatedPtr, alignedPtr;
1065     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1066                              castOp.getSource(), adaptor.getSource(),
1067                              &allocatedPtr, &alignedPtr);
1068     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1069     desc.setAlignedPtr(rewriter, loc, alignedPtr);
1070 
1071     // Set offset.
1072     if (castOp.isDynamicOffset(0))
1073       desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1074     else
1075       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1076 
1077     // Set sizes and strides.
1078     unsigned dynSizeId = 0;
1079     unsigned dynStrideId = 0;
1080     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1081       if (castOp.isDynamicSize(i))
1082         desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1083       else
1084         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1085 
1086       if (castOp.isDynamicStride(i))
1087         desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1088       else
1089         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1090     }
1091     *descriptor = desc;
1092     return success();
1093   }
1094 };
1095 
1096 struct MemRefReshapeOpLowering
1097     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1098   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1099 
1100   LogicalResult
1101   matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1102                   ConversionPatternRewriter &rewriter) const override {
1103     Type srcType = reshapeOp.getSource().getType();
1104 
1105     Value descriptor;
1106     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1107                                                adaptor, &descriptor)))
1108       return failure();
1109     rewriter.replaceOp(reshapeOp, {descriptor});
1110     return success();
1111   }
1112 
1113 private:
1114   LogicalResult
1115   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1116                                   Type srcType, memref::ReshapeOp reshapeOp,
1117                                   memref::ReshapeOp::Adaptor adaptor,
1118                                   Value *descriptor) const {
1119     auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1120     if (shapeMemRefType.hasStaticShape()) {
1121       MemRefType targetMemRefType =
1122           cast<MemRefType>(reshapeOp.getResult().getType());
1123       auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1124           typeConverter->convertType(targetMemRefType));
1125       if (!llvmTargetDescriptorTy)
1126         return failure();
1127 
1128       // Create descriptor.
1129       Location loc = reshapeOp.getLoc();
1130       auto desc =
1131           MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1132 
1133       // Set allocated and aligned pointers.
1134       Value allocatedPtr, alignedPtr;
1135       extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1136                                reshapeOp.getSource(), adaptor.getSource(),
1137                                &allocatedPtr, &alignedPtr);
1138       desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1139       desc.setAlignedPtr(rewriter, loc, alignedPtr);
1140 
1141       // Extract the offset and strides from the type.
1142       int64_t offset;
1143       SmallVector<int64_t> strides;
1144       if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1145         return rewriter.notifyMatchFailure(
1146             reshapeOp, "failed to get stride and offset exprs");
1147 
1148       if (!isStaticStrideOrOffset(offset))
1149         return rewriter.notifyMatchFailure(reshapeOp,
1150                                            "dynamic offset is unsupported");
1151 
1152       desc.setConstantOffset(rewriter, loc, offset);
1153 
1154       assert(targetMemRefType.getLayout().isIdentity() &&
1155              "Identity layout map is a precondition of a valid reshape op");
1156 
1157       Type indexType = getIndexType();
1158       Value stride = nullptr;
1159       int64_t targetRank = targetMemRefType.getRank();
1160       for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1161         if (!ShapedType::isDynamic(strides[i])) {
1162           // If the stride for this dimension is dynamic, then use the product
1163           // of the sizes of the inner dimensions.
1164           stride =
1165               createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1166         } else if (!stride) {
1167           // `stride` is null only in the first iteration of the loop.  However,
1168           // since the target memref has an identity layout, we can safely set
1169           // the innermost stride to 1.
1170           stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1171         }
1172 
1173         Value dimSize;
1174         // If the size of this dimension is dynamic, then load it at runtime
1175         // from the shape operand.
1176         if (!targetMemRefType.isDynamicDim(i)) {
1177           dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1178                                             targetMemRefType.getDimSize(i));
1179         } else {
1180           Value shapeOp = reshapeOp.getShape();
1181           Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1182           dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1183           Type indexType = getIndexType();
1184           if (dimSize.getType() != indexType)
1185             dimSize = typeConverter->materializeTargetConversion(
1186                 rewriter, loc, indexType, dimSize);
1187           assert(dimSize && "Invalid memref element type");
1188         }
1189 
1190         desc.setSize(rewriter, loc, i, dimSize);
1191         desc.setStride(rewriter, loc, i, stride);
1192 
1193         // Prepare the stride value for the next dimension.
1194         stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1195       }
1196 
1197       *descriptor = desc;
1198       return success();
1199     }
1200 
1201     // The shape is a rank-1 tensor with unknown length.
1202     Location loc = reshapeOp.getLoc();
1203     MemRefDescriptor shapeDesc(adaptor.getShape());
1204     Value resultRank = shapeDesc.size(rewriter, loc, 0);
1205 
1206     // Extract address space and element type.
1207     auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1208     unsigned addressSpace =
1209         *getTypeConverter()->getMemRefAddressSpace(targetType);
1210 
1211     // Create the unranked memref descriptor that holds the ranked one. The
1212     // inner descriptor is allocated on stack.
1213     auto targetDesc = UnrankedMemRefDescriptor::undef(
1214         rewriter, loc, typeConverter->convertType(targetType));
1215     targetDesc.setRank(rewriter, loc, resultRank);
1216     SmallVector<Value, 4> sizes;
1217     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1218                                            targetDesc, addressSpace, sizes);
1219     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1220         loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
1221         sizes.front());
1222     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1223 
1224     // Extract pointers and offset from the source memref.
1225     Value allocatedPtr, alignedPtr, offset;
1226     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1227                              reshapeOp.getSource(), adaptor.getSource(),
1228                              &allocatedPtr, &alignedPtr, &offset);
1229 
1230     // Set pointers and offset.
1231     auto elementPtrType =
1232         LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1233 
1234     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1235                                               elementPtrType, allocatedPtr);
1236     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1237                                             underlyingDescPtr, elementPtrType,
1238                                             alignedPtr);
1239     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1240                                         underlyingDescPtr, elementPtrType,
1241                                         offset);
1242 
1243     // Use the offset pointer as base for further addressing. Copy over the new
1244     // shape and compute strides. For this, we create a loop from rank-1 to 0.
1245     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1246         rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1247     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1248         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1249     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1250     Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1251     Value resultRankMinusOne =
1252         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1253 
1254     Block *initBlock = rewriter.getInsertionBlock();
1255     Type indexType = getTypeConverter()->getIndexType();
1256     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1257 
1258     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1259                                             {indexType, indexType}, {loc, loc});
1260 
1261     // Move the remaining initBlock ops to condBlock.
1262     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1263     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1264 
1265     rewriter.setInsertionPointToEnd(initBlock);
1266     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1267                                 condBlock);
1268     rewriter.setInsertionPointToStart(condBlock);
1269     Value indexArg = condBlock->getArgument(0);
1270     Value strideArg = condBlock->getArgument(1);
1271 
1272     Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1273     Value pred = rewriter.create<LLVM::ICmpOp>(
1274         loc, IntegerType::get(rewriter.getContext(), 1),
1275         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1276 
1277     Block *bodyBlock =
1278         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1279     rewriter.setInsertionPointToStart(bodyBlock);
1280 
1281     // Copy size from shape to descriptor.
1282     auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1283     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1284         loc, llvmIndexPtrType,
1285         typeConverter->convertType(shapeMemRefType.getElementType()),
1286         shapeOperandPtr, indexArg);
1287     Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1288     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1289                                       targetSizesBase, indexArg, size);
1290 
1291     // Write stride value and compute next one.
1292     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1293                                         targetStridesBase, indexArg, strideArg);
1294     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1295 
1296     // Decrement loop counter and branch back.
1297     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1298     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1299                                 condBlock);
1300 
1301     Block *remainder =
1302         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1303 
1304     // Hook up the cond exit to the remainder.
1305     rewriter.setInsertionPointToEnd(condBlock);
1306     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1307                                     remainder, std::nullopt);
1308 
1309     // Reset position to beginning of new remainder block.
1310     rewriter.setInsertionPointToStart(remainder);
1311 
1312     *descriptor = targetDesc;
1313     return success();
1314   }
1315 };
1316 
1317 /// RessociatingReshapeOp must be expanded before we reach this stage.
1318 /// Report that information.
1319 template <typename ReshapeOp>
1320 class ReassociatingReshapeOpConversion
1321     : public ConvertOpToLLVMPattern<ReshapeOp> {
1322 public:
1323   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1324   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1325 
1326   LogicalResult
1327   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1328                   ConversionPatternRewriter &rewriter) const override {
1329     return rewriter.notifyMatchFailure(
1330         reshapeOp,
1331         "reassociation operations should have been expanded beforehand");
1332   }
1333 };
1334 
1335 /// Subviews must be expanded before we reach this stage.
1336 /// Report that information.
1337 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1338   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1339 
1340   LogicalResult
1341   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1342                   ConversionPatternRewriter &rewriter) const override {
1343     return rewriter.notifyMatchFailure(
1344         subViewOp, "subview operations should have been expanded beforehand");
1345   }
1346 };
1347 
1348 /// Conversion pattern that transforms a transpose op into:
1349 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1350 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1351 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1352 ///      and stride. Size and stride are permutations of the original values.
1353 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1354 /// The transpose op is replaced by the alloca'ed pointer.
1355 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1356 public:
1357   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1358 
1359   LogicalResult
1360   matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1361                   ConversionPatternRewriter &rewriter) const override {
1362     auto loc = transposeOp.getLoc();
1363     MemRefDescriptor viewMemRef(adaptor.getIn());
1364 
1365     // No permutation, early exit.
1366     if (transposeOp.getPermutation().isIdentity())
1367       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1368 
1369     auto targetMemRef = MemRefDescriptor::undef(
1370         rewriter, loc,
1371         typeConverter->convertType(transposeOp.getIn().getType()));
1372 
1373     // Copy the base and aligned pointers from the old descriptor to the new
1374     // one.
1375     targetMemRef.setAllocatedPtr(rewriter, loc,
1376                                  viewMemRef.allocatedPtr(rewriter, loc));
1377     targetMemRef.setAlignedPtr(rewriter, loc,
1378                                viewMemRef.alignedPtr(rewriter, loc));
1379 
1380     // Copy the offset pointer from the old descriptor to the new one.
1381     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1382 
1383     // Iterate over the dimensions and apply size/stride permutation:
1384     // When enumerating the results of the permutation map, the enumeration
1385     // index is the index into the target dimensions and the DimExpr points to
1386     // the dimension of the source memref.
1387     for (const auto &en :
1388          llvm::enumerate(transposeOp.getPermutation().getResults())) {
1389       int targetPos = en.index();
1390       int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1391       targetMemRef.setSize(rewriter, loc, targetPos,
1392                            viewMemRef.size(rewriter, loc, sourcePos));
1393       targetMemRef.setStride(rewriter, loc, targetPos,
1394                              viewMemRef.stride(rewriter, loc, sourcePos));
1395     }
1396 
1397     rewriter.replaceOp(transposeOp, {targetMemRef});
1398     return success();
1399   }
1400 };
1401 
1402 /// Conversion pattern that transforms an op into:
1403 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1404 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1405 ///      and stride.
1406 /// The view op is replaced by the descriptor.
1407 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1408   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1409 
1410   // Build and return the value for the idx^th shape dimension, either by
1411   // returning the constant shape dimension or counting the proper dynamic size.
1412   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1413                 ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1414                 Type indexType) const {
1415     assert(idx < shape.size());
1416     if (!ShapedType::isDynamic(shape[idx]))
1417       return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1418     // Count the number of dynamic dims in range [0, idx]
1419     unsigned nDynamic =
1420         llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1421     return dynamicSizes[nDynamic];
1422   }
1423 
1424   // Build and return the idx^th stride, either by returning the constant stride
1425   // or by computing the dynamic stride from the current `runningStride` and
1426   // `nextSize`. The caller should keep a running stride and update it with the
1427   // result returned by this function.
1428   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1429                   ArrayRef<int64_t> strides, Value nextSize,
1430                   Value runningStride, unsigned idx, Type indexType) const {
1431     assert(idx < strides.size());
1432     if (!ShapedType::isDynamic(strides[idx]))
1433       return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1434     if (nextSize)
1435       return runningStride
1436                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1437                  : nextSize;
1438     assert(!runningStride);
1439     return createIndexAttrConstant(rewriter, loc, indexType, 1);
1440   }
1441 
1442   LogicalResult
1443   matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1444                   ConversionPatternRewriter &rewriter) const override {
1445     auto loc = viewOp.getLoc();
1446 
1447     auto viewMemRefType = viewOp.getType();
1448     auto targetElementTy =
1449         typeConverter->convertType(viewMemRefType.getElementType());
1450     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1451     if (!targetDescTy || !targetElementTy ||
1452         !LLVM::isCompatibleType(targetElementTy) ||
1453         !LLVM::isCompatibleType(targetDescTy))
1454       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1455              failure();
1456 
1457     int64_t offset;
1458     SmallVector<int64_t, 4> strides;
1459     auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1460     if (failed(successStrides))
1461       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1462     assert(offset == 0 && "expected offset to be 0");
1463 
1464     // Target memref must be contiguous in memory (innermost stride is 1), or
1465     // empty (special case when at least one of the memref dimensions is 0).
1466     if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1467       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1468              failure();
1469 
1470     // Create the descriptor.
1471     MemRefDescriptor sourceMemRef(adaptor.getSource());
1472     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1473 
1474     // Field 1: Copy the allocated pointer, used for malloc/free.
1475     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1476     auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1477     targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1478 
1479     // Field 2: Copy the actual aligned pointer to payload.
1480     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1481     alignedPtr = rewriter.create<LLVM::GEPOp>(
1482         loc, alignedPtr.getType(),
1483         typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1484         adaptor.getByteShift());
1485 
1486     targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1487 
1488     Type indexType = getIndexType();
1489     // Field 3: The offset in the resulting type must be 0. This is
1490     // because of the type change: an offset on srcType* may not be
1491     // expressible as an offset on dstType*.
1492     targetMemRef.setOffset(
1493         rewriter, loc,
1494         createIndexAttrConstant(rewriter, loc, indexType, offset));
1495 
1496     // Early exit for 0-D corner case.
1497     if (viewMemRefType.getRank() == 0)
1498       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1499 
1500     // Fields 4 and 5: Update sizes and strides.
1501     Value stride = nullptr, nextSize = nullptr;
1502     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1503       // Update size.
1504       Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1505                            adaptor.getSizes(), i, indexType);
1506       targetMemRef.setSize(rewriter, loc, i, size);
1507       // Update stride.
1508       stride =
1509           getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1510       targetMemRef.setStride(rewriter, loc, i, stride);
1511       nextSize = size;
1512     }
1513 
1514     rewriter.replaceOp(viewOp, {targetMemRef});
1515     return success();
1516   }
1517 };
1518 
1519 //===----------------------------------------------------------------------===//
1520 // AtomicRMWOpLowering
1521 //===----------------------------------------------------------------------===//
1522 
1523 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1524 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1525 static std::optional<LLVM::AtomicBinOp>
1526 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1527   switch (atomicOp.getKind()) {
1528   case arith::AtomicRMWKind::addf:
1529     return LLVM::AtomicBinOp::fadd;
1530   case arith::AtomicRMWKind::addi:
1531     return LLVM::AtomicBinOp::add;
1532   case arith::AtomicRMWKind::assign:
1533     return LLVM::AtomicBinOp::xchg;
1534   case arith::AtomicRMWKind::maximumf:
1535     return LLVM::AtomicBinOp::fmax;
1536   case arith::AtomicRMWKind::maxs:
1537     return LLVM::AtomicBinOp::max;
1538   case arith::AtomicRMWKind::maxu:
1539     return LLVM::AtomicBinOp::umax;
1540   case arith::AtomicRMWKind::minimumf:
1541     return LLVM::AtomicBinOp::fmin;
1542   case arith::AtomicRMWKind::mins:
1543     return LLVM::AtomicBinOp::min;
1544   case arith::AtomicRMWKind::minu:
1545     return LLVM::AtomicBinOp::umin;
1546   case arith::AtomicRMWKind::ori:
1547     return LLVM::AtomicBinOp::_or;
1548   case arith::AtomicRMWKind::andi:
1549     return LLVM::AtomicBinOp::_and;
1550   default:
1551     return std::nullopt;
1552   }
1553   llvm_unreachable("Invalid AtomicRMWKind");
1554 }
1555 
1556 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1557   using Base::Base;
1558 
1559   LogicalResult
1560   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1561                   ConversionPatternRewriter &rewriter) const override {
1562     auto maybeKind = matchSimpleAtomicOp(atomicOp);
1563     if (!maybeKind)
1564       return failure();
1565     auto memRefType = atomicOp.getMemRefType();
1566     SmallVector<int64_t> strides;
1567     int64_t offset;
1568     if (failed(memRefType.getStridesAndOffset(strides, offset)))
1569       return failure();
1570     auto dataPtr =
1571         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1572                              adaptor.getIndices(), rewriter);
1573     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1574         atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1575         LLVM::AtomicOrdering::acq_rel);
1576     return success();
1577   }
1578 };
1579 
1580 /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1581 class ConvertExtractAlignedPointerAsIndex
1582     : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1583 public:
1584   using ConvertOpToLLVMPattern<
1585       memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1586 
1587   LogicalResult
1588   matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1589                   OpAdaptor adaptor,
1590                   ConversionPatternRewriter &rewriter) const override {
1591     BaseMemRefType sourceTy = extractOp.getSource().getType();
1592 
1593     Value alignedPtr;
1594     if (sourceTy.hasRank()) {
1595       MemRefDescriptor desc(adaptor.getSource());
1596       alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1597     } else {
1598       auto elementPtrTy = LLVM::LLVMPointerType::get(
1599           rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
1600 
1601       UnrankedMemRefDescriptor desc(adaptor.getSource());
1602       Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1603 
1604       alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1605           rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1606           elementPtrTy);
1607     }
1608 
1609     rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1610         extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1611     return success();
1612   }
1613 };
1614 
1615 /// Materialize the MemRef descriptor represented by the results of
1616 /// ExtractStridedMetadataOp.
1617 class ExtractStridedMetadataOpLowering
1618     : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1619 public:
1620   using ConvertOpToLLVMPattern<
1621       memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1622 
1623   LogicalResult
1624   matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1625                   OpAdaptor adaptor,
1626                   ConversionPatternRewriter &rewriter) const override {
1627 
1628     if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1629       return failure();
1630 
1631     // Create the descriptor.
1632     MemRefDescriptor sourceMemRef(adaptor.getSource());
1633     Location loc = extractStridedMetadataOp.getLoc();
1634     Value source = extractStridedMetadataOp.getSource();
1635 
1636     auto sourceMemRefType = cast<MemRefType>(source.getType());
1637     int64_t rank = sourceMemRefType.getRank();
1638     SmallVector<Value> results;
1639     results.reserve(2 + rank * 2);
1640 
1641     // Base buffer.
1642     Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1643     Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1644     MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape(
1645         rewriter, loc, *getTypeConverter(),
1646         cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1647         baseBuffer, alignedBuffer);
1648     results.push_back((Value)dstMemRef);
1649 
1650     // Offset.
1651     results.push_back(sourceMemRef.offset(rewriter, loc));
1652 
1653     // Sizes.
1654     for (unsigned i = 0; i < rank; ++i)
1655       results.push_back(sourceMemRef.size(rewriter, loc, i));
1656     // Strides.
1657     for (unsigned i = 0; i < rank; ++i)
1658       results.push_back(sourceMemRef.stride(rewriter, loc, i));
1659 
1660     rewriter.replaceOp(extractStridedMetadataOp, results);
1661     return success();
1662   }
1663 };
1664 
1665 } // namespace
1666 
1667 void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
1668     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1669   // clang-format off
1670   patterns.add<
1671       AllocaOpLowering,
1672       AllocaScopeOpLowering,
1673       AtomicRMWOpLowering,
1674       AssumeAlignmentOpLowering,
1675       ConvertExtractAlignedPointerAsIndex,
1676       DimOpLowering,
1677       ExtractStridedMetadataOpLowering,
1678       GenericAtomicRMWOpLowering,
1679       GlobalMemrefOpLowering,
1680       GetGlobalMemrefOpLowering,
1681       LoadOpLowering,
1682       MemRefCastOpLowering,
1683       MemRefCopyOpLowering,
1684       MemorySpaceCastOpLowering,
1685       MemRefReinterpretCastOpLowering,
1686       MemRefReshapeOpLowering,
1687       PrefetchOpLowering,
1688       RankOpLowering,
1689       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1690       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1691       StoreOpLowering,
1692       SubViewOpLowering,
1693       TransposeOpLowering,
1694       ViewOpLowering>(converter);
1695   // clang-format on
1696   auto allocLowering = converter.getOptions().allocLowering;
1697   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1698     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1699   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1700     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1701 }
1702 
1703 namespace {
1704 struct FinalizeMemRefToLLVMConversionPass
1705     : public impl::FinalizeMemRefToLLVMConversionPassBase<
1706           FinalizeMemRefToLLVMConversionPass> {
1707   using FinalizeMemRefToLLVMConversionPassBase::
1708       FinalizeMemRefToLLVMConversionPassBase;
1709 
1710   void runOnOperation() override {
1711     Operation *op = getOperation();
1712     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1713     LowerToLLVMOptions options(&getContext(),
1714                                dataLayoutAnalysis.getAtOrAbove(op));
1715     options.allocLowering =
1716         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1717                          : LowerToLLVMOptions::AllocLowering::Malloc);
1718 
1719     options.useGenericFunctions = useGenericFunctions;
1720 
1721     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1722       options.overrideIndexBitwidth(indexBitwidth);
1723 
1724     LLVMTypeConverter typeConverter(&getContext(), options,
1725                                     &dataLayoutAnalysis);
1726     RewritePatternSet patterns(&getContext());
1727     populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
1728     LLVMConversionTarget target(getContext());
1729     target.addLegalOp<func::FuncOp>();
1730     if (failed(applyPartialConversion(op, target, std::move(patterns))))
1731       signalPassFailure();
1732   }
1733 };
1734 
1735 /// Implement the interface to convert MemRef to LLVM.
1736 struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
1737   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
1738   void loadDependentDialects(MLIRContext *context) const final {
1739     context->loadDialect<LLVM::LLVMDialect>();
1740   }
1741 
1742   /// Hook for derived dialect interface to provide conversion patterns
1743   /// and mark dialect legal for the conversion target.
1744   void populateConvertToLLVMConversionPatterns(
1745       ConversionTarget &target, LLVMTypeConverter &typeConverter,
1746       RewritePatternSet &patterns) const final {
1747     populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
1748   }
1749 };
1750 
1751 } // namespace
1752 
1753 void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry &registry) {
1754   registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
1755     dialect->addInterfaces<MemRefToLLVMDialectInterface>();
1756   });
1757 }
1758