17ce1e7abSRiver Riddle //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===// 27ce1e7abSRiver Riddle // 37ce1e7abSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 47ce1e7abSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 57ce1e7abSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 67ce1e7abSRiver Riddle // 77ce1e7abSRiver Riddle //===----------------------------------------------------------------------===// 87ce1e7abSRiver Riddle // 97ce1e7abSRiver Riddle // This file contains the definitions of the infer op interfaces defined in 107ce1e7abSRiver Riddle // `InferTypeOpInterface.td`. 117ce1e7abSRiver Riddle // 127ce1e7abSRiver Riddle //===----------------------------------------------------------------------===// 137ce1e7abSRiver Riddle 147ce1e7abSRiver Riddle #include "mlir/Interfaces/InferTypeOpInterface.h" 1509f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h" 1609349303SJacques Pienaar #include "mlir/IR/Matchers.h" 1709349303SJacques Pienaar #include "llvm/Support/FormatVariadic.h" 187ce1e7abSRiver Riddle 197ce1e7abSRiver Riddle using namespace mlir; 207ce1e7abSRiver Riddle 217ce1e7abSRiver Riddle namespace mlir { 227ce1e7abSRiver Riddle #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc" 237ce1e7abSRiver Riddle } // namespace mlir 247ce1e7abSRiver Riddle 25758329dcSMatthias Springer LogicalResult 26758329dcSMatthias Springer mlir::reifyResultShapes(OpBuilder &b, Operation *op, 27758329dcSMatthias Springer ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 28758329dcSMatthias Springer auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op); 29758329dcSMatthias Springer if (!reifiableOp) 30758329dcSMatthias Springer return failure(); 31758329dcSMatthias Springer LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes); 32758329dcSMatthias Springer #ifndef NDEBUG 33758329dcSMatthias Springer if (failed(status)) 34758329dcSMatthias Springer return failure(); 35758329dcSMatthias Springer // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced 36758329dcSMatthias Springer // a correct result. 37758329dcSMatthias Springer int64_t resultIdx = 0; 38758329dcSMatthias Springer for (OpResult result : op->getResults()) { 395550c821STres Popp auto shapedType = dyn_cast<ShapedType>(result.getType()); 40758329dcSMatthias Springer if (!shapedType) 41758329dcSMatthias Springer continue; 42758329dcSMatthias Springer if (!shapedType.hasRank()) { 43758329dcSMatthias Springer // Nothing to check for unranked shaped values. 44758329dcSMatthias Springer ++resultIdx; 45758329dcSMatthias Springer continue; 46758329dcSMatthias Springer } 47758329dcSMatthias Springer // Assert one OpFoldResult per dimension. 48758329dcSMatthias Springer assert(shapedType.getRank() == 49758329dcSMatthias Springer static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) && 50758329dcSMatthias Springer "incorrect implementation of ReifyRankedShapedTypeOpInterface"); 51758329dcSMatthias Springer ++resultIdx; 52758329dcSMatthias Springer } 53758329dcSMatthias Springer // Assert that every shaped value result was reified. 54758329dcSMatthias Springer assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) && 55758329dcSMatthias Springer "incorrect implementation of ReifyRankedShapedTypeOpInterface"); 56758329dcSMatthias Springer #endif // NDEBUG 57758329dcSMatthias Springer return status; 58758329dcSMatthias Springer } 59758329dcSMatthias Springer 6009349303SJacques Pienaar bool ShapeAdaptor::hasRank() const { 6109349303SJacques Pienaar if (val.isNull()) 6209349303SJacques Pienaar return false; 6368f58812STres Popp if (auto t = llvm::dyn_cast_if_present<Type>(val)) 645550c821STres Popp return cast<ShapedType>(t).hasRank(); 65*9192367aSKazu Hirata if (isa<Attribute>(val)) 6609349303SJacques Pienaar return true; 67*9192367aSKazu Hirata return cast<ShapedTypeComponents *>(val)->hasRank(); 6809349303SJacques Pienaar } 6909349303SJacques Pienaar 7009349303SJacques Pienaar Type ShapeAdaptor::getElementType() const { 7109349303SJacques Pienaar if (val.isNull()) 7209349303SJacques Pienaar return nullptr; 7368f58812STres Popp if (auto t = llvm::dyn_cast_if_present<Type>(val)) 745550c821STres Popp return cast<ShapedType>(t).getElementType(); 75*9192367aSKazu Hirata if (isa<Attribute>(val)) 7609349303SJacques Pienaar return nullptr; 77*9192367aSKazu Hirata return cast<ShapedTypeComponents *>(val)->getElementType(); 7809349303SJacques Pienaar } 7909349303SJacques Pienaar 8009349303SJacques Pienaar void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const { 8109349303SJacques Pienaar assert(hasRank()); 8268f58812STres Popp if (auto t = llvm::dyn_cast_if_present<Type>(val)) { 835550c821STres Popp ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape(); 8409349303SJacques Pienaar res.assign(vals.begin(), vals.end()); 8568f58812STres Popp } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) { 865550c821STres Popp auto dattr = cast<DenseIntElementsAttr>(attr); 8709349303SJacques Pienaar res.clear(); 8809349303SJacques Pienaar res.reserve(dattr.size()); 890cb5d7fcSRiver Riddle for (auto it : dattr.getValues<APInt>()) 9009349303SJacques Pienaar res.push_back(it.getSExtValue()); 9109349303SJacques Pienaar } else { 92*9192367aSKazu Hirata auto vals = cast<ShapedTypeComponents *>(val)->getDims(); 9309349303SJacques Pienaar res.assign(vals.begin(), vals.end()); 9409349303SJacques Pienaar } 9509349303SJacques Pienaar } 9609349303SJacques Pienaar 9709349303SJacques Pienaar void ShapeAdaptor::getDims(ShapedTypeComponents &res) const { 9809349303SJacques Pienaar assert(hasRank()); 9909349303SJacques Pienaar res.ranked = true; 10009349303SJacques Pienaar getDims(res.dims); 10109349303SJacques Pienaar } 10209349303SJacques Pienaar 10309349303SJacques Pienaar int64_t ShapeAdaptor::getDimSize(int index) const { 10409349303SJacques Pienaar assert(hasRank()); 10568f58812STres Popp if (auto t = llvm::dyn_cast_if_present<Type>(val)) 1065550c821STres Popp return cast<ShapedType>(t).getDimSize(index); 10768f58812STres Popp if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) 1085550c821STres Popp return cast<DenseIntElementsAttr>(attr) 109ae40d625SRiver Riddle .getValues<APInt>()[index] 11009349303SJacques Pienaar .getSExtValue(); 111*9192367aSKazu Hirata auto *stc = cast<ShapedTypeComponents *>(val); 11209349303SJacques Pienaar return stc->getDims()[index]; 11309349303SJacques Pienaar } 11409349303SJacques Pienaar 11509349303SJacques Pienaar int64_t ShapeAdaptor::getRank() const { 11609349303SJacques Pienaar assert(hasRank()); 11768f58812STres Popp if (auto t = llvm::dyn_cast_if_present<Type>(val)) 1185550c821STres Popp return cast<ShapedType>(t).getRank(); 11968f58812STres Popp if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) 1205550c821STres Popp return cast<DenseIntElementsAttr>(attr).size(); 121*9192367aSKazu Hirata return cast<ShapedTypeComponents *>(val)->getDims().size(); 12209349303SJacques Pienaar } 12309349303SJacques Pienaar 12409349303SJacques Pienaar bool ShapeAdaptor::hasStaticShape() const { 12509349303SJacques Pienaar if (!hasRank()) 12609349303SJacques Pienaar return false; 12709349303SJacques Pienaar 12868f58812STres Popp if (auto t = llvm::dyn_cast_if_present<Type>(val)) 1295550c821STres Popp return cast<ShapedType>(t).hasStaticShape(); 13068f58812STres Popp if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) { 1315550c821STres Popp auto dattr = cast<DenseIntElementsAttr>(attr); 1320cb5d7fcSRiver Riddle for (auto index : dattr.getValues<APInt>()) 13309349303SJacques Pienaar if (ShapedType::isDynamic(index.getSExtValue())) 13409349303SJacques Pienaar return false; 13509349303SJacques Pienaar return true; 13609349303SJacques Pienaar } 137*9192367aSKazu Hirata auto *stc = cast<ShapedTypeComponents *>(val); 138380a1b20SKazu Hirata return llvm::none_of(stc->getDims(), ShapedType::isDynamic); 13909349303SJacques Pienaar } 14009349303SJacques Pienaar 14109349303SJacques Pienaar int64_t ShapeAdaptor::getNumElements() const { 14209349303SJacques Pienaar assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 14309349303SJacques Pienaar 14468f58812STres Popp if (auto t = llvm::dyn_cast_if_present<Type>(val)) 1455550c821STres Popp return cast<ShapedType>(t).getNumElements(); 14609349303SJacques Pienaar 14768f58812STres Popp if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) { 1485550c821STres Popp auto dattr = cast<DenseIntElementsAttr>(attr); 14909349303SJacques Pienaar int64_t num = 1; 1500cb5d7fcSRiver Riddle for (auto index : dattr.getValues<APInt>()) { 15109349303SJacques Pienaar num *= index.getZExtValue(); 15209349303SJacques Pienaar assert(num >= 0 && "integer overflow in element count computation"); 15309349303SJacques Pienaar } 15409349303SJacques Pienaar return num; 15509349303SJacques Pienaar } 15609349303SJacques Pienaar 157*9192367aSKazu Hirata auto *stc = cast<ShapedTypeComponents *>(val); 15809349303SJacques Pienaar int64_t num = 1; 15909349303SJacques Pienaar for (int64_t dim : stc->getDims()) { 16009349303SJacques Pienaar num *= dim; 16109349303SJacques Pienaar assert(num >= 0 && "integer overflow in element count computation"); 16209349303SJacques Pienaar } 16309349303SJacques Pienaar return num; 16409349303SJacques Pienaar } 16509349303SJacques Pienaar 16609349303SJacques Pienaar void ShapeAdaptor::dump() const { 16709349303SJacques Pienaar if (!hasRank()) { 16809349303SJacques Pienaar llvm::errs() << "<<unranked>>\n"; 16909349303SJacques Pienaar return; 17009349303SJacques Pienaar } 17109349303SJacques Pienaar 17209349303SJacques Pienaar SmallVector<int64_t> dims; 17309349303SJacques Pienaar getDims(dims); 17409349303SJacques Pienaar auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string { 17509349303SJacques Pienaar if (ShapedType::isDynamic(dim)) 17609349303SJacques Pienaar return "?"; 17709349303SJacques Pienaar return llvm::formatv("{0}", dim).str(); 17809349303SJacques Pienaar }); 17909349303SJacques Pienaar llvm::errs() << "rank = " << getRank() << " dims = ["; 18009349303SJacques Pienaar llvm::interleave(mapped, llvm::errs(), "x"); 18109349303SJacques Pienaar llvm::errs() << "]\n"; 18209349303SJacques Pienaar } 18309349303SJacques Pienaar 18409349303SJacques Pienaar ShapeAdaptor ValueShapeRange::getValueAsShape(int index) { 18509349303SJacques Pienaar Value val = operator[](index); 18609349303SJacques Pienaar if (valueToShape) 18709349303SJacques Pienaar if (ShapeAdaptor ret = valueToShape(val)) 18809349303SJacques Pienaar return ret; 18909349303SJacques Pienaar 19009349303SJacques Pienaar DenseIntElementsAttr attr; 19109349303SJacques Pienaar if (!matchPattern(val, m_Constant(&attr))) 19209349303SJacques Pienaar return nullptr; 19309349303SJacques Pienaar if (attr.getType().getRank() != 1) 19409349303SJacques Pienaar return nullptr; 19509349303SJacques Pienaar return attr; 19609349303SJacques Pienaar } 19709349303SJacques Pienaar 19809349303SJacques Pienaar ShapeAdaptor ValueShapeRange::getShape(Value val) const { 19909349303SJacques Pienaar if (operandShape) 20009349303SJacques Pienaar if (ShapeAdaptor ret = operandShape(val)) 20109349303SJacques Pienaar return ret; 20209349303SJacques Pienaar return val.getType(); 20309349303SJacques Pienaar } 20409349303SJacques Pienaar 20509349303SJacques Pienaar ShapeAdaptor ValueShapeRange::getShape(int index) const { 20609349303SJacques Pienaar if (index < 0 || static_cast<size_t>(index) >= size()) 20709349303SJacques Pienaar return nullptr; 20809349303SJacques Pienaar return getShape(operator[](index)); 20909349303SJacques Pienaar } 21009349303SJacques Pienaar 2117ce1e7abSRiver Riddle LogicalResult mlir::detail::inferReturnTensorTypes( 21247b0a9b9SAmanda Tang ArrayRef<ShapedTypeComponents> retComponents, 2137ce1e7abSRiver Riddle SmallVectorImpl<Type> &inferredReturnTypes) { 214e4853be2SMehdi Amini for (const auto &shapeAndType : retComponents) { 215e8c961f5SMehdi Amini Type elementTy = shapeAndType.getElementType(); 216e8c961f5SMehdi Amini assert(elementTy && "element type required to construct tensor"); 217b2505ca2Ssmit-hinsu 218b2505ca2Ssmit-hinsu Attribute attr = shapeAndType.getAttribute(); 219b2505ca2Ssmit-hinsu if (shapeAndType.hasRank()) { 2207ce1e7abSRiver Riddle inferredReturnTypes.push_back( 221e8c961f5SMehdi Amini RankedTensorType::get(shapeAndType.getDims(), elementTy, attr)); 222b2505ca2Ssmit-hinsu } else { 223b2505ca2Ssmit-hinsu assert(attr == nullptr && "attribute not supported"); 224e8c961f5SMehdi Amini inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy)); 225b2505ca2Ssmit-hinsu } 2267ce1e7abSRiver Riddle } 2277ce1e7abSRiver Riddle return success(); 2287ce1e7abSRiver Riddle } 2297ce1e7abSRiver Riddle 2307ce1e7abSRiver Riddle LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { 231c8598fa2SJacques Pienaar SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes()); 2327ce1e7abSRiver Riddle auto retTypeFn = cast<InferTypeOpInterface>(op); 2335e118f93SMehdi Amini auto result = retTypeFn.refineReturnTypes( 2345e118f93SMehdi Amini op->getContext(), op->getLoc(), op->getOperands(), 235b336ab42SOleksandr "Alex" Zinenko op->getRawDictionaryAttrs(), op->getPropertiesStorage(), op->getRegions(), 236b336ab42SOleksandr "Alex" Zinenko inferredReturnTypes); 2375e118f93SMehdi Amini if (failed(result)) 2385e118f93SMehdi Amini op->emitOpError() << "failed to infer returned types"; 2395e118f93SMehdi Amini 2405e118f93SMehdi Amini return result; 2417ce1e7abSRiver Riddle } 24236d936a2SMatthias Springer 24336d936a2SMatthias Springer void mlir::detail::reportFatalInferReturnTypesError(OperationState &state) { 24436d936a2SMatthias Springer std::string buffer; 24536d936a2SMatthias Springer llvm::raw_string_ostream os(buffer); 24636d936a2SMatthias Springer os << "Failed to infer result type(s):\n"; 24736d936a2SMatthias Springer os << "\"" << state.name << "\"(...) "; 24836d936a2SMatthias Springer os << state.attributes.getDictionary(state.location.getContext()); 24936d936a2SMatthias Springer os << " : ("; 25036d936a2SMatthias Springer llvm::interleaveComma(state.operands, os, 25136d936a2SMatthias Springer [&](Value val) { os << val.getType(); }); 25236d936a2SMatthias Springer os << ") -> ( ??? )"; 25336d936a2SMatthias Springer emitRemark(state.location, "location of op"); 25436d936a2SMatthias Springer llvm::report_fatal_error(llvm::StringRef(buffer)); 25536d936a2SMatthias Springer } 256