xref: /llvm-project/mlir/include/mlir/Dialect/Arith/Utils/Utils.h (revision 1fd1f65569f565b5b06fd9464b3b91fcd6f2fa2a)
1 //===- Utils.h - General Arith transformation utilities ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines prototypes for various transformation utilities for
10 // the Arith dialect. These are not passes by themselves but are used
11 // either by passes, optimization sequences, or in turn by other transformation
12 // utilities.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef MLIR_DIALECT_ARITH_UTILS_UTILS_H
17 #define MLIR_DIALECT_ARITH_UTILS_UTILS_H
18 
19 #include "mlir/Dialect/Arith/IR/Arith.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/Value.h"
23 #include "llvm/ADT/ArrayRef.h"
24 
25 namespace mlir {
26 
27 using ReassociationIndices = SmallVector<int64_t, 2>;
28 
29 /// Infer the output shape for a {memref|tensor}.expand_shape when it is
30 /// possible to do so.
31 ///
32 /// Note: This should *only* be used to implement
33 /// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
34 /// If you need to infer the output shape you should use the static method of
35 /// `ExpandShapeOp` instead of calling this.
36 ///
37 /// `inputShape` is the shape of the tensor or memref being expanded as a
38 /// sequence of SSA values or constants. `expandedType` is the output shape of
39 /// the expand_shape operation. `reassociation` is the reassociation denoting
40 /// the output dims each input dim is mapped to.
41 ///
42 /// Returns the output shape in `outputShape` and `staticOutputShape`, following
43 /// the conventions for the output_shape and static_output_shape inputs to the
44 /// expand_shape ops.
45 std::optional<SmallVector<OpFoldResult>>
46 inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType,
47                             ArrayRef<ReassociationIndices> reassociation,
48                             ArrayRef<OpFoldResult> inputShape);
49 
50 /// Matches a ConstantIndexOp.
51 detail::op_matcher<arith::ConstantIndexOp> matchConstantIndex();
52 
53 llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank,
54                                             ArrayRef<int64_t> shape);
55 
56 /// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
57 /// a Value or creates a ConstantOp if it casts to an Integer Attribute.
58 /// Other attribute types are not supported.
59 Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc,
60                                     OpFoldResult ofr);
61 
62 /// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
63 /// a Value or creates a ConstantIndexOp if it casts to an Integer Attribute.
64 /// Other attribute types are not supported.
65 Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
66                                       OpFoldResult ofr);
67 
68 /// Similar to the other overload, but converts multiple OpFoldResults into
69 /// Values.
70 SmallVector<Value>
71 getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
72                                 ArrayRef<OpFoldResult> valueOrAttrVec);
73 
74 /// Create a cast from an index-like value (index or integer) to another
75 /// index-like value. If the value type and the target type are the same, it
76 /// returns the original value.
77 Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
78                                       Type targetType, Value value);
79 
80 /// Converts a scalar value `operand` to type `toType`. If the value doesn't
81 /// convert, a warning will be issued and the operand is returned as is (which
82 /// will presumably yield a verification issue downstream).
83 Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
84                            Type toType, bool isUnsignedCast);
85 
86 /// Create a constant of type `type` at location `loc` whose value is `value`
87 /// (an APInt or APFloat whose type must match the element type of `type`).
88 /// If `type` is a shaped type, create a splat constant of the given value.
89 /// Constants are folded if possible.
90 Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
91                                   const APInt &value);
92 Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
93                                   int64_t value);
94 Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
95                                   const APFloat &value);
96 
97 /// Returns the int type of the integer in ofr.
98 /// Other attribute types are not supported.
99 Type getType(OpFoldResult ofr);
100 
101 /// Helper struct to build simple arithmetic quantities with minimal type
102 /// inference support.
103 struct ArithBuilder {
104   ArithBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {}
105 
106   Value _and(Value lhs, Value rhs);
107   Value add(Value lhs, Value rhs);
108   Value sub(Value lhs, Value rhs);
109   Value mul(Value lhs, Value rhs);
110   Value select(Value cmp, Value lhs, Value rhs);
111   Value sgt(Value lhs, Value rhs);
112   Value slt(Value lhs, Value rhs);
113 
114 private:
115   OpBuilder &b;
116   Location loc;
117 };
118 
119 namespace arith {
120 
121 // Build the product of a sequence.
122 // If values = (v0, v1, ..., vn) than the returned
123 // value is v0 * v1 * ... * vn.
124 // All values must have the same type.
125 //
126 // The version without `resultType` must contain at least one element in values.
127 // Then the result will have the same type as the elements in `values`.
128 // If `values` is empty in the version with `resultType` returns 1 with type
129 // `resultType`.
130 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
131 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
132                     Type resultType);
133 
134 // Map strings to float types.
135 std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name);
136 
137 } // namespace arith
138 } // namespace mlir
139 
140 #endif // MLIR_DIALECT_ARITH_UTILS_UTILS_H
141