1 //===- Utils.h - General Arith transformation utilities ----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file defines prototypes for various transformation utilities for 10 // the Arith dialect. These are not passes by themselves but are used 11 // either by passes, optimization sequences, or in turn by other transformation 12 // utilities. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #ifndef MLIR_DIALECT_ARITH_UTILS_UTILS_H 17 #define MLIR_DIALECT_ARITH_UTILS_UTILS_H 18 19 #include "mlir/Dialect/Arith/IR/Arith.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/IR/Value.h" 23 #include "llvm/ADT/ArrayRef.h" 24 25 namespace mlir { 26 27 using ReassociationIndices = SmallVector<int64_t, 2>; 28 29 /// Infer the output shape for a {memref|tensor}.expand_shape when it is 30 /// possible to do so. 31 /// 32 /// Note: This should *only* be used to implement 33 /// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces. 34 /// If you need to infer the output shape you should use the static method of 35 /// `ExpandShapeOp` instead of calling this. 36 /// 37 /// `inputShape` is the shape of the tensor or memref being expanded as a 38 /// sequence of SSA values or constants. `expandedType` is the output shape of 39 /// the expand_shape operation. `reassociation` is the reassociation denoting 40 /// the output dims each input dim is mapped to. 41 /// 42 /// Returns the output shape in `outputShape` and `staticOutputShape`, following 43 /// the conventions for the output_shape and static_output_shape inputs to the 44 /// expand_shape ops. 45 std::optional<SmallVector<OpFoldResult>> 46 inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, 47 ArrayRef<ReassociationIndices> reassociation, 48 ArrayRef<OpFoldResult> inputShape); 49 50 /// Matches a ConstantIndexOp. 51 detail::op_matcher<arith::ConstantIndexOp> matchConstantIndex(); 52 53 llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, 54 ArrayRef<int64_t> shape); 55 56 /// Converts an OpFoldResult to a Value. Returns the fold result if it casts to 57 /// a Value or creates a ConstantOp if it casts to an Integer Attribute. 58 /// Other attribute types are not supported. 59 Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, 60 OpFoldResult ofr); 61 62 /// Converts an OpFoldResult to a Value. Returns the fold result if it casts to 63 /// a Value or creates a ConstantIndexOp if it casts to an Integer Attribute. 64 /// Other attribute types are not supported. 65 Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, 66 OpFoldResult ofr); 67 68 /// Similar to the other overload, but converts multiple OpFoldResults into 69 /// Values. 70 SmallVector<Value> 71 getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, 72 ArrayRef<OpFoldResult> valueOrAttrVec); 73 74 /// Create a cast from an index-like value (index or integer) to another 75 /// index-like value. If the value type and the target type are the same, it 76 /// returns the original value. 77 Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, 78 Type targetType, Value value); 79 80 /// Converts a scalar value `operand` to type `toType`. If the value doesn't 81 /// convert, a warning will be issued and the operand is returned as is (which 82 /// will presumably yield a verification issue downstream). 83 Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, 84 Type toType, bool isUnsignedCast); 85 86 /// Create a constant of type `type` at location `loc` whose value is `value` 87 /// (an APInt or APFloat whose type must match the element type of `type`). 88 /// If `type` is a shaped type, create a splat constant of the given value. 89 /// Constants are folded if possible. 90 Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, 91 const APInt &value); 92 Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, 93 int64_t value); 94 Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, 95 const APFloat &value); 96 97 /// Returns the int type of the integer in ofr. 98 /// Other attribute types are not supported. 99 Type getType(OpFoldResult ofr); 100 101 /// Helper struct to build simple arithmetic quantities with minimal type 102 /// inference support. 103 struct ArithBuilder { 104 ArithBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {} 105 106 Value _and(Value lhs, Value rhs); 107 Value add(Value lhs, Value rhs); 108 Value sub(Value lhs, Value rhs); 109 Value mul(Value lhs, Value rhs); 110 Value select(Value cmp, Value lhs, Value rhs); 111 Value sgt(Value lhs, Value rhs); 112 Value slt(Value lhs, Value rhs); 113 114 private: 115 OpBuilder &b; 116 Location loc; 117 }; 118 119 namespace arith { 120 121 // Build the product of a sequence. 122 // If values = (v0, v1, ..., vn) than the returned 123 // value is v0 * v1 * ... * vn. 124 // All values must have the same type. 125 // 126 // The version without `resultType` must contain at least one element in values. 127 // Then the result will have the same type as the elements in `values`. 128 // If `values` is empty in the version with `resultType` returns 1 with type 129 // `resultType`. 130 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values); 131 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values, 132 Type resultType); 133 134 // Map strings to float types. 135 std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name); 136 137 } // namespace arith 138 } // namespace mlir 139 140 #endif // MLIR_DIALECT_ARITH_UTILS_UTILS_H 141