xref: /llvm-project/mlir/lib/Interfaces/InferTypeOpInterface.cpp (revision 092372da15e5165be14cdbb7cac3cf4976fd82d0)
1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "llvm/Support/FormatVariadic.h"
18 
19 using namespace mlir;
20 
21 namespace mlir {
22 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
23 } // namespace mlir
24 
25 LogicalResult
26 mlir::reifyResultShapes(OpBuilder &b, Operation *op,
27                         ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
28   auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
29   if (!reifiableOp)
30     return failure();
31   LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes);
32 #ifndef NDEBUG
33   if (failed(status))
34     return failure();
35   // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced
36   // a correct result.
37   int64_t resultIdx = 0;
38   for (OpResult result : op->getResults()) {
39     auto shapedType = dyn_cast<ShapedType>(result.getType());
40     if (!shapedType)
41       continue;
42     if (!shapedType.hasRank()) {
43       // Nothing to check for unranked shaped values.
44       ++resultIdx;
45       continue;
46     }
47     // Assert one OpFoldResult per dimension.
48     assert(shapedType.getRank() ==
49                static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
50            "incorrect implementation of ReifyRankedShapedTypeOpInterface");
51     ++resultIdx;
52   }
53   // Assert that every shaped value result was reified.
54   assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) &&
55          "incorrect implementation of ReifyRankedShapedTypeOpInterface");
56 #endif // NDEBUG
57   return status;
58 }
59 
60 bool ShapeAdaptor::hasRank() const {
61   if (val.isNull())
62     return false;
63   if (auto t = llvm::dyn_cast_if_present<Type>(val))
64     return cast<ShapedType>(t).hasRank();
65   if (isa<Attribute>(val))
66     return true;
67   return cast<ShapedTypeComponents *>(val)->hasRank();
68 }
69 
70 Type ShapeAdaptor::getElementType() const {
71   if (val.isNull())
72     return nullptr;
73   if (auto t = llvm::dyn_cast_if_present<Type>(val))
74     return cast<ShapedType>(t).getElementType();
75   if (isa<Attribute>(val))
76     return nullptr;
77   return cast<ShapedTypeComponents *>(val)->getElementType();
78 }
79 
80 void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
81   assert(hasRank());
82   if (auto t = llvm::dyn_cast_if_present<Type>(val)) {
83     ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
84     res.assign(vals.begin(), vals.end());
85   } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
86     auto dattr = cast<DenseIntElementsAttr>(attr);
87     res.clear();
88     res.reserve(dattr.size());
89     for (auto it : dattr.getValues<APInt>())
90       res.push_back(it.getSExtValue());
91   } else {
92     auto vals = cast<ShapedTypeComponents *>(val)->getDims();
93     res.assign(vals.begin(), vals.end());
94   }
95 }
96 
97 void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
98   assert(hasRank());
99   res.ranked = true;
100   getDims(res.dims);
101 }
102 
103 int64_t ShapeAdaptor::getDimSize(int index) const {
104   assert(hasRank());
105   if (auto t = llvm::dyn_cast_if_present<Type>(val))
106     return cast<ShapedType>(t).getDimSize(index);
107   if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
108     return cast<DenseIntElementsAttr>(attr)
109         .getValues<APInt>()[index]
110         .getSExtValue();
111   auto *stc = cast<ShapedTypeComponents *>(val);
112   return stc->getDims()[index];
113 }
114 
115 int64_t ShapeAdaptor::getRank() const {
116   assert(hasRank());
117   if (auto t = llvm::dyn_cast_if_present<Type>(val))
118     return cast<ShapedType>(t).getRank();
119   if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
120     return cast<DenseIntElementsAttr>(attr).size();
121   return cast<ShapedTypeComponents *>(val)->getDims().size();
122 }
123 
124 bool ShapeAdaptor::hasStaticShape() const {
125   if (!hasRank())
126     return false;
127 
128   if (auto t = llvm::dyn_cast_if_present<Type>(val))
129     return cast<ShapedType>(t).hasStaticShape();
130   if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
131     auto dattr = cast<DenseIntElementsAttr>(attr);
132     for (auto index : dattr.getValues<APInt>())
133       if (ShapedType::isDynamic(index.getSExtValue()))
134         return false;
135     return true;
136   }
137   auto *stc = cast<ShapedTypeComponents *>(val);
138   return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
139 }
140 
141 int64_t ShapeAdaptor::getNumElements() const {
142   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
143 
144   if (auto t = llvm::dyn_cast_if_present<Type>(val))
145     return cast<ShapedType>(t).getNumElements();
146 
147   if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
148     auto dattr = cast<DenseIntElementsAttr>(attr);
149     int64_t num = 1;
150     for (auto index : dattr.getValues<APInt>()) {
151       num *= index.getZExtValue();
152       assert(num >= 0 && "integer overflow in element count computation");
153     }
154     return num;
155   }
156 
157   auto *stc = cast<ShapedTypeComponents *>(val);
158   int64_t num = 1;
159   for (int64_t dim : stc->getDims()) {
160     num *= dim;
161     assert(num >= 0 && "integer overflow in element count computation");
162   }
163   return num;
164 }
165 
166 void ShapeAdaptor::dump() const {
167   if (!hasRank()) {
168     llvm::errs() << "<<unranked>>\n";
169     return;
170   }
171 
172   SmallVector<int64_t> dims;
173   getDims(dims);
174   auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
175     if (ShapedType::isDynamic(dim))
176       return "?";
177     return llvm::formatv("{0}", dim).str();
178   });
179   llvm::errs() << "rank = " << getRank() << " dims = [";
180   llvm::interleave(mapped, llvm::errs(), "x");
181   llvm::errs() << "]\n";
182 }
183 
184 ShapeAdaptor ValueShapeRange::getValueAsShape(int index) {
185   Value val = operator[](index);
186   if (valueToShape)
187     if (ShapeAdaptor ret = valueToShape(val))
188       return ret;
189 
190   DenseIntElementsAttr attr;
191   if (!matchPattern(val, m_Constant(&attr)))
192     return nullptr;
193   if (attr.getType().getRank() != 1)
194     return nullptr;
195   return attr;
196 }
197 
198 ShapeAdaptor ValueShapeRange::getShape(Value val) const {
199   if (operandShape)
200     if (ShapeAdaptor ret = operandShape(val))
201       return ret;
202   return val.getType();
203 }
204 
205 ShapeAdaptor ValueShapeRange::getShape(int index) const {
206   if (index < 0 || static_cast<size_t>(index) >= size())
207     return nullptr;
208   return getShape(operator[](index));
209 }
210 
211 LogicalResult mlir::detail::inferReturnTensorTypes(
212     ArrayRef<ShapedTypeComponents> retComponents,
213     SmallVectorImpl<Type> &inferredReturnTypes) {
214   for (const auto &shapeAndType : retComponents) {
215     Type elementTy = shapeAndType.getElementType();
216     assert(elementTy && "element type required to construct tensor");
217 
218     Attribute attr = shapeAndType.getAttribute();
219     if (shapeAndType.hasRank()) {
220       inferredReturnTypes.push_back(
221           RankedTensorType::get(shapeAndType.getDims(), elementTy, attr));
222     } else {
223       assert(attr == nullptr && "attribute not supported");
224       inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy));
225     }
226   }
227   return success();
228 }
229 
230 LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
231   SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
232   auto retTypeFn = cast<InferTypeOpInterface>(op);
233   auto result = retTypeFn.refineReturnTypes(
234       op->getContext(), op->getLoc(), op->getOperands(),
235       op->getRawDictionaryAttrs(), op->getPropertiesStorage(), op->getRegions(),
236       inferredReturnTypes);
237   if (failed(result))
238     op->emitOpError() << "failed to infer returned types";
239 
240   return result;
241 }
242 
243 void mlir::detail::reportFatalInferReturnTypesError(OperationState &state) {
244   std::string buffer;
245   llvm::raw_string_ostream os(buffer);
246   os << "Failed to infer result type(s):\n";
247   os << "\"" << state.name << "\"(...) ";
248   os << state.attributes.getDictionary(state.location.getContext());
249   os << " : (";
250   llvm::interleaveComma(state.operands, os,
251                         [&](Value val) { os << val.getType(); });
252   os << ") -> ( ??? )";
253   emitRemark(state.location, "location of op");
254   llvm::report_fatal_error(llvm::StringRef(buffer));
255 }
256