xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===- SparseStorageSpecifierToLLVM.cpp - convert specifier to llvm -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Utils/CodegenUtils.h"
10 
11 #include "mlir/Conversion/LLVMCommon/StructBuilder.h"
12 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
13 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
14 
15 #include <optional>
16 
17 using namespace mlir;
18 using namespace sparse_tensor;
19 
20 namespace {
21 
22 //===----------------------------------------------------------------------===//
23 // Helper methods.
24 //===----------------------------------------------------------------------===//
25 
26 static SmallVector<Type, 4> getSpecifierFields(StorageSpecifierType tp) {
27   MLIRContext *ctx = tp.getContext();
28   auto enc = tp.getEncoding();
29   const Level lvlRank = enc.getLvlRank();
30 
31   SmallVector<Type, 4> result;
32   // TODO: how can we get the lowering type for index type in the later pipeline
33   // to be consistent? LLVM::StructureType does not allow index fields.
34   auto sizeType = IntegerType::get(tp.getContext(), 64);
35   auto lvlSizes = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
36   auto memSizes = LLVM::LLVMArrayType::get(ctx, sizeType,
37                                            getNumDataFieldsFromEncoding(enc));
38   result.push_back(lvlSizes);
39   result.push_back(memSizes);
40 
41   if (enc.isSlice()) {
42     // Extra fields are required for the slice information.
43     auto dimOffset = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
44     auto dimStride = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
45 
46     result.push_back(dimOffset);
47     result.push_back(dimStride);
48   }
49 
50   return result;
51 }
52 
53 static Type convertSpecifier(StorageSpecifierType tp) {
54   return LLVM::LLVMStructType::getLiteral(tp.getContext(),
55                                           getSpecifierFields(tp));
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // Specifier struct builder.
60 //===----------------------------------------------------------------------===//
61 
62 constexpr uint64_t kLvlSizePosInSpecifier = 0;
63 constexpr uint64_t kMemSizePosInSpecifier = 1;
64 constexpr uint64_t kDimOffsetPosInSpecifier = 2;
65 constexpr uint64_t kDimStridePosInSpecifier = 3;
66 
67 class SpecifierStructBuilder : public StructBuilder {
68 private:
69   Value extractField(OpBuilder &builder, Location loc,
70                      ArrayRef<int64_t> indices) const {
71     return genCast(builder, loc,
72                    builder.create<LLVM::ExtractValueOp>(loc, value, indices),
73                    builder.getIndexType());
74   }
75 
76   void insertField(OpBuilder &builder, Location loc, ArrayRef<int64_t> indices,
77                    Value v) {
78     value = builder.create<LLVM::InsertValueOp>(
79         loc, value, genCast(builder, loc, v, builder.getIntegerType(64)),
80         indices);
81   }
82 
83 public:
84   explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
85     assert(value);
86   }
87 
88   // Undef value for dimension sizes, all zero value for memory sizes.
89   static Value getInitValue(OpBuilder &builder, Location loc, Type structType,
90                             Value source);
91 
92   Value lvlSize(OpBuilder &builder, Location loc, Level lvl) const;
93   void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value size);
94 
95   Value dimOffset(OpBuilder &builder, Location loc, Dimension dim) const;
96   void setDimOffset(OpBuilder &builder, Location loc, Dimension dim,
97                     Value size);
98 
99   Value dimStride(OpBuilder &builder, Location loc, Dimension dim) const;
100   void setDimStride(OpBuilder &builder, Location loc, Dimension dim,
101                     Value size);
102 
103   Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx) const;
104   void setMemSize(OpBuilder &builder, Location loc, FieldIndex fidx,
105                   Value size);
106 
107   Value memSizeArray(OpBuilder &builder, Location loc) const;
108   void setMemSizeArray(OpBuilder &builder, Location loc, Value array);
109 };
110 
111 Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
112                                            Type structType, Value source) {
113   Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
114   SpecifierStructBuilder md(metaData);
115   if (!source) {
116     auto memSizeArrayType =
117         cast<LLVM::LLVMArrayType>(cast<LLVM::LLVMStructType>(structType)
118                                       .getBody()[kMemSizePosInSpecifier]);
119 
120     Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
121     // Fill memSizes array with zero.
122     for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
123       md.setMemSize(builder, loc, i, zero);
124   } else {
125     // We copy non-slice information (memory sizes array) from source
126     SpecifierStructBuilder sourceMd(source);
127     md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc));
128   }
129   return md;
130 }
131 
132 /// Builds IR extracting the pos-th offset from the descriptor.
133 Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc,
134                                         Dimension dim) const {
135   return extractField(
136       builder, loc,
137       ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)});
138 }
139 
140 /// Builds IR inserting the pos-th offset into the descriptor.
141 void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc,
142                                           Dimension dim, Value size) {
143   insertField(
144       builder, loc,
145       ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)},
146       size);
147 }
148 
149 /// Builds IR extracting the `lvl`-th level-size from the descriptor.
150 Value SpecifierStructBuilder::lvlSize(OpBuilder &builder, Location loc,
151                                       Level lvl) const {
152   // This static_cast makes the narrowing of `lvl` explicit, as required
153   // by the braces notation for the ctor.
154   return extractField(
155       builder, loc,
156       ArrayRef<int64_t>{kLvlSizePosInSpecifier, static_cast<int64_t>(lvl)});
157 }
158 
159 /// Builds IR inserting the `lvl`-th level-size into the descriptor.
160 void SpecifierStructBuilder::setLvlSize(OpBuilder &builder, Location loc,
161                                         Level lvl, Value size) {
162   // This static_cast makes the narrowing of `lvl` explicit, as required
163   // by the braces notation for the ctor.
164   insertField(
165       builder, loc,
166       ArrayRef<int64_t>{kLvlSizePosInSpecifier, static_cast<int64_t>(lvl)},
167       size);
168 }
169 
170 /// Builds IR extracting the pos-th stride from the descriptor.
171 Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc,
172                                         Dimension dim) const {
173   return extractField(
174       builder, loc,
175       ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)});
176 }
177 
178 /// Builds IR inserting the pos-th stride into the descriptor.
179 void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc,
180                                           Dimension dim, Value size) {
181   insertField(
182       builder, loc,
183       ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)},
184       size);
185 }
186 
187 /// Builds IR extracting the pos-th memory size into the descriptor.
188 Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
189                                       FieldIndex fidx) const {
190   return extractField(
191       builder, loc,
192       ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)});
193 }
194 
195 /// Builds IR inserting the `fidx`-th memory-size into the descriptor.
196 void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
197                                         FieldIndex fidx, Value size) {
198   insertField(
199       builder, loc,
200       ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)},
201       size);
202 }
203 
204 /// Builds IR extracting the memory size array from the descriptor.
205 Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder,
206                                            Location loc) const {
207   return builder.create<LLVM::ExtractValueOp>(loc, value,
208                                               kMemSizePosInSpecifier);
209 }
210 
211 /// Builds IR inserting the memory size array into the descriptor.
212 void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc,
213                                              Value array) {
214   value = builder.create<LLVM::InsertValueOp>(loc, value, array,
215                                               kMemSizePosInSpecifier);
216 }
217 
218 } // namespace
219 
220 //===----------------------------------------------------------------------===//
221 // The sparse storage specifier type converter (defined in Passes.h).
222 //===----------------------------------------------------------------------===//
223 
224 StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() {
225   addConversion([](Type type) { return type; });
226   addConversion(convertSpecifier);
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // Storage specifier conversion rules.
231 //===----------------------------------------------------------------------===//
232 
233 template <typename Base, typename SourceOp>
234 class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
235 public:
236   using OpAdaptor = typename SourceOp::Adaptor;
237   using OpConversionPattern<SourceOp>::OpConversionPattern;
238 
239   LogicalResult
240   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
241                   ConversionPatternRewriter &rewriter) const override {
242     SpecifierStructBuilder spec(adaptor.getSpecifier());
243     switch (op.getSpecifierKind()) {
244     case StorageSpecifierKind::LvlSize: {
245       Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel()));
246       rewriter.replaceOp(op, v);
247       return success();
248     }
249     case StorageSpecifierKind::DimOffset: {
250       Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel()));
251       rewriter.replaceOp(op, v);
252       return success();
253     }
254     case StorageSpecifierKind::DimStride: {
255       Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel()));
256       rewriter.replaceOp(op, v);
257       return success();
258     }
259     case StorageSpecifierKind::CrdMemSize:
260     case StorageSpecifierKind::PosMemSize:
261     case StorageSpecifierKind::ValMemSize: {
262       auto enc = op.getSpecifier().getType().getEncoding();
263       StorageLayout layout(enc);
264       std::optional<unsigned> lvl;
265       if (op.getLevel())
266         lvl = (*op.getLevel());
267       unsigned idx =
268           layout.getMemRefFieldIndex(toFieldKind(op.getSpecifierKind()), lvl);
269       Value v = Base::onMemSize(rewriter, op, spec, idx);
270       rewriter.replaceOp(op, v);
271       return success();
272     }
273     }
274     llvm_unreachable("unrecognized specifer kind");
275   }
276 };
277 
278 struct StorageSpecifierSetOpConverter
279     : public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter,
280                                               SetStorageSpecifierOp> {
281   using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
282 
283   static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op,
284                          SpecifierStructBuilder &spec, Level lvl) {
285     spec.setLvlSize(builder, op.getLoc(), lvl, op.getValue());
286     return spec;
287   }
288 
289   static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op,
290                            SpecifierStructBuilder &spec, Dimension d) {
291     spec.setDimOffset(builder, op.getLoc(), d, op.getValue());
292     return spec;
293   }
294 
295   static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op,
296                            SpecifierStructBuilder &spec, Dimension d) {
297     spec.setDimStride(builder, op.getLoc(), d, op.getValue());
298     return spec;
299   }
300 
301   static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op,
302                          SpecifierStructBuilder &spec, FieldIndex fidx) {
303     spec.setMemSize(builder, op.getLoc(), fidx, op.getValue());
304     return spec;
305   }
306 };
307 
308 struct StorageSpecifierGetOpConverter
309     : public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter,
310                                               GetStorageSpecifierOp> {
311   using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
312 
313   static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op,
314                          SpecifierStructBuilder &spec, Level lvl) {
315     return spec.lvlSize(builder, op.getLoc(), lvl);
316   }
317 
318   static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op,
319                            const SpecifierStructBuilder &spec, Dimension d) {
320     return spec.dimOffset(builder, op.getLoc(), d);
321   }
322 
323   static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op,
324                            const SpecifierStructBuilder &spec, Dimension d) {
325     return spec.dimStride(builder, op.getLoc(), d);
326   }
327 
328   static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op,
329                          SpecifierStructBuilder &spec, FieldIndex fidx) {
330     return spec.memSize(builder, op.getLoc(), fidx);
331   }
332 };
333 
334 struct StorageSpecifierInitOpConverter
335     : public OpConversionPattern<StorageSpecifierInitOp> {
336 public:
337   using OpConversionPattern::OpConversionPattern;
338   LogicalResult
339   matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor,
340                   ConversionPatternRewriter &rewriter) const override {
341     Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
342     rewriter.replaceOp(
343         op, SpecifierStructBuilder::getInitValue(
344                 rewriter, op.getLoc(), llvmType, adaptor.getSource()));
345     return success();
346   }
347 };
348 
349 //===----------------------------------------------------------------------===//
350 // Public method for populating conversion rules.
351 //===----------------------------------------------------------------------===//
352 
353 void mlir::populateStorageSpecifierToLLVMPatterns(
354     const TypeConverter &converter, RewritePatternSet &patterns) {
355   patterns.add<StorageSpecifierGetOpConverter, StorageSpecifierSetOpConverter,
356                StorageSpecifierInitOpConverter>(converter,
357                                                 patterns.getContext());
358 }
359