1 //===- Utils.h - Utilities to support the Tensor dialect -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_TENSOR_UTILS_UTILS_H_ 10 #define MLIR_DIALECT_TENSOR_UTILS_UTILS_H_ 11 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 14 namespace mlir { 15 namespace tensor { 16 17 // Return a PadOp that pads `source` to `resType` size. The op performs "high" 18 // padding, i.e. it adds trailing padding values until the desired size is met. 19 // Output sizes are assumed to be greater than the input sizes. The padding 20 // width is calculated as: resDim - sourceDim. 21 // 22 // Handling static sizes is trivial. Dynamic dimensions are trickier (*): 23 // 1. dynamic input sizes are extracted from `source` 24 // 2. for dynamic output dims, there are two options: 25 // 2.1 all output dynamic dim sizes are specified in `dynOutDim`, 26 // 2.2 `dynOutDim` is empty and the corresponding padding width is set to 0. 27 // 28 // (*) Note that `resType` is just a shape and it only encodes the actual sizes 29 // for _static_ dimensions. 30 PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, 31 bool nofold, Location loc, OpBuilder &builder, 32 SmallVector<Value> dynOutDim = {}); 33 34 // Creates dim ops for each dynamic dimension of the ranked tensor argument and 35 // returns these as values. 36 SmallVector<Value> createDynamicDimValues(OpBuilder &b, Location loc, 37 Value rankedTensor); 38 39 /// Returns the transposed `rankedTensorType` if `transposeVector` is non-empty. 40 /// Fail if `transposeVector` is not a permutation matching the tensor rank. 41 FailureOr<RankedTensorType> 42 computeTransposedType(RankedTensorType rankedTensorType, 43 ArrayRef<int64_t> transposeVector); 44 45 /// Shell function to compute the Destination Permutation of PackOp 46 /// This function uses the helper function `computePackUnPackPerm` to get 47 /// the permutation vector. Only major difference between UnPack and Pack is 48 /// that packOp uses destination rank whereas unpack Uses source rank. 49 SmallVector<int64_t> getPackInverseDestPerm(tensor::PackOp packOp); 50 51 /// Shell function to compute the Source Permutation of unPackOp. 52 /// This function, like the getPackInverseDestPerm uses the helper function 53 /// computePackUnPackPerm` to get the permutation vector. 54 /// Only major difference between UnPack and Pack is that packOp uses 55 /// destination rank whereas unpack Uses source rank. 56 SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp); 57 58 /// Shell function to compute the Source rank permutation for unpackOp 59 /// Unpack requires some packing metadata data information, so created 60 /// another function where this value is passed by reference. 61 SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp, 62 PackingMetadata &metadata); 63 64 /// A tensor.insert_slice is a cast-like operation if it merely rank-extends the 65 /// source tensor or inserts the source tensor into a destination tensor with 66 /// the same shape. 67 bool isCastLikeInsertSliceOp(InsertSliceOp op); 68 69 /// A tensor.extract_slice is a cast-like operation if it merely rank-reduces 70 /// unit dimensions of the source tensor or extracts the entire source tensor. 71 bool isCastLikeExtractSliceOp(ExtractSliceOp op); 72 73 } // namespace tensor 74 } // namespace mlir 75 76 #endif // MLIR_DIALECT_TENSOR_UTILS_UTILS_H_ 77