xref: /llvm-project/mlir/lib/Interfaces/InferTypeOpInterface.cpp (revision 092372da15e5165be14cdbb7cac3cf4976fd82d0)
17ce1e7abSRiver Riddle //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
27ce1e7abSRiver Riddle //
37ce1e7abSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47ce1e7abSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
57ce1e7abSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67ce1e7abSRiver Riddle //
77ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
87ce1e7abSRiver Riddle //
97ce1e7abSRiver Riddle // This file contains the definitions of the infer op interfaces defined in
107ce1e7abSRiver Riddle // `InferTypeOpInterface.td`.
117ce1e7abSRiver Riddle //
127ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
137ce1e7abSRiver Riddle 
147ce1e7abSRiver Riddle #include "mlir/Interfaces/InferTypeOpInterface.h"
1509f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
1609349303SJacques Pienaar #include "mlir/IR/Matchers.h"
1709349303SJacques Pienaar #include "llvm/Support/FormatVariadic.h"
187ce1e7abSRiver Riddle 
197ce1e7abSRiver Riddle using namespace mlir;
207ce1e7abSRiver Riddle 
217ce1e7abSRiver Riddle namespace mlir {
227ce1e7abSRiver Riddle #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
237ce1e7abSRiver Riddle } // namespace mlir
247ce1e7abSRiver Riddle 
25758329dcSMatthias Springer LogicalResult
26758329dcSMatthias Springer mlir::reifyResultShapes(OpBuilder &b, Operation *op,
27758329dcSMatthias Springer                         ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
28758329dcSMatthias Springer   auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
29758329dcSMatthias Springer   if (!reifiableOp)
30758329dcSMatthias Springer     return failure();
31758329dcSMatthias Springer   LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes);
32758329dcSMatthias Springer #ifndef NDEBUG
33758329dcSMatthias Springer   if (failed(status))
34758329dcSMatthias Springer     return failure();
35758329dcSMatthias Springer   // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced
36758329dcSMatthias Springer   // a correct result.
37758329dcSMatthias Springer   int64_t resultIdx = 0;
38758329dcSMatthias Springer   for (OpResult result : op->getResults()) {
395550c821STres Popp     auto shapedType = dyn_cast<ShapedType>(result.getType());
40758329dcSMatthias Springer     if (!shapedType)
41758329dcSMatthias Springer       continue;
42758329dcSMatthias Springer     if (!shapedType.hasRank()) {
43758329dcSMatthias Springer       // Nothing to check for unranked shaped values.
44758329dcSMatthias Springer       ++resultIdx;
45758329dcSMatthias Springer       continue;
46758329dcSMatthias Springer     }
47758329dcSMatthias Springer     // Assert one OpFoldResult per dimension.
48758329dcSMatthias Springer     assert(shapedType.getRank() ==
49758329dcSMatthias Springer                static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
50758329dcSMatthias Springer            "incorrect implementation of ReifyRankedShapedTypeOpInterface");
51758329dcSMatthias Springer     ++resultIdx;
52758329dcSMatthias Springer   }
53758329dcSMatthias Springer   // Assert that every shaped value result was reified.
54758329dcSMatthias Springer   assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) &&
55758329dcSMatthias Springer          "incorrect implementation of ReifyRankedShapedTypeOpInterface");
56758329dcSMatthias Springer #endif // NDEBUG
57758329dcSMatthias Springer   return status;
58758329dcSMatthias Springer }
59758329dcSMatthias Springer 
6009349303SJacques Pienaar bool ShapeAdaptor::hasRank() const {
6109349303SJacques Pienaar   if (val.isNull())
6209349303SJacques Pienaar     return false;
6368f58812STres Popp   if (auto t = llvm::dyn_cast_if_present<Type>(val))
645550c821STres Popp     return cast<ShapedType>(t).hasRank();
65*9192367aSKazu Hirata   if (isa<Attribute>(val))
6609349303SJacques Pienaar     return true;
67*9192367aSKazu Hirata   return cast<ShapedTypeComponents *>(val)->hasRank();
6809349303SJacques Pienaar }
6909349303SJacques Pienaar 
7009349303SJacques Pienaar Type ShapeAdaptor::getElementType() const {
7109349303SJacques Pienaar   if (val.isNull())
7209349303SJacques Pienaar     return nullptr;
7368f58812STres Popp   if (auto t = llvm::dyn_cast_if_present<Type>(val))
745550c821STres Popp     return cast<ShapedType>(t).getElementType();
75*9192367aSKazu Hirata   if (isa<Attribute>(val))
7609349303SJacques Pienaar     return nullptr;
77*9192367aSKazu Hirata   return cast<ShapedTypeComponents *>(val)->getElementType();
7809349303SJacques Pienaar }
7909349303SJacques Pienaar 
8009349303SJacques Pienaar void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
8109349303SJacques Pienaar   assert(hasRank());
8268f58812STres Popp   if (auto t = llvm::dyn_cast_if_present<Type>(val)) {
835550c821STres Popp     ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
8409349303SJacques Pienaar     res.assign(vals.begin(), vals.end());
8568f58812STres Popp   } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
865550c821STres Popp     auto dattr = cast<DenseIntElementsAttr>(attr);
8709349303SJacques Pienaar     res.clear();
8809349303SJacques Pienaar     res.reserve(dattr.size());
890cb5d7fcSRiver Riddle     for (auto it : dattr.getValues<APInt>())
9009349303SJacques Pienaar       res.push_back(it.getSExtValue());
9109349303SJacques Pienaar   } else {
92*9192367aSKazu Hirata     auto vals = cast<ShapedTypeComponents *>(val)->getDims();
9309349303SJacques Pienaar     res.assign(vals.begin(), vals.end());
9409349303SJacques Pienaar   }
9509349303SJacques Pienaar }
9609349303SJacques Pienaar 
9709349303SJacques Pienaar void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
9809349303SJacques Pienaar   assert(hasRank());
9909349303SJacques Pienaar   res.ranked = true;
10009349303SJacques Pienaar   getDims(res.dims);
10109349303SJacques Pienaar }
10209349303SJacques Pienaar 
10309349303SJacques Pienaar int64_t ShapeAdaptor::getDimSize(int index) const {
10409349303SJacques Pienaar   assert(hasRank());
10568f58812STres Popp   if (auto t = llvm::dyn_cast_if_present<Type>(val))
1065550c821STres Popp     return cast<ShapedType>(t).getDimSize(index);
10768f58812STres Popp   if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
1085550c821STres Popp     return cast<DenseIntElementsAttr>(attr)
109ae40d625SRiver Riddle         .getValues<APInt>()[index]
11009349303SJacques Pienaar         .getSExtValue();
111*9192367aSKazu Hirata   auto *stc = cast<ShapedTypeComponents *>(val);
11209349303SJacques Pienaar   return stc->getDims()[index];
11309349303SJacques Pienaar }
11409349303SJacques Pienaar 
11509349303SJacques Pienaar int64_t ShapeAdaptor::getRank() const {
11609349303SJacques Pienaar   assert(hasRank());
11768f58812STres Popp   if (auto t = llvm::dyn_cast_if_present<Type>(val))
1185550c821STres Popp     return cast<ShapedType>(t).getRank();
11968f58812STres Popp   if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
1205550c821STres Popp     return cast<DenseIntElementsAttr>(attr).size();
121*9192367aSKazu Hirata   return cast<ShapedTypeComponents *>(val)->getDims().size();
12209349303SJacques Pienaar }
12309349303SJacques Pienaar 
12409349303SJacques Pienaar bool ShapeAdaptor::hasStaticShape() const {
12509349303SJacques Pienaar   if (!hasRank())
12609349303SJacques Pienaar     return false;
12709349303SJacques Pienaar 
12868f58812STres Popp   if (auto t = llvm::dyn_cast_if_present<Type>(val))
1295550c821STres Popp     return cast<ShapedType>(t).hasStaticShape();
13068f58812STres Popp   if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
1315550c821STres Popp     auto dattr = cast<DenseIntElementsAttr>(attr);
1320cb5d7fcSRiver Riddle     for (auto index : dattr.getValues<APInt>())
13309349303SJacques Pienaar       if (ShapedType::isDynamic(index.getSExtValue()))
13409349303SJacques Pienaar         return false;
13509349303SJacques Pienaar     return true;
13609349303SJacques Pienaar   }
137*9192367aSKazu Hirata   auto *stc = cast<ShapedTypeComponents *>(val);
138380a1b20SKazu Hirata   return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
13909349303SJacques Pienaar }
14009349303SJacques Pienaar 
14109349303SJacques Pienaar int64_t ShapeAdaptor::getNumElements() const {
14209349303SJacques Pienaar   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
14309349303SJacques Pienaar 
14468f58812STres Popp   if (auto t = llvm::dyn_cast_if_present<Type>(val))
1455550c821STres Popp     return cast<ShapedType>(t).getNumElements();
14609349303SJacques Pienaar 
14768f58812STres Popp   if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
1485550c821STres Popp     auto dattr = cast<DenseIntElementsAttr>(attr);
14909349303SJacques Pienaar     int64_t num = 1;
1500cb5d7fcSRiver Riddle     for (auto index : dattr.getValues<APInt>()) {
15109349303SJacques Pienaar       num *= index.getZExtValue();
15209349303SJacques Pienaar       assert(num >= 0 && "integer overflow in element count computation");
15309349303SJacques Pienaar     }
15409349303SJacques Pienaar     return num;
15509349303SJacques Pienaar   }
15609349303SJacques Pienaar 
157*9192367aSKazu Hirata   auto *stc = cast<ShapedTypeComponents *>(val);
15809349303SJacques Pienaar   int64_t num = 1;
15909349303SJacques Pienaar   for (int64_t dim : stc->getDims()) {
16009349303SJacques Pienaar     num *= dim;
16109349303SJacques Pienaar     assert(num >= 0 && "integer overflow in element count computation");
16209349303SJacques Pienaar   }
16309349303SJacques Pienaar   return num;
16409349303SJacques Pienaar }
16509349303SJacques Pienaar 
16609349303SJacques Pienaar void ShapeAdaptor::dump() const {
16709349303SJacques Pienaar   if (!hasRank()) {
16809349303SJacques Pienaar     llvm::errs() << "<<unranked>>\n";
16909349303SJacques Pienaar     return;
17009349303SJacques Pienaar   }
17109349303SJacques Pienaar 
17209349303SJacques Pienaar   SmallVector<int64_t> dims;
17309349303SJacques Pienaar   getDims(dims);
17409349303SJacques Pienaar   auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
17509349303SJacques Pienaar     if (ShapedType::isDynamic(dim))
17609349303SJacques Pienaar       return "?";
17709349303SJacques Pienaar     return llvm::formatv("{0}", dim).str();
17809349303SJacques Pienaar   });
17909349303SJacques Pienaar   llvm::errs() << "rank = " << getRank() << " dims = [";
18009349303SJacques Pienaar   llvm::interleave(mapped, llvm::errs(), "x");
18109349303SJacques Pienaar   llvm::errs() << "]\n";
18209349303SJacques Pienaar }
18309349303SJacques Pienaar 
18409349303SJacques Pienaar ShapeAdaptor ValueShapeRange::getValueAsShape(int index) {
18509349303SJacques Pienaar   Value val = operator[](index);
18609349303SJacques Pienaar   if (valueToShape)
18709349303SJacques Pienaar     if (ShapeAdaptor ret = valueToShape(val))
18809349303SJacques Pienaar       return ret;
18909349303SJacques Pienaar 
19009349303SJacques Pienaar   DenseIntElementsAttr attr;
19109349303SJacques Pienaar   if (!matchPattern(val, m_Constant(&attr)))
19209349303SJacques Pienaar     return nullptr;
19309349303SJacques Pienaar   if (attr.getType().getRank() != 1)
19409349303SJacques Pienaar     return nullptr;
19509349303SJacques Pienaar   return attr;
19609349303SJacques Pienaar }
19709349303SJacques Pienaar 
19809349303SJacques Pienaar ShapeAdaptor ValueShapeRange::getShape(Value val) const {
19909349303SJacques Pienaar   if (operandShape)
20009349303SJacques Pienaar     if (ShapeAdaptor ret = operandShape(val))
20109349303SJacques Pienaar       return ret;
20209349303SJacques Pienaar   return val.getType();
20309349303SJacques Pienaar }
20409349303SJacques Pienaar 
20509349303SJacques Pienaar ShapeAdaptor ValueShapeRange::getShape(int index) const {
20609349303SJacques Pienaar   if (index < 0 || static_cast<size_t>(index) >= size())
20709349303SJacques Pienaar     return nullptr;
20809349303SJacques Pienaar   return getShape(operator[](index));
20909349303SJacques Pienaar }
21009349303SJacques Pienaar 
2117ce1e7abSRiver Riddle LogicalResult mlir::detail::inferReturnTensorTypes(
21247b0a9b9SAmanda Tang     ArrayRef<ShapedTypeComponents> retComponents,
2137ce1e7abSRiver Riddle     SmallVectorImpl<Type> &inferredReturnTypes) {
214e4853be2SMehdi Amini   for (const auto &shapeAndType : retComponents) {
215e8c961f5SMehdi Amini     Type elementTy = shapeAndType.getElementType();
216e8c961f5SMehdi Amini     assert(elementTy && "element type required to construct tensor");
217b2505ca2Ssmit-hinsu 
218b2505ca2Ssmit-hinsu     Attribute attr = shapeAndType.getAttribute();
219b2505ca2Ssmit-hinsu     if (shapeAndType.hasRank()) {
2207ce1e7abSRiver Riddle       inferredReturnTypes.push_back(
221e8c961f5SMehdi Amini           RankedTensorType::get(shapeAndType.getDims(), elementTy, attr));
222b2505ca2Ssmit-hinsu     } else {
223b2505ca2Ssmit-hinsu       assert(attr == nullptr && "attribute not supported");
224e8c961f5SMehdi Amini       inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy));
225b2505ca2Ssmit-hinsu     }
2267ce1e7abSRiver Riddle   }
2277ce1e7abSRiver Riddle   return success();
2287ce1e7abSRiver Riddle }
2297ce1e7abSRiver Riddle 
2307ce1e7abSRiver Riddle LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
231c8598fa2SJacques Pienaar   SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
2327ce1e7abSRiver Riddle   auto retTypeFn = cast<InferTypeOpInterface>(op);
2335e118f93SMehdi Amini   auto result = retTypeFn.refineReturnTypes(
2345e118f93SMehdi Amini       op->getContext(), op->getLoc(), op->getOperands(),
235b336ab42SOleksandr "Alex" Zinenko       op->getRawDictionaryAttrs(), op->getPropertiesStorage(), op->getRegions(),
236b336ab42SOleksandr "Alex" Zinenko       inferredReturnTypes);
2375e118f93SMehdi Amini   if (failed(result))
2385e118f93SMehdi Amini     op->emitOpError() << "failed to infer returned types";
2395e118f93SMehdi Amini 
2405e118f93SMehdi Amini   return result;
2417ce1e7abSRiver Riddle }
24236d936a2SMatthias Springer 
24336d936a2SMatthias Springer void mlir::detail::reportFatalInferReturnTypesError(OperationState &state) {
24436d936a2SMatthias Springer   std::string buffer;
24536d936a2SMatthias Springer   llvm::raw_string_ostream os(buffer);
24636d936a2SMatthias Springer   os << "Failed to infer result type(s):\n";
24736d936a2SMatthias Springer   os << "\"" << state.name << "\"(...) ";
24836d936a2SMatthias Springer   os << state.attributes.getDictionary(state.location.getContext());
24936d936a2SMatthias Springer   os << " : (";
25036d936a2SMatthias Springer   llvm::interleaveComma(state.operands, os,
25136d936a2SMatthias Springer                         [&](Value val) { os << val.getType(); });
25236d936a2SMatthias Springer   os << ") -> ( ??? )";
25336d936a2SMatthias Springer   emitRemark(state.location, "location of op");
25436d936a2SMatthias Springer   llvm::report_fatal_error(llvm::StringRef(buffer));
25536d936a2SMatthias Springer }
256