10bf120a8SThéo Degioanni //===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===//
20bf120a8SThéo Degioanni //
30bf120a8SThéo Degioanni // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40bf120a8SThéo Degioanni // See https://llvm.org/LICENSE.txt for license information.
50bf120a8SThéo Degioanni // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60bf120a8SThéo Degioanni //
70bf120a8SThéo Degioanni //===----------------------------------------------------------------------===//
80bf120a8SThéo Degioanni //
90bf120a8SThéo Degioanni // This file implements Mem2Reg-related interfaces for MemRef dialect
100bf120a8SThéo Degioanni // operations.
110bf120a8SThéo Degioanni //
120bf120a8SThéo Degioanni //===----------------------------------------------------------------------===//
130bf120a8SThéo Degioanni
140bf120a8SThéo Degioanni #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
150bf120a8SThéo Degioanni #include "mlir/Dialect/MemRef/IR/MemRef.h"
160bf120a8SThéo Degioanni #include "mlir/IR/BuiltinDialect.h"
170bf120a8SThéo Degioanni #include "mlir/IR/BuiltinTypes.h"
180bf120a8SThéo Degioanni #include "mlir/IR/Matchers.h"
190bf120a8SThéo Degioanni #include "mlir/IR/PatternMatch.h"
200bf120a8SThéo Degioanni #include "mlir/IR/Value.h"
210bf120a8SThéo Degioanni #include "mlir/Interfaces/InferTypeOpInterface.h"
220bf120a8SThéo Degioanni #include "mlir/Interfaces/MemorySlotInterfaces.h"
230bf120a8SThéo Degioanni #include "llvm/ADT/ArrayRef.h"
240bf120a8SThéo Degioanni #include "llvm/ADT/TypeSwitch.h"
258404b23aSThéo Degioanni #include "llvm/Support/ErrorHandling.h"
260bf120a8SThéo Degioanni
270bf120a8SThéo Degioanni using namespace mlir;
280bf120a8SThéo Degioanni
290bf120a8SThéo Degioanni //===----------------------------------------------------------------------===//
300bf120a8SThéo Degioanni // Utilities
310bf120a8SThéo Degioanni //===----------------------------------------------------------------------===//
320bf120a8SThéo Degioanni
330bf120a8SThéo Degioanni /// Walks over the indices of the elements of a tensor of a given `shape` by
340bf120a8SThéo Degioanni /// updating `index` in place to the next index. This returns failure if the
350bf120a8SThéo Degioanni /// provided index was the last index.
nextIndex(ArrayRef<int64_t> shape,MutableArrayRef<int64_t> index)360bf120a8SThéo Degioanni static LogicalResult nextIndex(ArrayRef<int64_t> shape,
370bf120a8SThéo Degioanni MutableArrayRef<int64_t> index) {
380bf120a8SThéo Degioanni for (size_t i = 0; i < shape.size(); ++i) {
390bf120a8SThéo Degioanni index[i]++;
400bf120a8SThéo Degioanni if (index[i] < shape[i])
410bf120a8SThéo Degioanni return success();
420bf120a8SThéo Degioanni index[i] = 0;
430bf120a8SThéo Degioanni }
440bf120a8SThéo Degioanni return failure();
450bf120a8SThéo Degioanni }
460bf120a8SThéo Degioanni
470bf120a8SThéo Degioanni /// Calls `walker` for each index within a tensor of a given `shape`, providing
480bf120a8SThéo Degioanni /// the index as an array attribute of the coordinates.
490bf120a8SThéo Degioanni template <typename CallableT>
walkIndicesAsAttr(MLIRContext * ctx,ArrayRef<int64_t> shape,CallableT && walker)500bf120a8SThéo Degioanni static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
510bf120a8SThéo Degioanni CallableT &&walker) {
520bf120a8SThéo Degioanni Type indexType = IndexType::get(ctx);
530bf120a8SThéo Degioanni SmallVector<int64_t> shapeIter(shape.size(), 0);
540bf120a8SThéo Degioanni do {
550bf120a8SThéo Degioanni SmallVector<Attribute> indexAsAttr;
560bf120a8SThéo Degioanni for (int64_t dim : shapeIter)
570bf120a8SThéo Degioanni indexAsAttr.push_back(IntegerAttr::get(indexType, dim));
580bf120a8SThéo Degioanni walker(ArrayAttr::get(ctx, indexAsAttr));
590bf120a8SThéo Degioanni } while (succeeded(nextIndex(shape, shapeIter)));
600bf120a8SThéo Degioanni }
610bf120a8SThéo Degioanni
620bf120a8SThéo Degioanni //===----------------------------------------------------------------------===//
630bf120a8SThéo Degioanni // Interfaces for AllocaOp
640bf120a8SThéo Degioanni //===----------------------------------------------------------------------===//
650bf120a8SThéo Degioanni
isSupportedElementType(Type type)660bf120a8SThéo Degioanni static bool isSupportedElementType(Type type) {
6768f58812STres Popp return llvm::isa<MemRefType>(type) ||
680bf120a8SThéo Degioanni OpBuilder(type.getContext()).getZeroAttr(type);
690bf120a8SThéo Degioanni }
700bf120a8SThéo Degioanni
getPromotableSlots()710bf120a8SThéo Degioanni SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
720bf120a8SThéo Degioanni MemRefType type = getType();
730bf120a8SThéo Degioanni if (!isSupportedElementType(type.getElementType()))
740bf120a8SThéo Degioanni return {};
750bf120a8SThéo Degioanni if (!type.hasStaticShape())
760bf120a8SThéo Degioanni return {};
770bf120a8SThéo Degioanni // Make sure the memref contains only a single element.
780bf120a8SThéo Degioanni if (type.getNumElements() != 1)
790bf120a8SThéo Degioanni return {};
800bf120a8SThéo Degioanni
810bf120a8SThéo Degioanni return {MemorySlot{getResult(), type.getElementType()}};
820bf120a8SThéo Degioanni }
830bf120a8SThéo Degioanni
getDefaultValue(const MemorySlot & slot,OpBuilder & builder)840bf120a8SThéo Degioanni Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
85084e2b53SChristian Ulmann OpBuilder &builder) {
860bf120a8SThéo Degioanni assert(isSupportedElementType(slot.elemType));
870bf120a8SThéo Degioanni // TODO: support more types.
880bf120a8SThéo Degioanni return TypeSwitch<Type, Value>(slot.elemType)
890bf120a8SThéo Degioanni .Case([&](MemRefType t) {
90084e2b53SChristian Ulmann return builder.create<memref::AllocaOp>(getLoc(), t);
910bf120a8SThéo Degioanni })
920bf120a8SThéo Degioanni .Default([&](Type t) {
93084e2b53SChristian Ulmann return builder.create<arith::ConstantOp>(getLoc(), t,
94084e2b53SChristian Ulmann builder.getZeroAttr(t));
950bf120a8SThéo Degioanni });
960bf120a8SThéo Degioanni }
970bf120a8SThéo Degioanni
98eeafc9daSChristian Ulmann std::optional<PromotableAllocationOpInterface>
handlePromotionComplete(const MemorySlot & slot,Value defaultValue,OpBuilder & builder)99eeafc9daSChristian Ulmann memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
1000bf120a8SThéo Degioanni Value defaultValue,
101084e2b53SChristian Ulmann OpBuilder &builder) {
1020bf120a8SThéo Degioanni if (defaultValue.use_empty())
103084e2b53SChristian Ulmann defaultValue.getDefiningOp()->erase();
104084e2b53SChristian Ulmann this->erase();
105eeafc9daSChristian Ulmann return std::nullopt;
1060bf120a8SThéo Degioanni }
1070bf120a8SThéo Degioanni
handleBlockArgument(const MemorySlot & slot,BlockArgument argument,OpBuilder & builder)1080bf120a8SThéo Degioanni void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
1090bf120a8SThéo Degioanni BlockArgument argument,
110084e2b53SChristian Ulmann OpBuilder &builder) {}
1110bf120a8SThéo Degioanni
1120bf120a8SThéo Degioanni SmallVector<DestructurableMemorySlot>
getDestructurableSlots()1130bf120a8SThéo Degioanni memref::AllocaOp::getDestructurableSlots() {
1140bf120a8SThéo Degioanni MemRefType memrefType = getType();
11568f58812STres Popp auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
1160bf120a8SThéo Degioanni if (!destructurable)
1170bf120a8SThéo Degioanni return {};
1180bf120a8SThéo Degioanni
119743dd4dbSKazu Hirata std::optional<DenseMap<Attribute, Type>> destructuredType =
1200bf120a8SThéo Degioanni destructurable.getSubelementIndexMap();
1210bf120a8SThéo Degioanni if (!destructuredType)
1220bf120a8SThéo Degioanni return {};
1230bf120a8SThéo Degioanni
1240289ae51SChristian Ulmann return {
1250289ae51SChristian Ulmann DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
1260bf120a8SThéo Degioanni }
1270bf120a8SThéo Degioanni
destructure(const DestructurableMemorySlot & slot,const SmallPtrSetImpl<Attribute> & usedIndices,OpBuilder & builder,SmallVectorImpl<DestructurableAllocationOpInterface> & newAllocators)1280b5b2027SChristian Ulmann DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
1290b5b2027SChristian Ulmann const DestructurableMemorySlot &slot,
1300b5b2027SChristian Ulmann const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
1310b5b2027SChristian Ulmann SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
132084e2b53SChristian Ulmann builder.setInsertionPointAfter(*this);
1330bf120a8SThéo Degioanni
1340bf120a8SThéo Degioanni DenseMap<Attribute, MemorySlot> slotMap;
1350bf120a8SThéo Degioanni
13668f58812STres Popp auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
1370bf120a8SThéo Degioanni for (Attribute usedIndex : usedIndices) {
1380bf120a8SThéo Degioanni Type elemType = memrefType.getTypeAtIndex(usedIndex);
1390bf120a8SThéo Degioanni MemRefType elemPtr = MemRefType::get({}, elemType);
140084e2b53SChristian Ulmann auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr);
1410b5b2027SChristian Ulmann newAllocators.push_back(subAlloca);
1420bf120a8SThéo Degioanni slotMap.try_emplace<MemorySlot>(usedIndex,
1430bf120a8SThéo Degioanni {subAlloca.getResult(), elemType});
1440bf120a8SThéo Degioanni }
1450bf120a8SThéo Degioanni
1460bf120a8SThéo Degioanni return slotMap;
1470bf120a8SThéo Degioanni }
1480bf120a8SThéo Degioanni
1490b5b2027SChristian Ulmann std::optional<DestructurableAllocationOpInterface>
handleDestructuringComplete(const DestructurableMemorySlot & slot,OpBuilder & builder)1500b5b2027SChristian Ulmann memref::AllocaOp::handleDestructuringComplete(
151084e2b53SChristian Ulmann const DestructurableMemorySlot &slot, OpBuilder &builder) {
1520bf120a8SThéo Degioanni assert(slot.ptr == getResult());
153084e2b53SChristian Ulmann this->erase();
1540b5b2027SChristian Ulmann return std::nullopt;
1550bf120a8SThéo Degioanni }
1560bf120a8SThéo Degioanni
1570bf120a8SThéo Degioanni //===----------------------------------------------------------------------===//
1580bf120a8SThéo Degioanni // Interfaces for LoadOp/StoreOp
1590bf120a8SThéo Degioanni //===----------------------------------------------------------------------===//
1600bf120a8SThéo Degioanni
loadsFrom(const MemorySlot & slot)1610bf120a8SThéo Degioanni bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
1620bf120a8SThéo Degioanni return getMemRef() == slot.ptr;
1630bf120a8SThéo Degioanni }
1640bf120a8SThéo Degioanni
storesTo(const MemorySlot & slot)1658404b23aSThéo Degioanni bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
1668404b23aSThéo Degioanni
getStored(const MemorySlot & slot,OpBuilder & builder,Value reachingDef,const DataLayout & dataLayout)167084e2b53SChristian Ulmann Value memref::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1686e9ea6eaSChristian Ulmann Value reachingDef,
169ac39fa74SChristian Ulmann const DataLayout &dataLayout) {
1708404b23aSThéo Degioanni llvm_unreachable("getStored should not be called on LoadOp");
1718404b23aSThéo Degioanni }
1720bf120a8SThéo Degioanni
canUsesBeRemoved(const MemorySlot & slot,const SmallPtrSetImpl<OpOperand * > & blockingUses,SmallVectorImpl<OpOperand * > & newBlockingUses,const DataLayout & dataLayout)1730bf120a8SThéo Degioanni bool memref::LoadOp::canUsesBeRemoved(
1740bf120a8SThéo Degioanni const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
17598c6bc53SChristian Ulmann SmallVectorImpl<OpOperand *> &newBlockingUses,
17698c6bc53SChristian Ulmann const DataLayout &dataLayout) {
1770bf120a8SThéo Degioanni if (blockingUses.size() != 1)
1780bf120a8SThéo Degioanni return false;
1790bf120a8SThéo Degioanni Value blockingUse = (*blockingUses.begin())->get();
1800bf120a8SThéo Degioanni return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
1810bf120a8SThéo Degioanni getResult().getType() == slot.elemType;
1820bf120a8SThéo Degioanni }
1830bf120a8SThéo Degioanni
removeBlockingUses(const MemorySlot & slot,const SmallPtrSetImpl<OpOperand * > & blockingUses,OpBuilder & builder,Value reachingDefinition,const DataLayout & dataLayout)1840bf120a8SThéo Degioanni DeletionKind memref::LoadOp::removeBlockingUses(
1850bf120a8SThéo Degioanni const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
186084e2b53SChristian Ulmann OpBuilder &builder, Value reachingDefinition,
187ac39fa74SChristian Ulmann const DataLayout &dataLayout) {
1880bf120a8SThéo Degioanni // `canUsesBeRemoved` checked this blocking use must be the loaded slot
1890bf120a8SThéo Degioanni // pointer.
190084e2b53SChristian Ulmann getResult().replaceAllUsesWith(reachingDefinition);
1910bf120a8SThéo Degioanni return DeletionKind::Delete;
1920bf120a8SThéo Degioanni }
1930bf120a8SThéo Degioanni
194b142501eSThéo Degioanni /// Returns the index of a memref in attribute form, given its indices. Returns
195b142501eSThéo Degioanni /// a null pointer if whether the indices form a valid index for the provided
196b142501eSThéo Degioanni /// MemRefType cannot be computed. The indices must come from a valid memref
197b142501eSThéo Degioanni /// StoreOp or LoadOp.
getAttributeIndexFromIndexOperands(MLIRContext * ctx,ValueRange indices,MemRefType memrefType)1980bf120a8SThéo Degioanni static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx,
199b142501eSThéo Degioanni ValueRange indices,
200b142501eSThéo Degioanni MemRefType memrefType) {
2010bf120a8SThéo Degioanni SmallVector<Attribute> index;
202b142501eSThéo Degioanni for (auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
2030bf120a8SThéo Degioanni IntegerAttr coordAttr;
2040bf120a8SThéo Degioanni if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
2050bf120a8SThéo Degioanni return {};
206b142501eSThéo Degioanni // MemRefType shape dimensions are always positive (checked by verifier).
207b142501eSThéo Degioanni std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
208b142501eSThéo Degioanni if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
209b142501eSThéo Degioanni return {};
2100bf120a8SThéo Degioanni index.push_back(coordAttr);
2110bf120a8SThéo Degioanni }
2120bf120a8SThéo Degioanni return ArrayAttr::get(ctx, index);
2130bf120a8SThéo Degioanni }
2140bf120a8SThéo Degioanni
canRewire(const DestructurableMemorySlot & slot,SmallPtrSetImpl<Attribute> & usedIndices,SmallVectorImpl<MemorySlot> & mustBeSafelyUsed,const DataLayout & dataLayout)2150bf120a8SThéo Degioanni bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
2160bf120a8SThéo Degioanni SmallPtrSetImpl<Attribute> &usedIndices,
21798c6bc53SChristian Ulmann SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
21898c6bc53SChristian Ulmann const DataLayout &dataLayout) {
2190bf120a8SThéo Degioanni if (slot.ptr != getMemRef())
2200bf120a8SThéo Degioanni return false;
221b142501eSThéo Degioanni Attribute index = getAttributeIndexFromIndexOperands(
222b142501eSThéo Degioanni getContext(), getIndices(), getMemRefType());
2230bf120a8SThéo Degioanni if (!index)
2240bf120a8SThéo Degioanni return false;
2250bf120a8SThéo Degioanni usedIndices.insert(index);
2260bf120a8SThéo Degioanni return true;
2270bf120a8SThéo Degioanni }
2280bf120a8SThéo Degioanni
rewire(const DestructurableMemorySlot & slot,DenseMap<Attribute,MemorySlot> & subslots,OpBuilder & builder,const DataLayout & dataLayout)2290bf120a8SThéo Degioanni DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
2300bf120a8SThéo Degioanni DenseMap<Attribute, MemorySlot> &subslots,
231084e2b53SChristian Ulmann OpBuilder &builder,
23298c6bc53SChristian Ulmann const DataLayout &dataLayout) {
233b142501eSThéo Degioanni Attribute index = getAttributeIndexFromIndexOperands(
234b142501eSThéo Degioanni getContext(), getIndices(), getMemRefType());
2350bf120a8SThéo Degioanni const MemorySlot &memorySlot = subslots.at(index);
2360bf120a8SThéo Degioanni setMemRef(memorySlot.ptr);
2370bf120a8SThéo Degioanni getIndicesMutable().clear();
2380bf120a8SThéo Degioanni return DeletionKind::Keep;
2390bf120a8SThéo Degioanni }
2400bf120a8SThéo Degioanni
loadsFrom(const MemorySlot & slot)2410bf120a8SThéo Degioanni bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
2420bf120a8SThéo Degioanni
storesTo(const MemorySlot & slot)2438404b23aSThéo Degioanni bool memref::StoreOp::storesTo(const MemorySlot &slot) {
2448404b23aSThéo Degioanni return getMemRef() == slot.ptr;
2458404b23aSThéo Degioanni }
2468404b23aSThéo Degioanni
getStored(const MemorySlot & slot,OpBuilder & builder,Value reachingDef,const DataLayout & dataLayout)247084e2b53SChristian Ulmann Value memref::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder,
2486e9ea6eaSChristian Ulmann Value reachingDef,
249ac39fa74SChristian Ulmann const DataLayout &dataLayout) {
2500bf120a8SThéo Degioanni return getValue();
2510bf120a8SThéo Degioanni }
2520bf120a8SThéo Degioanni
canUsesBeRemoved(const MemorySlot & slot,const SmallPtrSetImpl<OpOperand * > & blockingUses,SmallVectorImpl<OpOperand * > & newBlockingUses,const DataLayout & dataLayout)2530bf120a8SThéo Degioanni bool memref::StoreOp::canUsesBeRemoved(
2540bf120a8SThéo Degioanni const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
25598c6bc53SChristian Ulmann SmallVectorImpl<OpOperand *> &newBlockingUses,
25698c6bc53SChristian Ulmann const DataLayout &dataLayout) {
2570bf120a8SThéo Degioanni if (blockingUses.size() != 1)
2580bf120a8SThéo Degioanni return false;
2590bf120a8SThéo Degioanni Value blockingUse = (*blockingUses.begin())->get();
2600bf120a8SThéo Degioanni return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
2610bf120a8SThéo Degioanni getValue() != slot.ptr && getValue().getType() == slot.elemType;
2620bf120a8SThéo Degioanni }
2630bf120a8SThéo Degioanni
removeBlockingUses(const MemorySlot & slot,const SmallPtrSetImpl<OpOperand * > & blockingUses,OpBuilder & builder,Value reachingDefinition,const DataLayout & dataLayout)2640bf120a8SThéo Degioanni DeletionKind memref::StoreOp::removeBlockingUses(
2650bf120a8SThéo Degioanni const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
266084e2b53SChristian Ulmann OpBuilder &builder, Value reachingDefinition,
267ac39fa74SChristian Ulmann const DataLayout &dataLayout) {
2680bf120a8SThéo Degioanni return DeletionKind::Delete;
2690bf120a8SThéo Degioanni }
2700bf120a8SThéo Degioanni
canRewire(const DestructurableMemorySlot & slot,SmallPtrSetImpl<Attribute> & usedIndices,SmallVectorImpl<MemorySlot> & mustBeSafelyUsed,const DataLayout & dataLayout)2710bf120a8SThéo Degioanni bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
2720bf120a8SThéo Degioanni SmallPtrSetImpl<Attribute> &usedIndices,
27398c6bc53SChristian Ulmann SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
27498c6bc53SChristian Ulmann const DataLayout &dataLayout) {
2750bf120a8SThéo Degioanni if (slot.ptr != getMemRef() || getValue() == slot.ptr)
2760bf120a8SThéo Degioanni return false;
277b142501eSThéo Degioanni Attribute index = getAttributeIndexFromIndexOperands(
278b142501eSThéo Degioanni getContext(), getIndices(), getMemRefType());
279*69d3793fSThéo Degioanni if (!index || !slot.subelementTypes.contains(index))
2800bf120a8SThéo Degioanni return false;
2810bf120a8SThéo Degioanni usedIndices.insert(index);
2820bf120a8SThéo Degioanni return true;
2830bf120a8SThéo Degioanni }
2840bf120a8SThéo Degioanni
rewire(const DestructurableMemorySlot & slot,DenseMap<Attribute,MemorySlot> & subslots,OpBuilder & builder,const DataLayout & dataLayout)2850bf120a8SThéo Degioanni DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
2860bf120a8SThéo Degioanni DenseMap<Attribute, MemorySlot> &subslots,
287084e2b53SChristian Ulmann OpBuilder &builder,
28898c6bc53SChristian Ulmann const DataLayout &dataLayout) {
289b142501eSThéo Degioanni Attribute index = getAttributeIndexFromIndexOperands(
290b142501eSThéo Degioanni getContext(), getIndices(), getMemRefType());
2910bf120a8SThéo Degioanni const MemorySlot &memorySlot = subslots.at(index);
2920bf120a8SThéo Degioanni setMemRef(memorySlot.ptr);
2930bf120a8SThéo Degioanni getIndicesMutable().clear();
2940bf120a8SThéo Degioanni return DeletionKind::Keep;
2950bf120a8SThéo Degioanni }
2960bf120a8SThéo Degioanni
2970bf120a8SThéo Degioanni //===----------------------------------------------------------------------===//
2980bf120a8SThéo Degioanni // Interfaces for destructurable types
2990bf120a8SThéo Degioanni //===----------------------------------------------------------------------===//
3000bf120a8SThéo Degioanni
3010bf120a8SThéo Degioanni namespace {
3020bf120a8SThéo Degioanni
3030bf120a8SThéo Degioanni struct MemRefDestructurableTypeExternalModel
3040bf120a8SThéo Degioanni : public DestructurableTypeInterface::ExternalModel<
3050bf120a8SThéo Degioanni MemRefDestructurableTypeExternalModel, MemRefType> {
3060bf120a8SThéo Degioanni std::optional<DenseMap<Attribute, Type>>
getSubelementIndexMap__anon8a6d594d0311::MemRefDestructurableTypeExternalModel3070bf120a8SThéo Degioanni getSubelementIndexMap(Type type) const {
30868f58812STres Popp auto memrefType = llvm::cast<MemRefType>(type);
3090bf120a8SThéo Degioanni constexpr int64_t maxMemrefSizeForDestructuring = 16;
3100bf120a8SThéo Degioanni if (!memrefType.hasStaticShape() ||
3110bf120a8SThéo Degioanni memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
3120bf120a8SThéo Degioanni memrefType.getNumElements() == 1)
3130bf120a8SThéo Degioanni return {};
3140bf120a8SThéo Degioanni
3150bf120a8SThéo Degioanni DenseMap<Attribute, Type> destructured;
3160bf120a8SThéo Degioanni walkIndicesAsAttr(
3170bf120a8SThéo Degioanni memrefType.getContext(), memrefType.getShape(), [&](Attribute index) {
3180bf120a8SThéo Degioanni destructured.insert({index, memrefType.getElementType()});
3190bf120a8SThéo Degioanni });
3200bf120a8SThéo Degioanni
3210bf120a8SThéo Degioanni return destructured;
3220bf120a8SThéo Degioanni }
3230bf120a8SThéo Degioanni
getTypeAtIndex__anon8a6d594d0311::MemRefDestructurableTypeExternalModel3240bf120a8SThéo Degioanni Type getTypeAtIndex(Type type, Attribute index) const {
32568f58812STres Popp auto memrefType = llvm::cast<MemRefType>(type);
32668f58812STres Popp auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
3270bf120a8SThéo Degioanni if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
3280bf120a8SThéo Degioanni return {};
3290bf120a8SThéo Degioanni
3300bf120a8SThéo Degioanni Type indexType = IndexType::get(memrefType.getContext());
3310bf120a8SThéo Degioanni for (const auto &[coordAttr, dimSize] :
3320bf120a8SThéo Degioanni llvm::zip(coordArrAttr, memrefType.getShape())) {
33368f58812STres Popp auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
3340bf120a8SThéo Degioanni if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
3350bf120a8SThéo Degioanni coord.getInt() >= dimSize)
3360bf120a8SThéo Degioanni return {};
3370bf120a8SThéo Degioanni }
3380bf120a8SThéo Degioanni
3390bf120a8SThéo Degioanni return memrefType.getElementType();
3400bf120a8SThéo Degioanni }
3410bf120a8SThéo Degioanni };
3420bf120a8SThéo Degioanni
3430bf120a8SThéo Degioanni } // namespace
3440bf120a8SThéo Degioanni
3450bf120a8SThéo Degioanni //===----------------------------------------------------------------------===//
3460bf120a8SThéo Degioanni // Register external models
3470bf120a8SThéo Degioanni //===----------------------------------------------------------------------===//
3480bf120a8SThéo Degioanni
registerMemorySlotExternalModels(DialectRegistry & registry)3490bf120a8SThéo Degioanni void mlir::memref::registerMemorySlotExternalModels(DialectRegistry ®istry) {
3500bf120a8SThéo Degioanni registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
3510bf120a8SThéo Degioanni MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
3520bf120a8SThéo Degioanni });
3530bf120a8SThéo Degioanni }
354