xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h (revision 204234a69c068032a1adac31f00b51f3b9efa778)
1 //===- CodegenUtils.h - Utilities for generating MLIR -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities for generating MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENUTILS_H_
14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENUTILS_H_
15 
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Complex/IR/Complex.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
23 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24 #include "mlir/IR/Builders.h"
25 
26 namespace mlir {
27 
28 class Location;
29 class Type;
30 class Value;
31 
32 namespace sparse_tensor {
33 
34 /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`,
35 /// `createFuncCall()`, and `replaceOpWithFuncCall()`.
36 enum class EmitCInterface : bool { Off = false, On = true };
37 
38 //===----------------------------------------------------------------------===//
39 // ExecutionEngine/SparseTensorUtils helper functions.
40 //===----------------------------------------------------------------------===//
41 
42 /// Converts an overhead storage bitwidth to its internal type-encoding.
43 OverheadType overheadTypeEncoding(unsigned width);
44 
45 /// Converts an overhead storage type to its internal type-encoding.
46 OverheadType overheadTypeEncoding(Type tp);
47 
48 /// Converts the internal type-encoding for overhead storage to an mlir::Type.
49 Type getOverheadType(Builder &builder, OverheadType ot);
50 
51 /// Returns the OverheadType for position overhead storage.
52 OverheadType posTypeEncoding(SparseTensorEncodingAttr enc);
53 
54 /// Returns the OverheadType for coordinate overhead storage.
55 OverheadType crdTypeEncoding(SparseTensorEncodingAttr enc);
56 
57 /// Convert OverheadType to its function-name suffix.
58 StringRef overheadTypeFunctionSuffix(OverheadType ot);
59 
60 /// Converts an overhead storage type to its function-name suffix.
61 StringRef overheadTypeFunctionSuffix(Type overheadTp);
62 
63 /// Converts a primary storage type to its internal type-encoding.
64 PrimaryType primaryTypeEncoding(Type elemTp);
65 
66 /// Convert PrimaryType to its function-name suffix.
67 StringRef primaryTypeFunctionSuffix(PrimaryType pt);
68 
69 /// Converts a primary storage type to its function-name suffix.
70 StringRef primaryTypeFunctionSuffix(Type elemTp);
71 
72 //===----------------------------------------------------------------------===//
73 // Misc code generators and utilities.
74 //===----------------------------------------------------------------------===//
75 
76 /// A helper class to simplify lowering operations with/without function calls.
77 template <class SubClass>
78 class FuncCallOrInlineGenerator {
79 public:
80   FuncCallOrInlineGenerator(TypeRange retTypes, ValueRange params, bool genCall)
81       : retTypes(retTypes), params(params), genCall(genCall) {}
82 
83   // The main API invoked by clients, which abstracts away the details of
84   // creating function calls from clients.
85   SmallVector<Value> genCallOrInline(OpBuilder &builder, Location loc) {
86     if (!genCall)
87       return genImplementation(retTypes, params, builder, loc);
88 
89     // Looks up the function.
90     std::string funcName = getMangledFuncName();
91     ModuleOp module = getParentOpOf<ModuleOp>(builder);
92     MLIRContext *context = module.getContext();
93     auto result = SymbolRefAttr::get(context, funcName);
94     auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
95 
96     if (!func) {
97       // Create the function if not already exist.
98       OpBuilder::InsertionGuard insertionGuard(builder);
99       builder.setInsertionPoint(getParentOpOf<func::FuncOp>(builder));
100       func = builder.create<func::FuncOp>(
101           loc, funcName,
102           FunctionType::get(context, params.getTypes(), retTypes));
103       func.setPrivate();
104       // Set the insertion point to the body of the function.
105       Block *entryBB = func.addEntryBlock();
106       builder.setInsertionPointToStart(entryBB);
107       ValueRange args = entryBB->getArguments();
108       // Delegates to user to generate the actually implementation.
109       SmallVector<Value> result =
110           genImplementation(retTypes, args, builder, loc);
111       builder.create<func::ReturnOp>(loc, result);
112     }
113     // Returns the CallOp result.
114     func::CallOp call = builder.create<func::CallOp>(loc, func, params);
115     return call.getResults();
116   }
117 
118 private:
119   template <class OpTp>
120   OpTp getParentOpOf(OpBuilder &builder) {
121     return builder.getInsertionBlock()->getParent()->getParentOfType<OpTp>();
122   }
123 
124   // CRTP: get the mangled function name (only called when genCall=true).
125   std::string getMangledFuncName() {
126     return static_cast<SubClass *>(this)->getMangledFuncName();
127   }
128 
129   // CRTP: Client implementation.
130   SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange params,
131                                        OpBuilder &builder, Location loc) {
132     return static_cast<SubClass *>(this)->genImplementation(retTypes, params,
133                                                             builder, loc);
134   }
135 
136 private:
137   TypeRange retTypes; // The types of all returned results
138   ValueRange params;  // The values of all input parameters
139   bool genCall;       // Should the implemetantion be wrapped in a function
140 };
141 
142 /// Add type casting between arith and index types when needed.
143 Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
144 
145 /// Add conversion from scalar to given type (possibly a 0-rank tensor).
146 Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
147                         Type dstTp);
148 
149 /// Generates a pointer/index load from the sparse storage scheme. Narrower
150 /// data types need to be zero extended before casting the value into the
151 /// index type used for looping and indexing.
152 Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s);
153 
154 /// Generates a 1-valued attribute of the given type.  This supports
155 /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
156 /// for unsupported types we raise `llvm_unreachable` rather than
157 /// returning a null attribute.
158 TypedAttr getOneAttr(Builder &builder, Type tp);
159 
160 /// Generates the comparison `v != 0` where `v` is of numeric type.
161 /// For floating types, we use the "unordered" comparator (i.e., returns
162 /// true if `v` is NaN).
163 Value genIsNonzero(OpBuilder &builder, Location loc, Value v);
164 
165 /// Computes the shape of destination tensor of a reshape operator. This is only
166 /// used when operands have dynamic shape. The shape of the destination is
167 /// stored into dstShape.
168 void genReshapeDstShape(OpBuilder &builder, Location loc,
169                         SmallVectorImpl<Value> &dstShape,
170                         ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape,
171                         ArrayRef<ReassociationIndices> reassociation);
172 
173 /// Reshape coordinates during a reshaping operation.
174 void reshapeCvs(OpBuilder &builder, Location loc,
175                 ArrayRef<ReassociationIndices> reassociation,
176                 ValueRange srcSizes, ValueRange srcCvs, // NOLINT
177                 ValueRange dstSizes, SmallVectorImpl<Value> &dstCvs);
178 
179 /// Returns a function reference (first hit also inserts into module). Sets
180 /// the "_emit_c_interface" on the function declaration when requested,
181 /// so that LLVM lowering generates a wrapper function that takes care
182 /// of ABI complications with passing in and returning MemRefs to C functions.
183 FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType,
184                           ValueRange operands, EmitCInterface emitCInterface);
185 
186 /// Creates a `CallOp` to the function reference returned by `getFunc()` in
187 /// the builder's module.
188 func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name,
189                             TypeRange resultType, ValueRange operands,
190                             EmitCInterface emitCInterface);
191 
192 /// Returns the equivalent of `void*` for opaque arguments to the
193 /// execution engine.
194 Type getOpaquePointerType(MLIRContext *ctx);
195 Type getOpaquePointerType(Builder &builder);
196 
197 /// Generates an uninitialized temporary buffer of the given size and
198 /// type, but returns it as type `memref<? x $tp>` (rather than as type
199 /// `memref<$sz x $tp>`).
200 Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp);
201 
202 /// Generates an uninitialized temporary buffer of the given size and
203 /// type, and returns it as type `memref<? x $tp>` (staticShape=false) or
204 /// `memref<$sz x $tp>` (staticShape=true).
205 Value genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp,
206                 bool staticShape = false);
207 
208 /// Generates an uninitialized temporary buffer with room for one value
209 /// of the given type, and returns the `memref<$tp>`.
210 Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp);
211 
212 /// Generates a temporary buffer, initializes it with the given contents,
213 /// and returns it as type `memref<? x $tp>` (rather than specifying the
214 /// size of the buffer).
215 Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values);
216 
217 /// Generates code to allocate a buffer of the given type, and zero
218 /// initialize it.  If the buffer type has any dynamic sizes, then the
219 /// `sizes` parameter should be as filled by sizesFromPtr(); that way
220 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr().
221 Value allocDenseTensor(OpBuilder &builder, Location loc,
222                        RankedTensorType tensorTp, ValueRange sizes);
223 
224 /// Generates code to deallocate a dense buffer.
225 void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer);
226 
227 /// Populates given sizes array from dense tensor or sparse tensor constant.
228 void sizesFromSrc(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
229                   Location loc, Value src);
230 
231 /// Scans to top of generated loop.
232 Operation *getTop(Operation *op);
233 
234 /// Iterate over a sparse constant, generates constantOp for value
235 /// and coordinates.  E.g.,
236 /// sparse<[ [0], [28], [31] ],
237 ///          [ (-5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] >
238 /// =>
239 /// %c1 = arith.constant 0
240 /// %v1 = complex.constant (5.13, 2.0)
241 /// callback({%c1}, %v1)
242 ///
243 /// %c2 = arith.constant 28
244 /// %v2 = complex.constant (3.0, 4.0)
245 /// callback({%c2}, %v2)
246 ///
247 /// %c3 = arith.constant 31
248 /// %v3 = complex.constant (5.0, 6.0)
249 /// callback({%c3}, %v3)
250 void foreachInSparseConstant(
251     OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
252     function_ref<void(ArrayRef<Value>, Value)> callback);
253 
254 /// Loads `size`-many values from the memref, which must have rank-1 and
255 /// size greater-or-equal to `size`.  If the optional `(offsetIdx,offsetVal)`
256 /// arguments are provided, then the `offsetVal` will be added to the
257 /// `offsetIdx`-th value after loading.
258 SmallVector<Value> loadAll(OpBuilder &builder, Location loc, size_t size,
259                            Value mem, size_t offsetIdx = 0,
260                            Value offsetVal = Value());
261 
262 /// Stores all the values of `vs` into the memref `mem`, which must have
263 /// rank-1 and size greater-or-equal to `vs.size()`.  If the optional
264 /// `(offsetIdx,offsetVal)` arguments are provided, then the `offsetVal`
265 /// will be added to the `offsetIdx`-th value before storing.
266 void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
267               size_t offsetIdx = 0, Value offsetVal = Value());
268 
269 // Generates code to cast a tensor to a memref.
270 TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
271                                        Value tensor);
272 
273 /// Generates code to retrieve the slice offset for the sparse tensor slice,
274 /// return a constant if the offset is statically known.
275 Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor,
276                                 Dimension dim);
277 
278 /// Generates code to retrieve the slice slice for the sparse tensor slice,
279 /// return a constant if the offset is statically known.
280 Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor,
281                                 Dimension dim);
282 
283 /// Generates code that opens a reader and sets the dimension sizes.
284 Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt,
285                 Value tensor,
286                 /*out*/ SmallVectorImpl<Value> &dimSizesValues,
287                 /*out*/ Value &dimSizesBuffer);
288 
289 /// Generates code to set up the buffer parameters for a map.
290 Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt,
291                     ArrayRef<Value> dimSizesValues, Value dimSizesBuffer,
292                     /*out*/ SmallVectorImpl<Value> &lvlSizesValues,
293                     /*out*/ Value &dim2lvlBuffer,
294                     /*out*/ Value &lvl2dimBuffer);
295 
296 //===----------------------------------------------------------------------===//
297 // Inlined constant generators.
298 //
299 // All these functions are just wrappers to improve code legibility;
300 // therefore, we mark them as `inline` to avoid introducing any additional
301 // overhead due to the legibility. Ideally these should move upstream.
302 //
303 //===----------------------------------------------------------------------===//
304 
305 /// Generates a 0-valued constant of the given type.  In addition to
306 /// the scalar types (`ComplexType`, `FloatType`, `IndexType`,
307 /// `IntegerType`), this also works for `RankedTensorType` and `VectorType`
308 /// (for which it generates a constant `DenseElementsAttr` of zeros).
309 inline Value constantZero(OpBuilder &builder, Location loc, Type tp) {
310   if (auto ctp = dyn_cast<ComplexType>(tp)) {
311     auto zeroe = builder.getZeroAttr(ctp.getElementType());
312     auto zeroa = builder.getArrayAttr({zeroe, zeroe});
313     return builder.create<complex::ConstantOp>(loc, tp, zeroa);
314   }
315   return builder.create<arith::ConstantOp>(loc, tp, builder.getZeroAttr(tp));
316 }
317 
318 /// Generates a 1-valued constant of the given type.  This supports all
319 /// the same types as `constantZero`.
320 inline Value constantOne(OpBuilder &builder, Location loc, Type tp) {
321   if (auto ctp = dyn_cast<ComplexType>(tp)) {
322     auto zeroe = builder.getZeroAttr(ctp.getElementType());
323     auto onee = getOneAttr(builder, ctp.getElementType());
324     auto zeroa = builder.getArrayAttr({onee, zeroe});
325     return builder.create<complex::ConstantOp>(loc, tp, zeroa);
326   }
327   return builder.create<arith::ConstantOp>(loc, tp, getOneAttr(builder, tp));
328 }
329 
330 /// Generates a constant of `index` type.
331 inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) {
332   return builder.create<arith::ConstantIndexOp>(loc, i);
333 }
334 
335 /// Generates a constant of `i64` type.
336 inline Value constantI64(OpBuilder &builder, Location loc, int64_t i) {
337   return builder.create<arith::ConstantIntOp>(loc, i, 64);
338 }
339 
340 /// Generates a constant of `i32` type.
341 inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) {
342   return builder.create<arith::ConstantIntOp>(loc, i, 32);
343 }
344 
345 /// Generates a constant of `i16` type.
346 inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) {
347   return builder.create<arith::ConstantIntOp>(loc, i, 16);
348 }
349 
350 /// Generates a constant of `i8` type.
351 inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) {
352   return builder.create<arith::ConstantIntOp>(loc, i, 8);
353 }
354 
355 /// Generates a constant of `i1` type.
356 inline Value constantI1(OpBuilder &builder, Location loc, bool b) {
357   return builder.create<arith::ConstantIntOp>(loc, b, 1);
358 }
359 
360 /// Generates a constant of the given `Action`.
361 inline Value constantAction(OpBuilder &builder, Location loc, Action action) {
362   return constantI32(builder, loc, static_cast<uint32_t>(action));
363 }
364 
365 /// Generates a constant of the internal type-encoding for overhead storage.
366 inline Value constantOverheadTypeEncoding(OpBuilder &builder, Location loc,
367                                           unsigned width) {
368   return constantI32(builder, loc,
369                      static_cast<uint32_t>(overheadTypeEncoding(width)));
370 }
371 
372 /// Generates a constant of the internal type-encoding for position
373 /// overhead storage.
374 inline Value constantPosTypeEncoding(OpBuilder &builder, Location loc,
375                                      SparseTensorEncodingAttr enc) {
376   return constantOverheadTypeEncoding(builder, loc, enc.getPosWidth());
377 }
378 
379 /// Generates a constant of the internal type-encoding for coordinate
380 /// overhead storage.
381 inline Value constantCrdTypeEncoding(OpBuilder &builder, Location loc,
382                                      SparseTensorEncodingAttr enc) {
383   return constantOverheadTypeEncoding(builder, loc, enc.getCrdWidth());
384 }
385 
386 /// Generates a constant of the internal type-encoding for primary storage.
387 inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc,
388                                          Type elemTp) {
389   return constantI32(builder, loc,
390                      static_cast<uint32_t>(primaryTypeEncoding(elemTp)));
391 }
392 
393 /// Generates a constant of the internal dimension level type encoding.
394 inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
395                                        LevelType lt) {
396   return constantI64(builder, loc, static_cast<uint64_t>(lt));
397 }
398 
399 // Generates a constant from a validated value carrying attribute.
400 inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) {
401   if (auto complexAttr = dyn_cast<complex::NumberAttr>(attr)) {
402     Type tp = cast<ComplexType>(complexAttr.getType()).getElementType();
403     return builder.create<complex::ConstantOp>(
404         loc, complexAttr.getType(),
405         builder.getArrayAttr({FloatAttr::get(tp, complexAttr.getReal()),
406                               FloatAttr::get(tp, complexAttr.getImag())}));
407   }
408   return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(attr));
409 }
410 
411 // TODO: is this at the right place?
412 inline bool isZeroRankedTensorOrScalar(Type type) {
413   auto rtp = dyn_cast<RankedTensorType>(type);
414   return !rtp || rtp.getRank() == 0;
415 }
416 
417 } // namespace sparse_tensor
418 } // namespace mlir
419 
420 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENUTILS_H_
421