//===- CodegenUtils.h - Utilities for generating MLIR -----------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This header file defines utilities for generating MLIR. // //===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENUTILS_H_ #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENUTILS_H_ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/Builders.h" namespace mlir { class Location; class Type; class Value; namespace sparse_tensor { /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, /// `createFuncCall()`, and `replaceOpWithFuncCall()`. enum class EmitCInterface : bool { Off = false, On = true }; //===----------------------------------------------------------------------===// // ExecutionEngine/SparseTensorUtils helper functions. //===----------------------------------------------------------------------===// /// Converts an overhead storage bitwidth to its internal type-encoding. OverheadType overheadTypeEncoding(unsigned width); /// Converts an overhead storage type to its internal type-encoding. OverheadType overheadTypeEncoding(Type tp); /// Converts the internal type-encoding for overhead storage to an mlir::Type. Type getOverheadType(Builder &builder, OverheadType ot); /// Returns the OverheadType for position overhead storage. OverheadType posTypeEncoding(SparseTensorEncodingAttr enc); /// Returns the OverheadType for coordinate overhead storage. OverheadType crdTypeEncoding(SparseTensorEncodingAttr enc); /// Convert OverheadType to its function-name suffix. StringRef overheadTypeFunctionSuffix(OverheadType ot); /// Converts an overhead storage type to its function-name suffix. StringRef overheadTypeFunctionSuffix(Type overheadTp); /// Converts a primary storage type to its internal type-encoding. PrimaryType primaryTypeEncoding(Type elemTp); /// Convert PrimaryType to its function-name suffix. StringRef primaryTypeFunctionSuffix(PrimaryType pt); /// Converts a primary storage type to its function-name suffix. StringRef primaryTypeFunctionSuffix(Type elemTp); //===----------------------------------------------------------------------===// // Misc code generators and utilities. //===----------------------------------------------------------------------===// /// A helper class to simplify lowering operations with/without function calls. template class FuncCallOrInlineGenerator { public: FuncCallOrInlineGenerator(TypeRange retTypes, ValueRange params, bool genCall) : retTypes(retTypes), params(params), genCall(genCall) {} // The main API invoked by clients, which abstracts away the details of // creating function calls from clients. SmallVector genCallOrInline(OpBuilder &builder, Location loc) { if (!genCall) return genImplementation(retTypes, params, builder, loc); // Looks up the function. std::string funcName = getMangledFuncName(); ModuleOp module = getParentOpOf(builder); MLIRContext *context = module.getContext(); auto result = SymbolRefAttr::get(context, funcName); auto func = module.lookupSymbol(result.getAttr()); if (!func) { // Create the function if not already exist. OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPoint(getParentOpOf(builder)); func = builder.create( loc, funcName, FunctionType::get(context, params.getTypes(), retTypes)); func.setPrivate(); // Set the insertion point to the body of the function. Block *entryBB = func.addEntryBlock(); builder.setInsertionPointToStart(entryBB); ValueRange args = entryBB->getArguments(); // Delegates to user to generate the actually implementation. SmallVector result = genImplementation(retTypes, args, builder, loc); builder.create(loc, result); } // Returns the CallOp result. func::CallOp call = builder.create(loc, func, params); return call.getResults(); } private: template OpTp getParentOpOf(OpBuilder &builder) { return builder.getInsertionBlock()->getParent()->getParentOfType(); } // CRTP: get the mangled function name (only called when genCall=true). std::string getMangledFuncName() { return static_cast(this)->getMangledFuncName(); } // CRTP: Client implementation. SmallVector genImplementation(TypeRange retTypes, ValueRange params, OpBuilder &builder, Location loc) { return static_cast(this)->genImplementation(retTypes, params, builder, loc); } private: TypeRange retTypes; // The types of all returned results ValueRange params; // The values of all input parameters bool genCall; // Should the implemetantion be wrapped in a function }; /// Add type casting between arith and index types when needed. Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy); /// Add conversion from scalar to given type (possibly a 0-rank tensor). Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp); /// Generates a pointer/index load from the sparse storage scheme. Narrower /// data types need to be zero extended before casting the value into the /// index type used for looping and indexing. Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s); /// Generates a 1-valued attribute of the given type. This supports /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`, /// for unsupported types we raise `llvm_unreachable` rather than /// returning a null attribute. TypedAttr getOneAttr(Builder &builder, Type tp); /// Generates the comparison `v != 0` where `v` is of numeric type. /// For floating types, we use the "unordered" comparator (i.e., returns /// true if `v` is NaN). Value genIsNonzero(OpBuilder &builder, Location loc, Value v); /// Computes the shape of destination tensor of a reshape operator. This is only /// used when operands have dynamic shape. The shape of the destination is /// stored into dstShape. void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl &dstShape, ArrayRef srcShape, ArrayRef staticDstShape, ArrayRef reassociation); /// Reshape coordinates during a reshaping operation. void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef reassociation, ValueRange srcSizes, ValueRange srcCvs, // NOLINT ValueRange dstSizes, SmallVectorImpl &dstCvs); /// Returns a function reference (first hit also inserts into module). Sets /// the "_emit_c_interface" on the function declaration when requested, /// so that LLVM lowering generates a wrapper function that takes care /// of ABI complications with passing in and returning MemRefs to C functions. FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface); /// Creates a `CallOp` to the function reference returned by `getFunc()` in /// the builder's module. func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface); /// Returns the equivalent of `void*` for opaque arguments to the /// execution engine. Type getOpaquePointerType(MLIRContext *ctx); Type getOpaquePointerType(Builder &builder); /// Generates an uninitialized temporary buffer of the given size and /// type, but returns it as type `memref` (rather than as type /// `memref<$sz x $tp>`). Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp); /// Generates an uninitialized temporary buffer of the given size and /// type, and returns it as type `memref` (staticShape=false) or /// `memref<$sz x $tp>` (staticShape=true). Value genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp, bool staticShape = false); /// Generates an uninitialized temporary buffer with room for one value /// of the given type, and returns the `memref<$tp>`. Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp); /// Generates a temporary buffer, initializes it with the given contents, /// and returns it as type `memref` (rather than specifying the /// size of the buffer). Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values); /// Generates code to allocate a buffer of the given type, and zero /// initialize it. If the buffer type has any dynamic sizes, then the /// `sizes` parameter should be as filled by sizesFromPtr(); that way /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). Value allocDenseTensor(OpBuilder &builder, Location loc, RankedTensorType tensorTp, ValueRange sizes); /// Generates code to deallocate a dense buffer. void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer); /// Populates given sizes array from dense tensor or sparse tensor constant. void sizesFromSrc(OpBuilder &builder, SmallVectorImpl &sizes, Location loc, Value src); /// Scans to top of generated loop. Operation *getTop(Operation *op); /// Iterate over a sparse constant, generates constantOp for value /// and coordinates. E.g., /// sparse<[ [0], [28], [31] ], /// [ (-5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] > /// => /// %c1 = arith.constant 0 /// %v1 = complex.constant (5.13, 2.0) /// callback({%c1}, %v1) /// /// %c2 = arith.constant 28 /// %v2 = complex.constant (3.0, 4.0) /// callback({%c2}, %v2) /// /// %c3 = arith.constant 31 /// %v3 = complex.constant (5.0, 6.0) /// callback({%c3}, %v3) void foreachInSparseConstant( OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref, Value)> callback); /// Loads `size`-many values from the memref, which must have rank-1 and /// size greater-or-equal to `size`. If the optional `(offsetIdx,offsetVal)` /// arguments are provided, then the `offsetVal` will be added to the /// `offsetIdx`-th value after loading. SmallVector loadAll(OpBuilder &builder, Location loc, size_t size, Value mem, size_t offsetIdx = 0, Value offsetVal = Value()); /// Stores all the values of `vs` into the memref `mem`, which must have /// rank-1 and size greater-or-equal to `vs.size()`. If the optional /// `(offsetIdx,offsetVal)` arguments are provided, then the `offsetVal` /// will be added to the `offsetIdx`-th value before storing. void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, size_t offsetIdx = 0, Value offsetVal = Value()); // Generates code to cast a tensor to a memref. TypedValue genToMemref(OpBuilder &builder, Location loc, Value tensor); /// Generates code to retrieve the slice offset for the sparse tensor slice, /// return a constant if the offset is statically known. Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim); /// Generates code to retrieve the slice slice for the sparse tensor slice, /// return a constant if the offset is statically known. Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim); /// Generates code that opens a reader and sets the dimension sizes. Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, /*out*/ SmallVectorImpl &dimSizesValues, /*out*/ Value &dimSizesBuffer); /// Generates code to set up the buffer parameters for a map. Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef dimSizesValues, Value dimSizesBuffer, /*out*/ SmallVectorImpl &lvlSizesValues, /*out*/ Value &dim2lvlBuffer, /*out*/ Value &lvl2dimBuffer); //===----------------------------------------------------------------------===// // Inlined constant generators. // // All these functions are just wrappers to improve code legibility; // therefore, we mark them as `inline` to avoid introducing any additional // overhead due to the legibility. Ideally these should move upstream. // //===----------------------------------------------------------------------===// /// Generates a 0-valued constant of the given type. In addition to /// the scalar types (`ComplexType`, `FloatType`, `IndexType`, /// `IntegerType`), this also works for `RankedTensorType` and `VectorType` /// (for which it generates a constant `DenseElementsAttr` of zeros). inline Value constantZero(OpBuilder &builder, Location loc, Type tp) { if (auto ctp = dyn_cast(tp)) { auto zeroe = builder.getZeroAttr(ctp.getElementType()); auto zeroa = builder.getArrayAttr({zeroe, zeroe}); return builder.create(loc, tp, zeroa); } return builder.create(loc, tp, builder.getZeroAttr(tp)); } /// Generates a 1-valued constant of the given type. This supports all /// the same types as `constantZero`. inline Value constantOne(OpBuilder &builder, Location loc, Type tp) { if (auto ctp = dyn_cast(tp)) { auto zeroe = builder.getZeroAttr(ctp.getElementType()); auto onee = getOneAttr(builder, ctp.getElementType()); auto zeroa = builder.getArrayAttr({onee, zeroe}); return builder.create(loc, tp, zeroa); } return builder.create(loc, tp, getOneAttr(builder, tp)); } /// Generates a constant of `index` type. inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) { return builder.create(loc, i); } /// Generates a constant of `i64` type. inline Value constantI64(OpBuilder &builder, Location loc, int64_t i) { return builder.create(loc, i, 64); } /// Generates a constant of `i32` type. inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) { return builder.create(loc, i, 32); } /// Generates a constant of `i16` type. inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) { return builder.create(loc, i, 16); } /// Generates a constant of `i8` type. inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) { return builder.create(loc, i, 8); } /// Generates a constant of `i1` type. inline Value constantI1(OpBuilder &builder, Location loc, bool b) { return builder.create(loc, b, 1); } /// Generates a constant of the given `Action`. inline Value constantAction(OpBuilder &builder, Location loc, Action action) { return constantI32(builder, loc, static_cast(action)); } /// Generates a constant of the internal type-encoding for overhead storage. inline Value constantOverheadTypeEncoding(OpBuilder &builder, Location loc, unsigned width) { return constantI32(builder, loc, static_cast(overheadTypeEncoding(width))); } /// Generates a constant of the internal type-encoding for position /// overhead storage. inline Value constantPosTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc) { return constantOverheadTypeEncoding(builder, loc, enc.getPosWidth()); } /// Generates a constant of the internal type-encoding for coordinate /// overhead storage. inline Value constantCrdTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc) { return constantOverheadTypeEncoding(builder, loc, enc.getCrdWidth()); } /// Generates a constant of the internal type-encoding for primary storage. inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, Type elemTp) { return constantI32(builder, loc, static_cast(primaryTypeEncoding(elemTp))); } /// Generates a constant of the internal dimension level type encoding. inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc, LevelType lt) { return constantI64(builder, loc, static_cast(lt)); } // Generates a constant from a validated value carrying attribute. inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) { if (auto complexAttr = dyn_cast(attr)) { Type tp = cast(complexAttr.getType()).getElementType(); return builder.create( loc, complexAttr.getType(), builder.getArrayAttr({FloatAttr::get(tp, complexAttr.getReal()), FloatAttr::get(tp, complexAttr.getImag())})); } return builder.create(loc, cast(attr)); } // TODO: is this at the right place? inline bool isZeroRankedTensorOrScalar(Type type) { auto rtp = dyn_cast(type); return !rtp || rtp.getRank() == 0; } } // namespace sparse_tensor } // namespace mlir #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENUTILS_H_