1 //===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Mem2Reg-related interfaces for MemRef dialect
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/IR/BuiltinDialect.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/Value.h"
21 #include "mlir/Interfaces/InferTypeOpInterface.h"
22 #include "mlir/Interfaces/MemorySlotInterfaces.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/ErrorHandling.h"
26
27 using namespace mlir;
28
29 //===----------------------------------------------------------------------===//
30 // Utilities
31 //===----------------------------------------------------------------------===//
32
33 /// Walks over the indices of the elements of a tensor of a given `shape` by
34 /// updating `index` in place to the next index. This returns failure if the
35 /// provided index was the last index.
nextIndex(ArrayRef<int64_t> shape,MutableArrayRef<int64_t> index)36 static LogicalResult nextIndex(ArrayRef<int64_t> shape,
37 MutableArrayRef<int64_t> index) {
38 for (size_t i = 0; i < shape.size(); ++i) {
39 index[i]++;
40 if (index[i] < shape[i])
41 return success();
42 index[i] = 0;
43 }
44 return failure();
45 }
46
47 /// Calls `walker` for each index within a tensor of a given `shape`, providing
48 /// the index as an array attribute of the coordinates.
49 template <typename CallableT>
walkIndicesAsAttr(MLIRContext * ctx,ArrayRef<int64_t> shape,CallableT && walker)50 static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
51 CallableT &&walker) {
52 Type indexType = IndexType::get(ctx);
53 SmallVector<int64_t> shapeIter(shape.size(), 0);
54 do {
55 SmallVector<Attribute> indexAsAttr;
56 for (int64_t dim : shapeIter)
57 indexAsAttr.push_back(IntegerAttr::get(indexType, dim));
58 walker(ArrayAttr::get(ctx, indexAsAttr));
59 } while (succeeded(nextIndex(shape, shapeIter)));
60 }
61
62 //===----------------------------------------------------------------------===//
63 // Interfaces for AllocaOp
64 //===----------------------------------------------------------------------===//
65
isSupportedElementType(Type type)66 static bool isSupportedElementType(Type type) {
67 return llvm::isa<MemRefType>(type) ||
68 OpBuilder(type.getContext()).getZeroAttr(type);
69 }
70
getPromotableSlots()71 SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
72 MemRefType type = getType();
73 if (!isSupportedElementType(type.getElementType()))
74 return {};
75 if (!type.hasStaticShape())
76 return {};
77 // Make sure the memref contains only a single element.
78 if (type.getNumElements() != 1)
79 return {};
80
81 return {MemorySlot{getResult(), type.getElementType()}};
82 }
83
getDefaultValue(const MemorySlot & slot,OpBuilder & builder)84 Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
85 OpBuilder &builder) {
86 assert(isSupportedElementType(slot.elemType));
87 // TODO: support more types.
88 return TypeSwitch<Type, Value>(slot.elemType)
89 .Case([&](MemRefType t) {
90 return builder.create<memref::AllocaOp>(getLoc(), t);
91 })
92 .Default([&](Type t) {
93 return builder.create<arith::ConstantOp>(getLoc(), t,
94 builder.getZeroAttr(t));
95 });
96 }
97
98 std::optional<PromotableAllocationOpInterface>
handlePromotionComplete(const MemorySlot & slot,Value defaultValue,OpBuilder & builder)99 memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
100 Value defaultValue,
101 OpBuilder &builder) {
102 if (defaultValue.use_empty())
103 defaultValue.getDefiningOp()->erase();
104 this->erase();
105 return std::nullopt;
106 }
107
handleBlockArgument(const MemorySlot & slot,BlockArgument argument,OpBuilder & builder)108 void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
109 BlockArgument argument,
110 OpBuilder &builder) {}
111
112 SmallVector<DestructurableMemorySlot>
getDestructurableSlots()113 memref::AllocaOp::getDestructurableSlots() {
114 MemRefType memrefType = getType();
115 auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
116 if (!destructurable)
117 return {};
118
119 std::optional<DenseMap<Attribute, Type>> destructuredType =
120 destructurable.getSubelementIndexMap();
121 if (!destructuredType)
122 return {};
123
124 return {
125 DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
126 }
127
destructure(const DestructurableMemorySlot & slot,const SmallPtrSetImpl<Attribute> & usedIndices,OpBuilder & builder,SmallVectorImpl<DestructurableAllocationOpInterface> & newAllocators)128 DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
129 const DestructurableMemorySlot &slot,
130 const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
131 SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
132 builder.setInsertionPointAfter(*this);
133
134 DenseMap<Attribute, MemorySlot> slotMap;
135
136 auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
137 for (Attribute usedIndex : usedIndices) {
138 Type elemType = memrefType.getTypeAtIndex(usedIndex);
139 MemRefType elemPtr = MemRefType::get({}, elemType);
140 auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr);
141 newAllocators.push_back(subAlloca);
142 slotMap.try_emplace<MemorySlot>(usedIndex,
143 {subAlloca.getResult(), elemType});
144 }
145
146 return slotMap;
147 }
148
149 std::optional<DestructurableAllocationOpInterface>
handleDestructuringComplete(const DestructurableMemorySlot & slot,OpBuilder & builder)150 memref::AllocaOp::handleDestructuringComplete(
151 const DestructurableMemorySlot &slot, OpBuilder &builder) {
152 assert(slot.ptr == getResult());
153 this->erase();
154 return std::nullopt;
155 }
156
157 //===----------------------------------------------------------------------===//
158 // Interfaces for LoadOp/StoreOp
159 //===----------------------------------------------------------------------===//
160
loadsFrom(const MemorySlot & slot)161 bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
162 return getMemRef() == slot.ptr;
163 }
164
storesTo(const MemorySlot & slot)165 bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
166
getStored(const MemorySlot & slot,OpBuilder & builder,Value reachingDef,const DataLayout & dataLayout)167 Value memref::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
168 Value reachingDef,
169 const DataLayout &dataLayout) {
170 llvm_unreachable("getStored should not be called on LoadOp");
171 }
172
canUsesBeRemoved(const MemorySlot & slot,const SmallPtrSetImpl<OpOperand * > & blockingUses,SmallVectorImpl<OpOperand * > & newBlockingUses,const DataLayout & dataLayout)173 bool memref::LoadOp::canUsesBeRemoved(
174 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
175 SmallVectorImpl<OpOperand *> &newBlockingUses,
176 const DataLayout &dataLayout) {
177 if (blockingUses.size() != 1)
178 return false;
179 Value blockingUse = (*blockingUses.begin())->get();
180 return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
181 getResult().getType() == slot.elemType;
182 }
183
removeBlockingUses(const MemorySlot & slot,const SmallPtrSetImpl<OpOperand * > & blockingUses,OpBuilder & builder,Value reachingDefinition,const DataLayout & dataLayout)184 DeletionKind memref::LoadOp::removeBlockingUses(
185 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
186 OpBuilder &builder, Value reachingDefinition,
187 const DataLayout &dataLayout) {
188 // `canUsesBeRemoved` checked this blocking use must be the loaded slot
189 // pointer.
190 getResult().replaceAllUsesWith(reachingDefinition);
191 return DeletionKind::Delete;
192 }
193
194 /// Returns the index of a memref in attribute form, given its indices. Returns
195 /// a null pointer if whether the indices form a valid index for the provided
196 /// MemRefType cannot be computed. The indices must come from a valid memref
197 /// StoreOp or LoadOp.
getAttributeIndexFromIndexOperands(MLIRContext * ctx,ValueRange indices,MemRefType memrefType)198 static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx,
199 ValueRange indices,
200 MemRefType memrefType) {
201 SmallVector<Attribute> index;
202 for (auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
203 IntegerAttr coordAttr;
204 if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
205 return {};
206 // MemRefType shape dimensions are always positive (checked by verifier).
207 std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
208 if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
209 return {};
210 index.push_back(coordAttr);
211 }
212 return ArrayAttr::get(ctx, index);
213 }
214
canRewire(const DestructurableMemorySlot & slot,SmallPtrSetImpl<Attribute> & usedIndices,SmallVectorImpl<MemorySlot> & mustBeSafelyUsed,const DataLayout & dataLayout)215 bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
216 SmallPtrSetImpl<Attribute> &usedIndices,
217 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
218 const DataLayout &dataLayout) {
219 if (slot.ptr != getMemRef())
220 return false;
221 Attribute index = getAttributeIndexFromIndexOperands(
222 getContext(), getIndices(), getMemRefType());
223 if (!index)
224 return false;
225 usedIndices.insert(index);
226 return true;
227 }
228
rewire(const DestructurableMemorySlot & slot,DenseMap<Attribute,MemorySlot> & subslots,OpBuilder & builder,const DataLayout & dataLayout)229 DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
230 DenseMap<Attribute, MemorySlot> &subslots,
231 OpBuilder &builder,
232 const DataLayout &dataLayout) {
233 Attribute index = getAttributeIndexFromIndexOperands(
234 getContext(), getIndices(), getMemRefType());
235 const MemorySlot &memorySlot = subslots.at(index);
236 setMemRef(memorySlot.ptr);
237 getIndicesMutable().clear();
238 return DeletionKind::Keep;
239 }
240
loadsFrom(const MemorySlot & slot)241 bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
242
storesTo(const MemorySlot & slot)243 bool memref::StoreOp::storesTo(const MemorySlot &slot) {
244 return getMemRef() == slot.ptr;
245 }
246
getStored(const MemorySlot & slot,OpBuilder & builder,Value reachingDef,const DataLayout & dataLayout)247 Value memref::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder,
248 Value reachingDef,
249 const DataLayout &dataLayout) {
250 return getValue();
251 }
252
canUsesBeRemoved(const MemorySlot & slot,const SmallPtrSetImpl<OpOperand * > & blockingUses,SmallVectorImpl<OpOperand * > & newBlockingUses,const DataLayout & dataLayout)253 bool memref::StoreOp::canUsesBeRemoved(
254 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
255 SmallVectorImpl<OpOperand *> &newBlockingUses,
256 const DataLayout &dataLayout) {
257 if (blockingUses.size() != 1)
258 return false;
259 Value blockingUse = (*blockingUses.begin())->get();
260 return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
261 getValue() != slot.ptr && getValue().getType() == slot.elemType;
262 }
263
removeBlockingUses(const MemorySlot & slot,const SmallPtrSetImpl<OpOperand * > & blockingUses,OpBuilder & builder,Value reachingDefinition,const DataLayout & dataLayout)264 DeletionKind memref::StoreOp::removeBlockingUses(
265 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
266 OpBuilder &builder, Value reachingDefinition,
267 const DataLayout &dataLayout) {
268 return DeletionKind::Delete;
269 }
270
canRewire(const DestructurableMemorySlot & slot,SmallPtrSetImpl<Attribute> & usedIndices,SmallVectorImpl<MemorySlot> & mustBeSafelyUsed,const DataLayout & dataLayout)271 bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
272 SmallPtrSetImpl<Attribute> &usedIndices,
273 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
274 const DataLayout &dataLayout) {
275 if (slot.ptr != getMemRef() || getValue() == slot.ptr)
276 return false;
277 Attribute index = getAttributeIndexFromIndexOperands(
278 getContext(), getIndices(), getMemRefType());
279 if (!index || !slot.subelementTypes.contains(index))
280 return false;
281 usedIndices.insert(index);
282 return true;
283 }
284
rewire(const DestructurableMemorySlot & slot,DenseMap<Attribute,MemorySlot> & subslots,OpBuilder & builder,const DataLayout & dataLayout)285 DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
286 DenseMap<Attribute, MemorySlot> &subslots,
287 OpBuilder &builder,
288 const DataLayout &dataLayout) {
289 Attribute index = getAttributeIndexFromIndexOperands(
290 getContext(), getIndices(), getMemRefType());
291 const MemorySlot &memorySlot = subslots.at(index);
292 setMemRef(memorySlot.ptr);
293 getIndicesMutable().clear();
294 return DeletionKind::Keep;
295 }
296
297 //===----------------------------------------------------------------------===//
298 // Interfaces for destructurable types
299 //===----------------------------------------------------------------------===//
300
301 namespace {
302
303 struct MemRefDestructurableTypeExternalModel
304 : public DestructurableTypeInterface::ExternalModel<
305 MemRefDestructurableTypeExternalModel, MemRefType> {
306 std::optional<DenseMap<Attribute, Type>>
getSubelementIndexMap__anon8a6d594d0311::MemRefDestructurableTypeExternalModel307 getSubelementIndexMap(Type type) const {
308 auto memrefType = llvm::cast<MemRefType>(type);
309 constexpr int64_t maxMemrefSizeForDestructuring = 16;
310 if (!memrefType.hasStaticShape() ||
311 memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
312 memrefType.getNumElements() == 1)
313 return {};
314
315 DenseMap<Attribute, Type> destructured;
316 walkIndicesAsAttr(
317 memrefType.getContext(), memrefType.getShape(), [&](Attribute index) {
318 destructured.insert({index, memrefType.getElementType()});
319 });
320
321 return destructured;
322 }
323
getTypeAtIndex__anon8a6d594d0311::MemRefDestructurableTypeExternalModel324 Type getTypeAtIndex(Type type, Attribute index) const {
325 auto memrefType = llvm::cast<MemRefType>(type);
326 auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
327 if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
328 return {};
329
330 Type indexType = IndexType::get(memrefType.getContext());
331 for (const auto &[coordAttr, dimSize] :
332 llvm::zip(coordArrAttr, memrefType.getShape())) {
333 auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
334 if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
335 coord.getInt() >= dimSize)
336 return {};
337 }
338
339 return memrefType.getElementType();
340 }
341 };
342
343 } // namespace
344
345 //===----------------------------------------------------------------------===//
346 // Register external models
347 //===----------------------------------------------------------------------===//
348
registerMemorySlotExternalModels(DialectRegistry & registry)349 void mlir::memref::registerMemorySlotExternalModels(DialectRegistry ®istry) {
350 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
351 MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
352 });
353 }
354