xref: /llvm-project/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Mem2Reg-related interfaces for MemRef dialect
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/IR/BuiltinDialect.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/Value.h"
21 #include "mlir/Interfaces/InferTypeOpInterface.h"
22 #include "mlir/Interfaces/MemorySlotInterfaces.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/ErrorHandling.h"
26 
27 using namespace mlir;
28 
29 //===----------------------------------------------------------------------===//
30 //  Utilities
31 //===----------------------------------------------------------------------===//
32 
33 /// Walks over the indices of the elements of a tensor of a given `shape` by
34 /// updating `index` in place to the next index. This returns failure if the
35 /// provided index was the last index.
nextIndex(ArrayRef<int64_t> shape,MutableArrayRef<int64_t> index)36 static LogicalResult nextIndex(ArrayRef<int64_t> shape,
37                                MutableArrayRef<int64_t> index) {
38   for (size_t i = 0; i < shape.size(); ++i) {
39     index[i]++;
40     if (index[i] < shape[i])
41       return success();
42     index[i] = 0;
43   }
44   return failure();
45 }
46 
47 /// Calls `walker` for each index within a tensor of a given `shape`, providing
48 /// the index as an array attribute of the coordinates.
49 template <typename CallableT>
walkIndicesAsAttr(MLIRContext * ctx,ArrayRef<int64_t> shape,CallableT && walker)50 static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
51                               CallableT &&walker) {
52   Type indexType = IndexType::get(ctx);
53   SmallVector<int64_t> shapeIter(shape.size(), 0);
54   do {
55     SmallVector<Attribute> indexAsAttr;
56     for (int64_t dim : shapeIter)
57       indexAsAttr.push_back(IntegerAttr::get(indexType, dim));
58     walker(ArrayAttr::get(ctx, indexAsAttr));
59   } while (succeeded(nextIndex(shape, shapeIter)));
60 }
61 
62 //===----------------------------------------------------------------------===//
63 //  Interfaces for AllocaOp
64 //===----------------------------------------------------------------------===//
65 
isSupportedElementType(Type type)66 static bool isSupportedElementType(Type type) {
67   return llvm::isa<MemRefType>(type) ||
68          OpBuilder(type.getContext()).getZeroAttr(type);
69 }
70 
getPromotableSlots()71 SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
72   MemRefType type = getType();
73   if (!isSupportedElementType(type.getElementType()))
74     return {};
75   if (!type.hasStaticShape())
76     return {};
77   // Make sure the memref contains only a single element.
78   if (type.getNumElements() != 1)
79     return {};
80 
81   return {MemorySlot{getResult(), type.getElementType()}};
82 }
83 
getDefaultValue(const MemorySlot & slot,OpBuilder & builder)84 Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
85                                         OpBuilder &builder) {
86   assert(isSupportedElementType(slot.elemType));
87   // TODO: support more types.
88   return TypeSwitch<Type, Value>(slot.elemType)
89       .Case([&](MemRefType t) {
90         return builder.create<memref::AllocaOp>(getLoc(), t);
91       })
92       .Default([&](Type t) {
93         return builder.create<arith::ConstantOp>(getLoc(), t,
94                                                  builder.getZeroAttr(t));
95       });
96 }
97 
98 std::optional<PromotableAllocationOpInterface>
handlePromotionComplete(const MemorySlot & slot,Value defaultValue,OpBuilder & builder)99 memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
100                                           Value defaultValue,
101                                           OpBuilder &builder) {
102   if (defaultValue.use_empty())
103     defaultValue.getDefiningOp()->erase();
104   this->erase();
105   return std::nullopt;
106 }
107 
handleBlockArgument(const MemorySlot & slot,BlockArgument argument,OpBuilder & builder)108 void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
109                                            BlockArgument argument,
110                                            OpBuilder &builder) {}
111 
112 SmallVector<DestructurableMemorySlot>
getDestructurableSlots()113 memref::AllocaOp::getDestructurableSlots() {
114   MemRefType memrefType = getType();
115   auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
116   if (!destructurable)
117     return {};
118 
119   std::optional<DenseMap<Attribute, Type>> destructuredType =
120       destructurable.getSubelementIndexMap();
121   if (!destructuredType)
122     return {};
123 
124   return {
125       DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
126 }
127 
destructure(const DestructurableMemorySlot & slot,const SmallPtrSetImpl<Attribute> & usedIndices,OpBuilder & builder,SmallVectorImpl<DestructurableAllocationOpInterface> & newAllocators)128 DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
129     const DestructurableMemorySlot &slot,
130     const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
131     SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
132   builder.setInsertionPointAfter(*this);
133 
134   DenseMap<Attribute, MemorySlot> slotMap;
135 
136   auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
137   for (Attribute usedIndex : usedIndices) {
138     Type elemType = memrefType.getTypeAtIndex(usedIndex);
139     MemRefType elemPtr = MemRefType::get({}, elemType);
140     auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr);
141     newAllocators.push_back(subAlloca);
142     slotMap.try_emplace<MemorySlot>(usedIndex,
143                                     {subAlloca.getResult(), elemType});
144   }
145 
146   return slotMap;
147 }
148 
149 std::optional<DestructurableAllocationOpInterface>
handleDestructuringComplete(const DestructurableMemorySlot & slot,OpBuilder & builder)150 memref::AllocaOp::handleDestructuringComplete(
151     const DestructurableMemorySlot &slot, OpBuilder &builder) {
152   assert(slot.ptr == getResult());
153   this->erase();
154   return std::nullopt;
155 }
156 
157 //===----------------------------------------------------------------------===//
158 //  Interfaces for LoadOp/StoreOp
159 //===----------------------------------------------------------------------===//
160 
loadsFrom(const MemorySlot & slot)161 bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
162   return getMemRef() == slot.ptr;
163 }
164 
storesTo(const MemorySlot & slot)165 bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
166 
getStored(const MemorySlot & slot,OpBuilder & builder,Value reachingDef,const DataLayout & dataLayout)167 Value memref::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
168                                 Value reachingDef,
169                                 const DataLayout &dataLayout) {
170   llvm_unreachable("getStored should not be called on LoadOp");
171 }
172 
canUsesBeRemoved(const MemorySlot & slot,const SmallPtrSetImpl<OpOperand * > & blockingUses,SmallVectorImpl<OpOperand * > & newBlockingUses,const DataLayout & dataLayout)173 bool memref::LoadOp::canUsesBeRemoved(
174     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
175     SmallVectorImpl<OpOperand *> &newBlockingUses,
176     const DataLayout &dataLayout) {
177   if (blockingUses.size() != 1)
178     return false;
179   Value blockingUse = (*blockingUses.begin())->get();
180   return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
181          getResult().getType() == slot.elemType;
182 }
183 
removeBlockingUses(const MemorySlot & slot,const SmallPtrSetImpl<OpOperand * > & blockingUses,OpBuilder & builder,Value reachingDefinition,const DataLayout & dataLayout)184 DeletionKind memref::LoadOp::removeBlockingUses(
185     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
186     OpBuilder &builder, Value reachingDefinition,
187     const DataLayout &dataLayout) {
188   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
189   // pointer.
190   getResult().replaceAllUsesWith(reachingDefinition);
191   return DeletionKind::Delete;
192 }
193 
194 /// Returns the index of a memref in attribute form, given its indices. Returns
195 /// a null pointer if whether the indices form a valid index for the provided
196 /// MemRefType cannot be computed. The indices must come from a valid memref
197 /// StoreOp or LoadOp.
getAttributeIndexFromIndexOperands(MLIRContext * ctx,ValueRange indices,MemRefType memrefType)198 static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx,
199                                                     ValueRange indices,
200                                                     MemRefType memrefType) {
201   SmallVector<Attribute> index;
202   for (auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
203     IntegerAttr coordAttr;
204     if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
205       return {};
206     // MemRefType shape dimensions are always positive (checked by verifier).
207     std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
208     if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
209       return {};
210     index.push_back(coordAttr);
211   }
212   return ArrayAttr::get(ctx, index);
213 }
214 
canRewire(const DestructurableMemorySlot & slot,SmallPtrSetImpl<Attribute> & usedIndices,SmallVectorImpl<MemorySlot> & mustBeSafelyUsed,const DataLayout & dataLayout)215 bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
216                                SmallPtrSetImpl<Attribute> &usedIndices,
217                                SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
218                                const DataLayout &dataLayout) {
219   if (slot.ptr != getMemRef())
220     return false;
221   Attribute index = getAttributeIndexFromIndexOperands(
222       getContext(), getIndices(), getMemRefType());
223   if (!index)
224     return false;
225   usedIndices.insert(index);
226   return true;
227 }
228 
rewire(const DestructurableMemorySlot & slot,DenseMap<Attribute,MemorySlot> & subslots,OpBuilder & builder,const DataLayout & dataLayout)229 DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
230                                     DenseMap<Attribute, MemorySlot> &subslots,
231                                     OpBuilder &builder,
232                                     const DataLayout &dataLayout) {
233   Attribute index = getAttributeIndexFromIndexOperands(
234       getContext(), getIndices(), getMemRefType());
235   const MemorySlot &memorySlot = subslots.at(index);
236   setMemRef(memorySlot.ptr);
237   getIndicesMutable().clear();
238   return DeletionKind::Keep;
239 }
240 
loadsFrom(const MemorySlot & slot)241 bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
242 
storesTo(const MemorySlot & slot)243 bool memref::StoreOp::storesTo(const MemorySlot &slot) {
244   return getMemRef() == slot.ptr;
245 }
246 
getStored(const MemorySlot & slot,OpBuilder & builder,Value reachingDef,const DataLayout & dataLayout)247 Value memref::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder,
248                                  Value reachingDef,
249                                  const DataLayout &dataLayout) {
250   return getValue();
251 }
252 
canUsesBeRemoved(const MemorySlot & slot,const SmallPtrSetImpl<OpOperand * > & blockingUses,SmallVectorImpl<OpOperand * > & newBlockingUses,const DataLayout & dataLayout)253 bool memref::StoreOp::canUsesBeRemoved(
254     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
255     SmallVectorImpl<OpOperand *> &newBlockingUses,
256     const DataLayout &dataLayout) {
257   if (blockingUses.size() != 1)
258     return false;
259   Value blockingUse = (*blockingUses.begin())->get();
260   return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
261          getValue() != slot.ptr && getValue().getType() == slot.elemType;
262 }
263 
removeBlockingUses(const MemorySlot & slot,const SmallPtrSetImpl<OpOperand * > & blockingUses,OpBuilder & builder,Value reachingDefinition,const DataLayout & dataLayout)264 DeletionKind memref::StoreOp::removeBlockingUses(
265     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
266     OpBuilder &builder, Value reachingDefinition,
267     const DataLayout &dataLayout) {
268   return DeletionKind::Delete;
269 }
270 
canRewire(const DestructurableMemorySlot & slot,SmallPtrSetImpl<Attribute> & usedIndices,SmallVectorImpl<MemorySlot> & mustBeSafelyUsed,const DataLayout & dataLayout)271 bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
272                                 SmallPtrSetImpl<Attribute> &usedIndices,
273                                 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
274                                 const DataLayout &dataLayout) {
275   if (slot.ptr != getMemRef() || getValue() == slot.ptr)
276     return false;
277   Attribute index = getAttributeIndexFromIndexOperands(
278       getContext(), getIndices(), getMemRefType());
279   if (!index || !slot.subelementTypes.contains(index))
280     return false;
281   usedIndices.insert(index);
282   return true;
283 }
284 
rewire(const DestructurableMemorySlot & slot,DenseMap<Attribute,MemorySlot> & subslots,OpBuilder & builder,const DataLayout & dataLayout)285 DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
286                                      DenseMap<Attribute, MemorySlot> &subslots,
287                                      OpBuilder &builder,
288                                      const DataLayout &dataLayout) {
289   Attribute index = getAttributeIndexFromIndexOperands(
290       getContext(), getIndices(), getMemRefType());
291   const MemorySlot &memorySlot = subslots.at(index);
292   setMemRef(memorySlot.ptr);
293   getIndicesMutable().clear();
294   return DeletionKind::Keep;
295 }
296 
297 //===----------------------------------------------------------------------===//
298 //  Interfaces for destructurable types
299 //===----------------------------------------------------------------------===//
300 
301 namespace {
302 
303 struct MemRefDestructurableTypeExternalModel
304     : public DestructurableTypeInterface::ExternalModel<
305           MemRefDestructurableTypeExternalModel, MemRefType> {
306   std::optional<DenseMap<Attribute, Type>>
getSubelementIndexMap__anon8a6d594d0311::MemRefDestructurableTypeExternalModel307   getSubelementIndexMap(Type type) const {
308     auto memrefType = llvm::cast<MemRefType>(type);
309     constexpr int64_t maxMemrefSizeForDestructuring = 16;
310     if (!memrefType.hasStaticShape() ||
311         memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
312         memrefType.getNumElements() == 1)
313       return {};
314 
315     DenseMap<Attribute, Type> destructured;
316     walkIndicesAsAttr(
317         memrefType.getContext(), memrefType.getShape(), [&](Attribute index) {
318           destructured.insert({index, memrefType.getElementType()});
319         });
320 
321     return destructured;
322   }
323 
getTypeAtIndex__anon8a6d594d0311::MemRefDestructurableTypeExternalModel324   Type getTypeAtIndex(Type type, Attribute index) const {
325     auto memrefType = llvm::cast<MemRefType>(type);
326     auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
327     if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
328       return {};
329 
330     Type indexType = IndexType::get(memrefType.getContext());
331     for (const auto &[coordAttr, dimSize] :
332          llvm::zip(coordArrAttr, memrefType.getShape())) {
333       auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
334       if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
335           coord.getInt() >= dimSize)
336         return {};
337     }
338 
339     return memrefType.getElementType();
340   }
341 };
342 
343 } // namespace
344 
345 //===----------------------------------------------------------------------===//
346 //  Register external models
347 //===----------------------------------------------------------------------===//
348 
registerMemorySlotExternalModels(DialectRegistry & registry)349 void mlir::memref::registerMemorySlotExternalModels(DialectRegistry &registry) {
350   registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
351     MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
352   });
353 }
354