1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the definitions of the infer op interfaces defined in 10 // `InferTypeOpInterface.td`. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Interfaces/InferTypeOpInterface.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Matchers.h" 17 #include "llvm/Support/FormatVariadic.h" 18 19 using namespace mlir; 20 21 namespace mlir { 22 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc" 23 } // namespace mlir 24 25 LogicalResult 26 mlir::reifyResultShapes(OpBuilder &b, Operation *op, 27 ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 28 auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op); 29 if (!reifiableOp) 30 return failure(); 31 LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes); 32 #ifndef NDEBUG 33 if (failed(status)) 34 return failure(); 35 // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced 36 // a correct result. 37 int64_t resultIdx = 0; 38 for (OpResult result : op->getResults()) { 39 auto shapedType = dyn_cast<ShapedType>(result.getType()); 40 if (!shapedType) 41 continue; 42 if (!shapedType.hasRank()) { 43 // Nothing to check for unranked shaped values. 44 ++resultIdx; 45 continue; 46 } 47 // Assert one OpFoldResult per dimension. 48 assert(shapedType.getRank() == 49 static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) && 50 "incorrect implementation of ReifyRankedShapedTypeOpInterface"); 51 ++resultIdx; 52 } 53 // Assert that every shaped value result was reified. 54 assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) && 55 "incorrect implementation of ReifyRankedShapedTypeOpInterface"); 56 #endif // NDEBUG 57 return status; 58 } 59 60 bool ShapeAdaptor::hasRank() const { 61 if (val.isNull()) 62 return false; 63 if (auto t = llvm::dyn_cast_if_present<Type>(val)) 64 return cast<ShapedType>(t).hasRank(); 65 if (isa<Attribute>(val)) 66 return true; 67 return cast<ShapedTypeComponents *>(val)->hasRank(); 68 } 69 70 Type ShapeAdaptor::getElementType() const { 71 if (val.isNull()) 72 return nullptr; 73 if (auto t = llvm::dyn_cast_if_present<Type>(val)) 74 return cast<ShapedType>(t).getElementType(); 75 if (isa<Attribute>(val)) 76 return nullptr; 77 return cast<ShapedTypeComponents *>(val)->getElementType(); 78 } 79 80 void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const { 81 assert(hasRank()); 82 if (auto t = llvm::dyn_cast_if_present<Type>(val)) { 83 ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape(); 84 res.assign(vals.begin(), vals.end()); 85 } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) { 86 auto dattr = cast<DenseIntElementsAttr>(attr); 87 res.clear(); 88 res.reserve(dattr.size()); 89 for (auto it : dattr.getValues<APInt>()) 90 res.push_back(it.getSExtValue()); 91 } else { 92 auto vals = cast<ShapedTypeComponents *>(val)->getDims(); 93 res.assign(vals.begin(), vals.end()); 94 } 95 } 96 97 void ShapeAdaptor::getDims(ShapedTypeComponents &res) const { 98 assert(hasRank()); 99 res.ranked = true; 100 getDims(res.dims); 101 } 102 103 int64_t ShapeAdaptor::getDimSize(int index) const { 104 assert(hasRank()); 105 if (auto t = llvm::dyn_cast_if_present<Type>(val)) 106 return cast<ShapedType>(t).getDimSize(index); 107 if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) 108 return cast<DenseIntElementsAttr>(attr) 109 .getValues<APInt>()[index] 110 .getSExtValue(); 111 auto *stc = cast<ShapedTypeComponents *>(val); 112 return stc->getDims()[index]; 113 } 114 115 int64_t ShapeAdaptor::getRank() const { 116 assert(hasRank()); 117 if (auto t = llvm::dyn_cast_if_present<Type>(val)) 118 return cast<ShapedType>(t).getRank(); 119 if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) 120 return cast<DenseIntElementsAttr>(attr).size(); 121 return cast<ShapedTypeComponents *>(val)->getDims().size(); 122 } 123 124 bool ShapeAdaptor::hasStaticShape() const { 125 if (!hasRank()) 126 return false; 127 128 if (auto t = llvm::dyn_cast_if_present<Type>(val)) 129 return cast<ShapedType>(t).hasStaticShape(); 130 if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) { 131 auto dattr = cast<DenseIntElementsAttr>(attr); 132 for (auto index : dattr.getValues<APInt>()) 133 if (ShapedType::isDynamic(index.getSExtValue())) 134 return false; 135 return true; 136 } 137 auto *stc = cast<ShapedTypeComponents *>(val); 138 return llvm::none_of(stc->getDims(), ShapedType::isDynamic); 139 } 140 141 int64_t ShapeAdaptor::getNumElements() const { 142 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 143 144 if (auto t = llvm::dyn_cast_if_present<Type>(val)) 145 return cast<ShapedType>(t).getNumElements(); 146 147 if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) { 148 auto dattr = cast<DenseIntElementsAttr>(attr); 149 int64_t num = 1; 150 for (auto index : dattr.getValues<APInt>()) { 151 num *= index.getZExtValue(); 152 assert(num >= 0 && "integer overflow in element count computation"); 153 } 154 return num; 155 } 156 157 auto *stc = cast<ShapedTypeComponents *>(val); 158 int64_t num = 1; 159 for (int64_t dim : stc->getDims()) { 160 num *= dim; 161 assert(num >= 0 && "integer overflow in element count computation"); 162 } 163 return num; 164 } 165 166 void ShapeAdaptor::dump() const { 167 if (!hasRank()) { 168 llvm::errs() << "<<unranked>>\n"; 169 return; 170 } 171 172 SmallVector<int64_t> dims; 173 getDims(dims); 174 auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string { 175 if (ShapedType::isDynamic(dim)) 176 return "?"; 177 return llvm::formatv("{0}", dim).str(); 178 }); 179 llvm::errs() << "rank = " << getRank() << " dims = ["; 180 llvm::interleave(mapped, llvm::errs(), "x"); 181 llvm::errs() << "]\n"; 182 } 183 184 ShapeAdaptor ValueShapeRange::getValueAsShape(int index) { 185 Value val = operator[](index); 186 if (valueToShape) 187 if (ShapeAdaptor ret = valueToShape(val)) 188 return ret; 189 190 DenseIntElementsAttr attr; 191 if (!matchPattern(val, m_Constant(&attr))) 192 return nullptr; 193 if (attr.getType().getRank() != 1) 194 return nullptr; 195 return attr; 196 } 197 198 ShapeAdaptor ValueShapeRange::getShape(Value val) const { 199 if (operandShape) 200 if (ShapeAdaptor ret = operandShape(val)) 201 return ret; 202 return val.getType(); 203 } 204 205 ShapeAdaptor ValueShapeRange::getShape(int index) const { 206 if (index < 0 || static_cast<size_t>(index) >= size()) 207 return nullptr; 208 return getShape(operator[](index)); 209 } 210 211 LogicalResult mlir::detail::inferReturnTensorTypes( 212 ArrayRef<ShapedTypeComponents> retComponents, 213 SmallVectorImpl<Type> &inferredReturnTypes) { 214 for (const auto &shapeAndType : retComponents) { 215 Type elementTy = shapeAndType.getElementType(); 216 assert(elementTy && "element type required to construct tensor"); 217 218 Attribute attr = shapeAndType.getAttribute(); 219 if (shapeAndType.hasRank()) { 220 inferredReturnTypes.push_back( 221 RankedTensorType::get(shapeAndType.getDims(), elementTy, attr)); 222 } else { 223 assert(attr == nullptr && "attribute not supported"); 224 inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy)); 225 } 226 } 227 return success(); 228 } 229 230 LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { 231 SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes()); 232 auto retTypeFn = cast<InferTypeOpInterface>(op); 233 auto result = retTypeFn.refineReturnTypes( 234 op->getContext(), op->getLoc(), op->getOperands(), 235 op->getRawDictionaryAttrs(), op->getPropertiesStorage(), op->getRegions(), 236 inferredReturnTypes); 237 if (failed(result)) 238 op->emitOpError() << "failed to infer returned types"; 239 240 return result; 241 } 242 243 void mlir::detail::reportFatalInferReturnTypesError(OperationState &state) { 244 std::string buffer; 245 llvm::raw_string_ostream os(buffer); 246 os << "Failed to infer result type(s):\n"; 247 os << "\"" << state.name << "\"(...) "; 248 os << state.attributes.getDictionary(state.location.getContext()); 249 os << " : ("; 250 llvm::interleaveComma(state.operands, os, 251 [&](Value val) { os << val.getType(); }); 252 os << ") -> ( ??? )"; 253 emitRemark(state.location, "location of op"); 254 llvm::report_fatal_error(llvm::StringRef(buffer)); 255 } 256