1 //===- TosaInferShapes.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Propagate shapes forward along TOSA operations to resolve dynamic shape 10 // operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 15 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 19 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/Interfaces/InferTypeOpInterface.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 26 namespace mlir { 27 namespace tosa { 28 #define GEN_PASS_DEF_TOSAINFERSHAPES 29 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" 30 } // namespace tosa 31 } // namespace mlir 32 33 using namespace mlir; 34 using namespace mlir::tosa; 35 36 namespace { 37 38 // Check whether this use case is replaceable. We define an op as 39 // being replaceable if it is used by a TosaOp, or an op with a 40 // type-inference related interface. 41 // When a non-replaceable use is encountered, the value is wrapped in a 42 // cast back to the original type after inference. 43 bool canBeRefined(Operation *user) { 44 if (!user->getDialect()) 45 return false; 46 return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() || 47 isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user); 48 } 49 50 // During type propagation, the types of values in the operator graph are 51 // updated. For the tosa.while_loop operation, types are speculatively updated 52 // within the body region to determine the output type of the while_loop. This 53 // process is performed until a fixed point is reached, then the types are 54 // rolled back. 55 // 56 // This class encapsulates the state information needed to perform the roll back 57 // process or to commit to the final changes. 58 class TypeModificationState { 59 public: 60 TypeModificationState() = default; 61 62 ~TypeModificationState() { 63 // Ensure the recorded modifications are either committed or rolled back. 64 assert(oldTypes.empty() && "unhandled type modifications"); 65 } 66 67 // Update the state of the value and record the old type. 68 void setType(Value value, Type type) { 69 if (value.getType() != type) { 70 oldTypes.emplace_back(value, value.getType()); 71 value.setType(type); 72 } 73 } 74 75 // Roll back changes made to the types in the IR by setting all the affected 76 // values to their old types. 77 void rollBack() { 78 for (auto [value, type] : oldTypes) 79 value.setType(type); 80 81 oldTypes.clear(); 82 } 83 84 // Commit the changes to the types in the IR. 85 // This requires inserting tensor.cast operations to mediate the newly 86 // inferred result types with users that do not support type inference. 87 void commit() { 88 // For each use whose type changed, cast the value with the new type back to 89 // the old type. 90 for (auto [value, oldType] : oldTypes) { 91 // The call to 'use->set()' in the body of the loop below invalidates the 92 // iterator used to traverse op uses, so it is important to make a copy of 93 // these first. 94 llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector( 95 value.getUses(), 96 [](OpOperand &use) -> OpOperand * { 97 return &use; 98 }); 99 100 // A 'tensor.cast' op is emitted only if needed. Once emitted, it is 101 // cached and reused by all consumers. 102 tensor::CastOp castValue; 103 104 // Traverse all uses 105 for (OpOperand *use : uses) { 106 if (canBeRefined(use->getOwner())) 107 continue; 108 109 if (!castValue) { 110 // Set the insertion point as far back as possible, since new 111 // consumers of the 'tensor.cast' op generated in future iterations 112 // are likely to be further up in the code due to the order in which 113 // they appear in the use list. 114 OpBuilder builder{value.getContext()}; 115 builder.setInsertionPointAfter(value.getDefiningOp()); 116 castValue = 117 builder.create<tensor::CastOp>(value.getLoc(), oldType, value); 118 } 119 120 use->set(castValue); 121 } 122 } 123 124 oldTypes.clear(); 125 } 126 127 private: 128 // A record of each value whose type was updated along with that value's 129 // previous type. 130 llvm::SmallVector<std::pair<Value, Type>> oldTypes; 131 }; 132 133 void propagateShapesInRegion(Region ®ion, TypeModificationState &state); 134 135 void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) { 136 IfOp ifOp = dyn_cast<IfOp>(op); 137 if (!ifOp) 138 return; 139 140 for (auto ®ion : op.getRegions()) { 141 Block &frontBlock = region.front(); 142 if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands()) 143 return; 144 145 for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) { 146 auto inferredTy = cast<ShapedType>(op.getOperand(i).getType()); 147 auto blockArg = frontBlock.getArgument(i - 1); 148 auto oldType = cast<ShapedType>(blockArg.getType()); 149 150 if (inferredTy.hasRank()) { 151 Type newType = oldType.clone(inferredTy.getShape()); 152 state.setType(blockArg, newType); 153 } 154 } 155 156 for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) { 157 ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType( 158 ifOp.getOperand(i + 1).getType()); 159 ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType( 160 frontBlock.getArgument(i).getType()); 161 ValueKnowledge joinedKnowledge = 162 ValueKnowledge::join(operandKnowledge, blockKnowledge); 163 if (!joinedKnowledge) 164 continue; 165 state.setType(frontBlock.getArgument(i), joinedKnowledge.getType()); 166 } 167 168 propagateShapesInRegion(region, state); 169 } 170 } 171 172 void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) { 173 WhileOp whileOp = dyn_cast<WhileOp>(op); 174 if (!whileOp) 175 return; 176 177 // Determine what the expected argument types are to the cond/body blocks. 178 // The expected arguments should be compatible with ever iteration of the 179 // loop body / condition for tosa.while. 180 SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes()); 181 182 bool hasNewTypes = true; 183 while (hasNewTypes) { 184 TypeModificationState localState; 185 186 // Set types on the block args. 187 Region &bodyRegion = op.getRegion(1); 188 Block &block = bodyRegion.front(); 189 for (int i = 0, s = argTypes.size(); i < s; i++) { 190 localState.setType(block.getArgument(i), argTypes[i]); 191 } 192 193 // Propagate to the end. 194 propagateShapesInRegion(bodyRegion, localState); 195 196 // Find all the tosa yield types and verify there is a single one. 197 llvm::SmallVector<YieldOp> yieldOps; 198 for (auto &block : bodyRegion) 199 if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator())) 200 yieldOps.push_back(yieldOp); 201 202 assert(yieldOps.size() == 1 && "missing or non-unique yield op"); 203 // Using the new tosa.yield operand types, infer the new subtypes. 204 llvm::SmallVector<ValueKnowledge> yieldTypeInfo; 205 for (auto ty : argTypes) { 206 yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty)); 207 } 208 209 for (auto yieldOp : yieldOps) { 210 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { 211 auto newKnowledge = 212 ValueKnowledge::getKnowledgeFromType(it.value().getType()); 213 yieldTypeInfo[it.index()] = 214 ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge); 215 } 216 } 217 218 // This should never happen. 219 if (yieldTypeInfo.size() != argTypes.size()) { 220 op.emitWarning("has a tosa.yield with the incorrect number of operands"); 221 return; 222 } 223 224 // Determine the new block args and see if any changed. 225 hasNewTypes = false; 226 for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) { 227 Type newType = yieldTypeInfo[i].getType(); 228 hasNewTypes |= (newType != argTypes[i]); 229 argTypes[i] = newType; 230 } 231 232 // Roll back all changes made during the speculative part of the algorithm. 233 localState.rollBack(); 234 } 235 236 // We now set the block arguments according to the most recent shape 237 // inference results. This gives us the block arg types for the next 238 // iteration. 239 for (auto ®ion : op.getRegions()) { 240 for (unsigned int i = 0, s = argTypes.size(); i < s; i++) { 241 state.setType(region.front().getArgument(i), argTypes[i]); 242 } 243 244 propagateShapesInRegion(region, state); 245 } 246 } 247 248 void propagateShapesInRegion(Region ®ion, TypeModificationState &state) { 249 Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>(); 250 251 for (auto &block : region) { 252 for (Operation &op : block) { 253 if (op.getDialect() != tosaDialect) 254 continue; 255 256 propagateShapesToTosaIf(op, state); 257 propagateShapesToTosaWhile(op, state); 258 259 InferShapedTypeOpInterface shapeInterface = 260 dyn_cast<InferShapedTypeOpInterface>(op); 261 if (!shapeInterface) 262 continue; 263 264 SmallVector<ShapedTypeComponents> returnedShapes; 265 266 if (shapeInterface 267 .inferReturnTypeComponents( 268 op.getContext(), op.getLoc(), op.getOperands(), 269 op.getDiscardableAttrDictionary(), op.getPropertiesStorage(), 270 op.getRegions(), returnedShapes) 271 .succeeded()) { 272 for (auto it : llvm::zip(op.getResults(), returnedShapes)) { 273 Value result = std::get<0>(it); 274 ShapedTypeComponents predictedShape = std::get<1>(it); 275 276 // Determine the knowledge based on the output type. 277 // TODO: should also query WIP type probably 278 Type resultTy = result.getType(); 279 auto currentKnowledge = 280 ValueKnowledge::getKnowledgeFromType(resultTy); 281 282 // Compute the knowledge based on the inferred type. 283 auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); 284 inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType(); 285 inferredKnowledge.hasRank = predictedShape.hasRank(); 286 if (predictedShape.hasRank()) { 287 for (auto dim : predictedShape.getDims()) { 288 inferredKnowledge.sizes.push_back(dim); 289 } 290 } 291 292 // Compute the new type based on the joined version. 293 auto newKnowledge = 294 ValueKnowledge::join(currentKnowledge, inferredKnowledge); 295 if (!newKnowledge) 296 continue; 297 298 // Set new type 299 state.setType(result, newKnowledge.getType()); 300 } 301 } 302 } 303 } 304 } 305 306 /// Recursively validate tosa ops with SameOperandsAndResultRank trait in region 307 /// and all nested regions 308 void validateSameOperandsAndResultRankTrait(Region ®ion) { 309 int errs = 0; 310 for (auto &block : region) { 311 for (auto &op : block) { 312 if (!op.getDialect() || 313 op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace()) 314 continue; 315 if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) { 316 if (OpTrait::impl::verifySameOperandsAndResultRank(&op).failed()) { 317 errs++; 318 (void)errs; 319 } 320 } 321 WhileOp whileOp = dyn_cast<WhileOp>(op); 322 IfOp ifOp = dyn_cast<IfOp>(op); 323 if (whileOp || ifOp) { 324 // recurse into whileOp's regions 325 for (auto &next : op.getRegions()) { 326 validateSameOperandsAndResultRankTrait(next); 327 } 328 } 329 } 330 } 331 } 332 333 /// Pass that performs shape propagation across TOSA operations. This includes 334 /// migrating to within the regions of if/while operations. 335 struct TosaInferShapes 336 : public tosa::impl::TosaInferShapesBase<TosaInferShapes> { 337 public: 338 void runOnOperation() override { 339 func::FuncOp func = getOperation(); 340 TypeModificationState state; 341 propagateShapesInRegion(func.getBody(), state); 342 state.commit(); 343 344 validateSameOperandsAndResultRankTrait(func.getBody()); 345 } 346 }; 347 } // namespace 348 349 std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() { 350 return std::make_unique<TosaInferShapes>(); 351 } 352