xref: /llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h (revision 5550c821897ab77e664977121a0e90ad5be1ff59)
1 //===-- ShapeUtils.h - TOSA shape support declarations ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Class declarations for shape utilities meant to assist shape propagation.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H
14 #define MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H
15 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Types.h"
18 #include "mlir/Interfaces/InferTypeOpInterface.h"
19 #include "llvm/ADT/Sequence.h"
20 #include "llvm/ADT/SmallVector.h"
21 
22 namespace mlir {
23 namespace tosa {
24 /// Statically known information for a particular Value.
25 ///
26 /// This struct currently tracks only information relevant for tensor/array-like
27 /// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped
28 /// type as long as it is in the default "no knowledge" state returned by
29 /// `getPessimisticValueState`. The important invariant is that we cannot
30 /// claim to know something about a value which is false.
31 ///
32 /// This class could also be called "dataflow facts", "lattice value", etc.
33 struct ValueKnowledge {
34   ValueKnowledge() = delete;
ValueKnowledgeValueKnowledge35   ValueKnowledge(bool hasRank, llvm::ArrayRef<int64_t> newSizes, Type dtype)
36       : hasError(false), hasRank(hasRank), dtype(dtype) {
37     sizes.reserve(newSizes.size());
38     for (auto size : newSizes)
39       sizes.push_back(size);
40   }
41 
42   operator bool() const { return !hasError; }
43 
44   // Get the static knowledge intrinsic to `type`.
getKnowledgeFromTypeValueKnowledge45   static ValueKnowledge getKnowledgeFromType(Type type) {
46     ValueKnowledge result = getPessimisticValueState();
47     if (auto shapedType = dyn_cast<ShapedType>(type)) {
48       if (shapedType.hasRank()) {
49         result.hasRank = true;
50         result.sizes.reserve(shapedType.getRank());
51         for (auto dim : shapedType.getShape())
52           result.sizes.push_back(dim);
53       }
54       result.dtype = shapedType.getElementType();
55     }
56     return result;
57   }
58 
59   // Return a pessimistic/conservative value state without assuming any knowlege
60   // about the IR.
getPessimisticValueStateValueKnowledge61   static ValueKnowledge getPessimisticValueState() {
62     return ValueKnowledge(false, {}, Type());
63   }
64 
getShapedTypeComponentsValueKnowledge65   ShapedTypeComponents getShapedTypeComponents() const {
66     return hasRank ? ShapedTypeComponents(sizes) : ShapedTypeComponents();
67   }
68 
getTypeValueKnowledge69   Type getType() const {
70     if (hasRank)
71       return RankedTensorType::get(llvm::ArrayRef(sizes), dtype);
72     return UnrankedTensorType::get(dtype);
73   }
74 
75   bool operator==(const ValueKnowledge &rhs) const {
76     return hasRank == rhs.hasRank && sizes == rhs.sizes && dtype == rhs.dtype;
77   }
78 
79   // Given two pieces of static knowledge, calculate conservatively the
80   // information we can be sure about.
joinValueKnowledge81   static ValueKnowledge join(const ValueKnowledge &lhs,
82                              const ValueKnowledge &rhs) {
83     // Mental model: All conditions are checking how to change from the safe "no
84     // knowledge" default-initialized state to a state with more knowledge
85     // consistent with lhs and rhs.
86     ValueKnowledge result = getPessimisticValueState();
87     result.hasError = true;
88 
89     if (!lhs || !rhs || lhs.dtype != rhs.dtype)
90       return result;
91 
92     result.hasError = false;
93     result.dtype = lhs.dtype;
94 
95     if (!lhs.hasRank && !rhs.hasRank)
96       return result;
97 
98     if (!rhs.hasRank) {
99       result.hasRank = true;
100       result.sizes = lhs.sizes;
101       return result;
102     }
103 
104     if (!lhs.hasRank) {
105       result.hasRank = true;
106       result.sizes = rhs.sizes;
107       return result;
108     }
109 
110     if (lhs.sizes.size() != rhs.sizes.size())
111       return result;
112 
113     result.hasRank = true;
114     result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamic);
115     for (auto i : llvm::seq<unsigned>(0, result.sizes.size())) {
116       int64_t lhsSize = lhs.sizes[i];
117       int64_t rhsSize = rhs.sizes[i];
118       int64_t &resultSize = result.sizes[i];
119       if (lhsSize == ShapedType::kDynamic) {
120         resultSize = rhsSize;
121       } else if (rhsSize == ShapedType::kDynamic) {
122         resultSize = lhsSize;
123       } else if (lhsSize == rhsSize) {
124         resultSize = lhsSize;
125       } else {
126         result.hasError = true;
127       }
128     }
129 
130     return result;
131   }
132 
133   // Given to types, generate a new ValueKnowledge that meets to cover both
134   // cases. E.g. if the rank of the LHS and RHS differ, the resulting tensor
135   // has unknown rank.
meetValueKnowledge136   static ValueKnowledge meet(const ValueKnowledge &lhs,
137                              const ValueKnowledge &rhs) {
138     ValueKnowledge result = getPessimisticValueState();
139     result.hasError = true;
140 
141     if (!lhs || !rhs || lhs.dtype != rhs.dtype)
142       return result;
143 
144     result.hasError = false;
145     result.dtype = lhs.dtype;
146 
147     if (!lhs.hasRank || !rhs.hasRank) {
148       result.hasRank = false;
149       return result;
150     }
151 
152     if (lhs.sizes.size() != rhs.sizes.size()) {
153       result.hasRank = false;
154       return result;
155     }
156 
157     result.hasRank = true;
158     result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamic);
159     for (int i = 0, e = lhs.sizes.size(); i < e; i++) {
160       if (lhs.sizes[i] == rhs.sizes[i]) {
161         result.sizes[i] = lhs.sizes[i];
162       }
163     }
164 
165     return result;
166   }
167 
168   // Whether the value information has an error.
169   bool hasError;
170   // Whether the value has known rank.
171   bool hasRank;
172   // If `hasRank`, the sizes along each rank. Unknown sizes are represented as
173   // `ShapedType::kDynamic`.
174   llvm::SmallVector<int64_t> sizes;
175   // The dtype of a tensor.
176   // This is equal to nullptr if we don't know that it is a specific concrete
177   // type.
178   Type dtype;
179 };
180 } // namespace tosa
181 } // namespace mlir
182 
183 #endif // MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H
184