1 //===- Tensor.h - Tensor dialect --------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_TENSOR_IR_TENSOR_H_ 10 #define MLIR_DIALECT_TENSOR_IR_TENSOR_H_ 11 12 #include "mlir/Bytecode/BytecodeOpInterface.h" 13 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 14 #include "mlir/IR/BuiltinTypes.h" 15 #include "mlir/IR/Dialect.h" 16 #include "mlir/IR/OpDefinition.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/Interfaces/CastInterfaces.h" 19 #include "mlir/Interfaces/ControlFlowInterfaces.h" 20 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 21 #include "mlir/Interfaces/InferTypeOpInterface.h" 22 #include "mlir/Interfaces/ParallelCombiningOpInterface.h" 23 #include "mlir/Interfaces/ShapedOpInterfaces.h" 24 #include "mlir/Interfaces/SideEffectInterfaces.h" 25 #include "mlir/Interfaces/TilingInterface.h" 26 #include "mlir/Interfaces/ViewLikeInterface.h" 27 28 //===----------------------------------------------------------------------===// 29 // Tensor Dialect Helpers 30 //===----------------------------------------------------------------------===// 31 32 namespace mlir { 33 34 /// Return the list of Range (i.e. offset, size, stride). Each Range 35 /// entry contains either the dynamic value or a ConstantIndexOp constructed 36 /// with `b` at location `loc`. 37 SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 38 OpBuilder &b, Location loc); 39 40 } // namespace mlir 41 42 //===----------------------------------------------------------------------===// 43 // Tensor Dialect 44 //===----------------------------------------------------------------------===// 45 46 #include "mlir/Dialect/Tensor/IR/TensorOpsDialect.h.inc" 47 48 //===----------------------------------------------------------------------===// 49 // Tensor Dialect Operations 50 //===----------------------------------------------------------------------===// 51 52 #define GET_OP_CLASSES 53 #include "mlir/Dialect/Tensor/IR/TensorOps.h.inc" 54 55 //===----------------------------------------------------------------------===// 56 // Tensor Dialect Helpers 57 //===----------------------------------------------------------------------===// 58 59 namespace mlir { 60 namespace tensor { 61 62 /// Returns true if `target` is a ranked tensor type that preserves static 63 /// information available in the `source` ranked tensor type. 64 bool preservesStaticInformation(Type source, Type target); 65 66 /// Determines whether tensor::CastOp casts to a more dynamic version of the 67 /// source tensor. This is useful to fold a tensor.cast into a consuming op and 68 /// implement canonicalization patterns for ops in different dialects that may 69 /// consume the results of tensor.cast operations. Such foldable tensor.cast 70 /// operations are typically inserted as `extract_slice` ops and are 71 /// canonicalized, to preserve the type compatibility of their uses. 72 /// 73 /// Returns true when all conditions are met: 74 /// 1. source and result are ranked tensors with same element type and rank. 75 /// 2. the tensor type has more static information than the result 76 /// 77 /// Example: 78 /// ```mlir 79 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32> 80 /// %2 = consumer %1 ... : tensor<?x?xf32> ... 81 /// ``` 82 /// 83 /// folds into: 84 /// 85 /// ```mlir 86 /// %2 = consumer %0 ... : tensor<8x16xf32> ... 87 /// ``` 88 bool canFoldIntoConsumerOp(CastOp castOp); 89 90 /// Determines whether the tensor::CastOp casts to a more static version of the 91 /// source tensor. This is useful to fold into a producing op and implement 92 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer 93 /// being from different dialects. Returns true when all conditions are met: 94 /// 1. source and result and ranked tensors with same element type and rank. 95 /// 2. the result type has more static information than the source. 96 /// 97 /// Example: 98 /// ```mlir 99 /// %1 = producer ... : tensor<?x?xf32> 100 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32> 101 /// ``` 102 /// 103 /// can be canonicalized to : 104 /// 105 /// ```mlir 106 /// %2 = producer ... : tensor<8x16xf32> 107 /// ``` 108 /// Not all ops might be canonicalizable this way, but for those that can be, 109 /// this method provides a check that it is worth doing the canonicalization. 110 bool canFoldIntoProducerOp(CastOp castOp); 111 112 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp 113 /// that can be folded. 114 LogicalResult foldTensorCast(Operation *op); 115 116 /// Return the dimension of the given tensor value. 117 OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, 118 int64_t dim); 119 120 /// Return the dimensions of the given tensor value. 121 SmallVector<OpFoldResult> getMixedSizes(OpBuilder &builder, Location loc, 122 Value value); 123 124 /// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and 125 /// appropriate sizes (i.e. `tensor.getSizes()`) to reduce the rank of `tensor` 126 /// to that of `targetType`. 127 Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, 128 Value tensor, 129 RankedTensorType targetType); 130 131 /// Create a rank-reducing InsertSliceOp @[0 .. 0] with strides [1 .. 1] and 132 /// appropriate sizes (i.e. `dest.getSizes()`). The result is a new tensor with 133 /// rank increased to that of `dest`, obtained by inserting `tensor` into `dest` 134 /// at the canonical [0 .. 0] position. 135 Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, 136 Value tensor, Value dest); 137 138 /// This is a helper function for DestinationStyleOpInterface. If there is a 139 /// destination operand for the given OpResult, return that operand. Otherwise, 140 /// return an empty tensor (`tensor.empty`) with the shape of the OpResult. 141 /// Dynamic dimensions are queried via ReifyRankedShapedTypeOpInterface. 142 FailureOr<Value> getOrCreateDestination(OpBuilder &b, Location loc, 143 OpResult opResult); 144 145 /// This is a helper function for DestinationStyleOpInterface. Get or create 146 /// destinations for every tensor OpResult of the given op. 147 LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, 148 SmallVector<Value> &result); 149 150 /// Tests if types are the same when ignoring encoding on ranked tensors. 151 bool isSameTypeWithoutEncoding(Type tp1, Type tp2); 152 153 /// Function to control the folding of constant and extract slice. 154 using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>; 155 156 /// Patterns to fold the extract slice op with its constant operand. 157 void populateFoldConstantExtractSlicePatterns( 158 RewritePatternSet &patterns, 159 const ControlConstantExtractSliceFusionFn &controlFn = 160 [](ExtractSliceOp op) { 161 // Disable by default because the folding can generate a large 162 // constant tensor, which would affect the compile time and storage. 163 return false; 164 }); 165 166 } // namespace tensor 167 } // namespace mlir 168 169 #endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_ 170