1 //===- CodegenUtils.h - Utilities for generating MLIR -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file defines utilities for generating MLIR. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENUTILS_H_ 14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENUTILS_H_ 15 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Complex/IR/Complex.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 23 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 24 #include "mlir/IR/Builders.h" 25 26 namespace mlir { 27 28 class Location; 29 class Type; 30 class Value; 31 32 namespace sparse_tensor { 33 34 /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, 35 /// `createFuncCall()`, and `replaceOpWithFuncCall()`. 36 enum class EmitCInterface : bool { Off = false, On = true }; 37 38 //===----------------------------------------------------------------------===// 39 // ExecutionEngine/SparseTensorUtils helper functions. 40 //===----------------------------------------------------------------------===// 41 42 /// Converts an overhead storage bitwidth to its internal type-encoding. 43 OverheadType overheadTypeEncoding(unsigned width); 44 45 /// Converts an overhead storage type to its internal type-encoding. 46 OverheadType overheadTypeEncoding(Type tp); 47 48 /// Converts the internal type-encoding for overhead storage to an mlir::Type. 49 Type getOverheadType(Builder &builder, OverheadType ot); 50 51 /// Returns the OverheadType for position overhead storage. 52 OverheadType posTypeEncoding(SparseTensorEncodingAttr enc); 53 54 /// Returns the OverheadType for coordinate overhead storage. 55 OverheadType crdTypeEncoding(SparseTensorEncodingAttr enc); 56 57 /// Convert OverheadType to its function-name suffix. 58 StringRef overheadTypeFunctionSuffix(OverheadType ot); 59 60 /// Converts an overhead storage type to its function-name suffix. 61 StringRef overheadTypeFunctionSuffix(Type overheadTp); 62 63 /// Converts a primary storage type to its internal type-encoding. 64 PrimaryType primaryTypeEncoding(Type elemTp); 65 66 /// Convert PrimaryType to its function-name suffix. 67 StringRef primaryTypeFunctionSuffix(PrimaryType pt); 68 69 /// Converts a primary storage type to its function-name suffix. 70 StringRef primaryTypeFunctionSuffix(Type elemTp); 71 72 //===----------------------------------------------------------------------===// 73 // Misc code generators and utilities. 74 //===----------------------------------------------------------------------===// 75 76 /// A helper class to simplify lowering operations with/without function calls. 77 template <class SubClass> 78 class FuncCallOrInlineGenerator { 79 public: 80 FuncCallOrInlineGenerator(TypeRange retTypes, ValueRange params, bool genCall) 81 : retTypes(retTypes), params(params), genCall(genCall) {} 82 83 // The main API invoked by clients, which abstracts away the details of 84 // creating function calls from clients. 85 SmallVector<Value> genCallOrInline(OpBuilder &builder, Location loc) { 86 if (!genCall) 87 return genImplementation(retTypes, params, builder, loc); 88 89 // Looks up the function. 90 std::string funcName = getMangledFuncName(); 91 ModuleOp module = getParentOpOf<ModuleOp>(builder); 92 MLIRContext *context = module.getContext(); 93 auto result = SymbolRefAttr::get(context, funcName); 94 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 95 96 if (!func) { 97 // Create the function if not already exist. 98 OpBuilder::InsertionGuard insertionGuard(builder); 99 builder.setInsertionPoint(getParentOpOf<func::FuncOp>(builder)); 100 func = builder.create<func::FuncOp>( 101 loc, funcName, 102 FunctionType::get(context, params.getTypes(), retTypes)); 103 func.setPrivate(); 104 // Set the insertion point to the body of the function. 105 Block *entryBB = func.addEntryBlock(); 106 builder.setInsertionPointToStart(entryBB); 107 ValueRange args = entryBB->getArguments(); 108 // Delegates to user to generate the actually implementation. 109 SmallVector<Value> result = 110 genImplementation(retTypes, args, builder, loc); 111 builder.create<func::ReturnOp>(loc, result); 112 } 113 // Returns the CallOp result. 114 func::CallOp call = builder.create<func::CallOp>(loc, func, params); 115 return call.getResults(); 116 } 117 118 private: 119 template <class OpTp> 120 OpTp getParentOpOf(OpBuilder &builder) { 121 return builder.getInsertionBlock()->getParent()->getParentOfType<OpTp>(); 122 } 123 124 // CRTP: get the mangled function name (only called when genCall=true). 125 std::string getMangledFuncName() { 126 return static_cast<SubClass *>(this)->getMangledFuncName(); 127 } 128 129 // CRTP: Client implementation. 130 SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange params, 131 OpBuilder &builder, Location loc) { 132 return static_cast<SubClass *>(this)->genImplementation(retTypes, params, 133 builder, loc); 134 } 135 136 private: 137 TypeRange retTypes; // The types of all returned results 138 ValueRange params; // The values of all input parameters 139 bool genCall; // Should the implemetantion be wrapped in a function 140 }; 141 142 /// Add type casting between arith and index types when needed. 143 Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy); 144 145 /// Add conversion from scalar to given type (possibly a 0-rank tensor). 146 Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, 147 Type dstTp); 148 149 /// Generates a pointer/index load from the sparse storage scheme. Narrower 150 /// data types need to be zero extended before casting the value into the 151 /// index type used for looping and indexing. 152 Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s); 153 154 /// Generates a 1-valued attribute of the given type. This supports 155 /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`, 156 /// for unsupported types we raise `llvm_unreachable` rather than 157 /// returning a null attribute. 158 TypedAttr getOneAttr(Builder &builder, Type tp); 159 160 /// Generates the comparison `v != 0` where `v` is of numeric type. 161 /// For floating types, we use the "unordered" comparator (i.e., returns 162 /// true if `v` is NaN). 163 Value genIsNonzero(OpBuilder &builder, Location loc, Value v); 164 165 /// Computes the shape of destination tensor of a reshape operator. This is only 166 /// used when operands have dynamic shape. The shape of the destination is 167 /// stored into dstShape. 168 void genReshapeDstShape(OpBuilder &builder, Location loc, 169 SmallVectorImpl<Value> &dstShape, 170 ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape, 171 ArrayRef<ReassociationIndices> reassociation); 172 173 /// Reshape coordinates during a reshaping operation. 174 void reshapeCvs(OpBuilder &builder, Location loc, 175 ArrayRef<ReassociationIndices> reassociation, 176 ValueRange srcSizes, ValueRange srcCvs, // NOLINT 177 ValueRange dstSizes, SmallVectorImpl<Value> &dstCvs); 178 179 /// Returns a function reference (first hit also inserts into module). Sets 180 /// the "_emit_c_interface" on the function declaration when requested, 181 /// so that LLVM lowering generates a wrapper function that takes care 182 /// of ABI complications with passing in and returning MemRefs to C functions. 183 FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, 184 ValueRange operands, EmitCInterface emitCInterface); 185 186 /// Creates a `CallOp` to the function reference returned by `getFunc()` in 187 /// the builder's module. 188 func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, 189 TypeRange resultType, ValueRange operands, 190 EmitCInterface emitCInterface); 191 192 /// Returns the equivalent of `void*` for opaque arguments to the 193 /// execution engine. 194 Type getOpaquePointerType(MLIRContext *ctx); 195 Type getOpaquePointerType(Builder &builder); 196 197 /// Generates an uninitialized temporary buffer of the given size and 198 /// type, but returns it as type `memref<? x $tp>` (rather than as type 199 /// `memref<$sz x $tp>`). 200 Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp); 201 202 /// Generates an uninitialized temporary buffer of the given size and 203 /// type, and returns it as type `memref<? x $tp>` (staticShape=false) or 204 /// `memref<$sz x $tp>` (staticShape=true). 205 Value genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp, 206 bool staticShape = false); 207 208 /// Generates an uninitialized temporary buffer with room for one value 209 /// of the given type, and returns the `memref<$tp>`. 210 Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp); 211 212 /// Generates a temporary buffer, initializes it with the given contents, 213 /// and returns it as type `memref<? x $tp>` (rather than specifying the 214 /// size of the buffer). 215 Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values); 216 217 /// Generates code to allocate a buffer of the given type, and zero 218 /// initialize it. If the buffer type has any dynamic sizes, then the 219 /// `sizes` parameter should be as filled by sizesFromPtr(); that way 220 /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). 221 Value allocDenseTensor(OpBuilder &builder, Location loc, 222 RankedTensorType tensorTp, ValueRange sizes); 223 224 /// Generates code to deallocate a dense buffer. 225 void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer); 226 227 /// Populates given sizes array from dense tensor or sparse tensor constant. 228 void sizesFromSrc(OpBuilder &builder, SmallVectorImpl<Value> &sizes, 229 Location loc, Value src); 230 231 /// Scans to top of generated loop. 232 Operation *getTop(Operation *op); 233 234 /// Iterate over a sparse constant, generates constantOp for value 235 /// and coordinates. E.g., 236 /// sparse<[ [0], [28], [31] ], 237 /// [ (-5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] > 238 /// => 239 /// %c1 = arith.constant 0 240 /// %v1 = complex.constant (5.13, 2.0) 241 /// callback({%c1}, %v1) 242 /// 243 /// %c2 = arith.constant 28 244 /// %v2 = complex.constant (3.0, 4.0) 245 /// callback({%c2}, %v2) 246 /// 247 /// %c3 = arith.constant 31 248 /// %v3 = complex.constant (5.0, 6.0) 249 /// callback({%c3}, %v3) 250 void foreachInSparseConstant( 251 OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, 252 function_ref<void(ArrayRef<Value>, Value)> callback); 253 254 /// Loads `size`-many values from the memref, which must have rank-1 and 255 /// size greater-or-equal to `size`. If the optional `(offsetIdx,offsetVal)` 256 /// arguments are provided, then the `offsetVal` will be added to the 257 /// `offsetIdx`-th value after loading. 258 SmallVector<Value> loadAll(OpBuilder &builder, Location loc, size_t size, 259 Value mem, size_t offsetIdx = 0, 260 Value offsetVal = Value()); 261 262 /// Stores all the values of `vs` into the memref `mem`, which must have 263 /// rank-1 and size greater-or-equal to `vs.size()`. If the optional 264 /// `(offsetIdx,offsetVal)` arguments are provided, then the `offsetVal` 265 /// will be added to the `offsetIdx`-th value before storing. 266 void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, 267 size_t offsetIdx = 0, Value offsetVal = Value()); 268 269 // Generates code to cast a tensor to a memref. 270 TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc, 271 Value tensor); 272 273 /// Generates code to retrieve the slice offset for the sparse tensor slice, 274 /// return a constant if the offset is statically known. 275 Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor, 276 Dimension dim); 277 278 /// Generates code to retrieve the slice slice for the sparse tensor slice, 279 /// return a constant if the offset is statically known. 280 Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor, 281 Dimension dim); 282 283 /// Generates code that opens a reader and sets the dimension sizes. 284 Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, 285 Value tensor, 286 /*out*/ SmallVectorImpl<Value> &dimSizesValues, 287 /*out*/ Value &dimSizesBuffer); 288 289 /// Generates code to set up the buffer parameters for a map. 290 Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, 291 ArrayRef<Value> dimSizesValues, Value dimSizesBuffer, 292 /*out*/ SmallVectorImpl<Value> &lvlSizesValues, 293 /*out*/ Value &dim2lvlBuffer, 294 /*out*/ Value &lvl2dimBuffer); 295 296 //===----------------------------------------------------------------------===// 297 // Inlined constant generators. 298 // 299 // All these functions are just wrappers to improve code legibility; 300 // therefore, we mark them as `inline` to avoid introducing any additional 301 // overhead due to the legibility. Ideally these should move upstream. 302 // 303 //===----------------------------------------------------------------------===// 304 305 /// Generates a 0-valued constant of the given type. In addition to 306 /// the scalar types (`ComplexType`, `FloatType`, `IndexType`, 307 /// `IntegerType`), this also works for `RankedTensorType` and `VectorType` 308 /// (for which it generates a constant `DenseElementsAttr` of zeros). 309 inline Value constantZero(OpBuilder &builder, Location loc, Type tp) { 310 if (auto ctp = dyn_cast<ComplexType>(tp)) { 311 auto zeroe = builder.getZeroAttr(ctp.getElementType()); 312 auto zeroa = builder.getArrayAttr({zeroe, zeroe}); 313 return builder.create<complex::ConstantOp>(loc, tp, zeroa); 314 } 315 return builder.create<arith::ConstantOp>(loc, tp, builder.getZeroAttr(tp)); 316 } 317 318 /// Generates a 1-valued constant of the given type. This supports all 319 /// the same types as `constantZero`. 320 inline Value constantOne(OpBuilder &builder, Location loc, Type tp) { 321 if (auto ctp = dyn_cast<ComplexType>(tp)) { 322 auto zeroe = builder.getZeroAttr(ctp.getElementType()); 323 auto onee = getOneAttr(builder, ctp.getElementType()); 324 auto zeroa = builder.getArrayAttr({onee, zeroe}); 325 return builder.create<complex::ConstantOp>(loc, tp, zeroa); 326 } 327 return builder.create<arith::ConstantOp>(loc, tp, getOneAttr(builder, tp)); 328 } 329 330 /// Generates a constant of `index` type. 331 inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) { 332 return builder.create<arith::ConstantIndexOp>(loc, i); 333 } 334 335 /// Generates a constant of `i64` type. 336 inline Value constantI64(OpBuilder &builder, Location loc, int64_t i) { 337 return builder.create<arith::ConstantIntOp>(loc, i, 64); 338 } 339 340 /// Generates a constant of `i32` type. 341 inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) { 342 return builder.create<arith::ConstantIntOp>(loc, i, 32); 343 } 344 345 /// Generates a constant of `i16` type. 346 inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) { 347 return builder.create<arith::ConstantIntOp>(loc, i, 16); 348 } 349 350 /// Generates a constant of `i8` type. 351 inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) { 352 return builder.create<arith::ConstantIntOp>(loc, i, 8); 353 } 354 355 /// Generates a constant of `i1` type. 356 inline Value constantI1(OpBuilder &builder, Location loc, bool b) { 357 return builder.create<arith::ConstantIntOp>(loc, b, 1); 358 } 359 360 /// Generates a constant of the given `Action`. 361 inline Value constantAction(OpBuilder &builder, Location loc, Action action) { 362 return constantI32(builder, loc, static_cast<uint32_t>(action)); 363 } 364 365 /// Generates a constant of the internal type-encoding for overhead storage. 366 inline Value constantOverheadTypeEncoding(OpBuilder &builder, Location loc, 367 unsigned width) { 368 return constantI32(builder, loc, 369 static_cast<uint32_t>(overheadTypeEncoding(width))); 370 } 371 372 /// Generates a constant of the internal type-encoding for position 373 /// overhead storage. 374 inline Value constantPosTypeEncoding(OpBuilder &builder, Location loc, 375 SparseTensorEncodingAttr enc) { 376 return constantOverheadTypeEncoding(builder, loc, enc.getPosWidth()); 377 } 378 379 /// Generates a constant of the internal type-encoding for coordinate 380 /// overhead storage. 381 inline Value constantCrdTypeEncoding(OpBuilder &builder, Location loc, 382 SparseTensorEncodingAttr enc) { 383 return constantOverheadTypeEncoding(builder, loc, enc.getCrdWidth()); 384 } 385 386 /// Generates a constant of the internal type-encoding for primary storage. 387 inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, 388 Type elemTp) { 389 return constantI32(builder, loc, 390 static_cast<uint32_t>(primaryTypeEncoding(elemTp))); 391 } 392 393 /// Generates a constant of the internal dimension level type encoding. 394 inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc, 395 LevelType lt) { 396 return constantI64(builder, loc, static_cast<uint64_t>(lt)); 397 } 398 399 // Generates a constant from a validated value carrying attribute. 400 inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) { 401 if (auto complexAttr = dyn_cast<complex::NumberAttr>(attr)) { 402 Type tp = cast<ComplexType>(complexAttr.getType()).getElementType(); 403 return builder.create<complex::ConstantOp>( 404 loc, complexAttr.getType(), 405 builder.getArrayAttr({FloatAttr::get(tp, complexAttr.getReal()), 406 FloatAttr::get(tp, complexAttr.getImag())})); 407 } 408 return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(attr)); 409 } 410 411 // TODO: is this at the right place? 412 inline bool isZeroRankedTensorOrScalar(Type type) { 413 auto rtp = dyn_cast<RankedTensorType>(type); 414 return !rtp || rtp.getRank() == 0; 415 } 416 417 } // namespace sparse_tensor 418 } // namespace mlir 419 420 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENUTILS_H_ 421