13bcaf2ebSGeorgios Pinitas //===- TosaInferShapes.cpp ------------------------------------------------===// 28dea784bSRob Suderman // 38dea784bSRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 48dea784bSRob Suderman // See https://llvm.org/LICENSE.txt for license information. 58dea784bSRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68dea784bSRob Suderman // 78dea784bSRob Suderman //===----------------------------------------------------------------------===// 88dea784bSRob Suderman // 9f8559751SJay Foad // Propagate shapes forward along TOSA operations to resolve dynamic shape 108dea784bSRob Suderman // operations. 118dea784bSRob Suderman // 128dea784bSRob Suderman //===----------------------------------------------------------------------===// 138dea784bSRob Suderman 1467d0d7acSMichele Scuttari #include "mlir/Dialect/Tosa/Transforms/Passes.h" 1567d0d7acSMichele Scuttari 1623aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 178dea784bSRob Suderman #include "mlir/Dialect/Tensor/IR/Tensor.h" 188dea784bSRob Suderman #include "mlir/Dialect/Tosa/IR/TosaOps.h" 191b00b94fSRob Suderman #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" 208dea784bSRob Suderman #include "mlir/IR/Builders.h" 21e513f2c6SSpenser Bauman #include "mlir/IR/ImplicitLocOpBuilder.h" 2298aad408SSpenser Bauman #include "mlir/Interfaces/InferTypeOpInterface.h" 238dea784bSRob Suderman #include "mlir/Pass/Pass.h" 248dea784bSRob Suderman #include "mlir/Transforms/DialectConversion.h" 258dea784bSRob Suderman 2667d0d7acSMichele Scuttari namespace mlir { 2767d0d7acSMichele Scuttari namespace tosa { 2867d0d7acSMichele Scuttari #define GEN_PASS_DEF_TOSAINFERSHAPES 2967d0d7acSMichele Scuttari #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" 3067d0d7acSMichele Scuttari } // namespace tosa 3167d0d7acSMichele Scuttari } // namespace mlir 3267d0d7acSMichele Scuttari 338dea784bSRob Suderman using namespace mlir; 348dea784bSRob Suderman using namespace mlir::tosa; 358dea784bSRob Suderman 368dea784bSRob Suderman namespace { 378dea784bSRob Suderman 380a94d35bSSpenser Bauman // Check whether this use case is replaceable. We define an op as 390a94d35bSSpenser Bauman // being replaceable if it is used by a TosaOp, or an op with a 400a94d35bSSpenser Bauman // type-inference related interface. 410a94d35bSSpenser Bauman // When a non-replaceable use is encountered, the value is wrapped in a 420a94d35bSSpenser Bauman // cast back to the original type after inference. 43e513f2c6SSpenser Bauman bool canBeRefined(Operation *user) { 440a94d35bSSpenser Bauman if (!user->getDialect()) 450a94d35bSSpenser Bauman return false; 46e513f2c6SSpenser Bauman return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() || 470a94d35bSSpenser Bauman isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user); 480a94d35bSSpenser Bauman } 490a94d35bSSpenser Bauman 500a94d35bSSpenser Bauman // During type propagation, the types of values in the operator graph are 510a94d35bSSpenser Bauman // updated. For the tosa.while_loop operation, types are speculatively updated 520a94d35bSSpenser Bauman // within the body region to determine the output type of the while_loop. This 530a94d35bSSpenser Bauman // process is performed until a fixed point is reached, then the types are 54e513f2c6SSpenser Bauman // rolled back. 550a94d35bSSpenser Bauman // 56e513f2c6SSpenser Bauman // This class encapsulates the state information needed to perform the roll back 570a94d35bSSpenser Bauman // process or to commit to the final changes. 580a94d35bSSpenser Bauman class TypeModificationState { 590a94d35bSSpenser Bauman public: 600a94d35bSSpenser Bauman TypeModificationState() = default; 610a94d35bSSpenser Bauman 620a94d35bSSpenser Bauman ~TypeModificationState() { 63e513f2c6SSpenser Bauman // Ensure the recorded modifications are either committed or rolled back. 640a94d35bSSpenser Bauman assert(oldTypes.empty() && "unhandled type modifications"); 650a94d35bSSpenser Bauman } 660a94d35bSSpenser Bauman 670a94d35bSSpenser Bauman // Update the state of the value and record the old type. 680a94d35bSSpenser Bauman void setType(Value value, Type type) { 690a94d35bSSpenser Bauman if (value.getType() != type) { 700a94d35bSSpenser Bauman oldTypes.emplace_back(value, value.getType()); 710a94d35bSSpenser Bauman value.setType(type); 720a94d35bSSpenser Bauman } 730a94d35bSSpenser Bauman } 740a94d35bSSpenser Bauman 75e513f2c6SSpenser Bauman // Roll back changes made to the types in the IR by setting all the affected 760a94d35bSSpenser Bauman // values to their old types. 77e513f2c6SSpenser Bauman void rollBack() { 780a94d35bSSpenser Bauman for (auto [value, type] : oldTypes) 790a94d35bSSpenser Bauman value.setType(type); 800a94d35bSSpenser Bauman 810a94d35bSSpenser Bauman oldTypes.clear(); 820a94d35bSSpenser Bauman } 830a94d35bSSpenser Bauman 840a94d35bSSpenser Bauman // Commit the changes to the types in the IR. 850a94d35bSSpenser Bauman // This requires inserting tensor.cast operations to mediate the newly 860a94d35bSSpenser Bauman // inferred result types with users that do not support type inference. 870a94d35bSSpenser Bauman void commit() { 880a94d35bSSpenser Bauman // For each use whose type changed, cast the value with the new type back to 890a94d35bSSpenser Bauman // the old type. 900a94d35bSSpenser Bauman for (auto [value, oldType] : oldTypes) { 91ef6e7affSRafael Ubal // The call to 'use->set()' in the body of the loop below invalidates the 92ef6e7affSRafael Ubal // iterator used to traverse op uses, so it is important to make a copy of 93ef6e7affSRafael Ubal // these first. 94ef6e7affSRafael Ubal llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector( 95ef6e7affSRafael Ubal value.getUses(), 96ef6e7affSRafael Ubal [](OpOperand &use) -> OpOperand * { 97ef6e7affSRafael Ubal return &use; 98ef6e7affSRafael Ubal }); 99ef6e7affSRafael Ubal 100ef6e7affSRafael Ubal // A 'tensor.cast' op is emitted only if needed. Once emitted, it is 101ef6e7affSRafael Ubal // cached and reused by all consumers. 102ef6e7affSRafael Ubal tensor::CastOp castValue; 103ef6e7affSRafael Ubal 104ef6e7affSRafael Ubal // Traverse all uses 105ef6e7affSRafael Ubal for (OpOperand *use : uses) { 106ef6e7affSRafael Ubal if (canBeRefined(use->getOwner())) 1070a94d35bSSpenser Bauman continue; 1080a94d35bSSpenser Bauman 109ef6e7affSRafael Ubal if (!castValue) { 110ef6e7affSRafael Ubal // Set the insertion point as far back as possible, since new 111ef6e7affSRafael Ubal // consumers of the 'tensor.cast' op generated in future iterations 112ef6e7affSRafael Ubal // are likely to be further up in the code due to the order in which 113ef6e7affSRafael Ubal // they appear in the use list. 114ef6e7affSRafael Ubal OpBuilder builder{value.getContext()}; 115ef6e7affSRafael Ubal builder.setInsertionPointAfter(value.getDefiningOp()); 116ef6e7affSRafael Ubal castValue = 117ef6e7affSRafael Ubal builder.create<tensor::CastOp>(value.getLoc(), oldType, value); 118e513f2c6SSpenser Bauman } 1190a94d35bSSpenser Bauman 120ef6e7affSRafael Ubal use->set(castValue); 1210a94d35bSSpenser Bauman } 1220a94d35bSSpenser Bauman } 1230a94d35bSSpenser Bauman 1240a94d35bSSpenser Bauman oldTypes.clear(); 1250a94d35bSSpenser Bauman } 1260a94d35bSSpenser Bauman 1270a94d35bSSpenser Bauman private: 1280a94d35bSSpenser Bauman // A record of each value whose type was updated along with that value's 1290a94d35bSSpenser Bauman // previous type. 1300a94d35bSSpenser Bauman llvm::SmallVector<std::pair<Value, Type>> oldTypes; 1310a94d35bSSpenser Bauman }; 1320a94d35bSSpenser Bauman 1330a94d35bSSpenser Bauman void propagateShapesInRegion(Region ®ion, TypeModificationState &state); 1340a94d35bSSpenser Bauman 1350a94d35bSSpenser Bauman void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) { 136b0532286SRob Suderman IfOp ifOp = dyn_cast<IfOp>(op); 1371b00b94fSRob Suderman if (!ifOp) 1388dea784bSRob Suderman return; 1391b00b94fSRob Suderman 1401b00b94fSRob Suderman for (auto ®ion : op.getRegions()) { 1411b00b94fSRob Suderman Block &frontBlock = region.front(); 1421b00b94fSRob Suderman if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands()) 1431b00b94fSRob Suderman return; 1441b00b94fSRob Suderman 145b0532286SRob Suderman for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) { 146e9cb5828Smaxbartel auto inferredTy = cast<ShapedType>(op.getOperand(i).getType()); 147b0532286SRob Suderman auto blockArg = frontBlock.getArgument(i - 1); 1485550c821STres Popp auto oldType = cast<ShapedType>(blockArg.getType()); 149b0532286SRob Suderman 150b0532286SRob Suderman if (inferredTy.hasRank()) { 151e9cb5828Smaxbartel Type newType = oldType.clone(inferredTy.getShape()); 1520a94d35bSSpenser Bauman state.setType(blockArg, newType); 153b0532286SRob Suderman } 154b0532286SRob Suderman } 155b0532286SRob Suderman 1561b00b94fSRob Suderman for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) { 1571b00b94fSRob Suderman ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType( 1581b00b94fSRob Suderman ifOp.getOperand(i + 1).getType()); 1591b00b94fSRob Suderman ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType( 1601b00b94fSRob Suderman frontBlock.getArgument(i).getType()); 1611b00b94fSRob Suderman ValueKnowledge joinedKnowledge = 1621b00b94fSRob Suderman ValueKnowledge::join(operandKnowledge, blockKnowledge); 1631b00b94fSRob Suderman if (!joinedKnowledge) 1641b00b94fSRob Suderman continue; 1650a94d35bSSpenser Bauman state.setType(frontBlock.getArgument(i), joinedKnowledge.getType()); 1661b00b94fSRob Suderman } 1671b00b94fSRob Suderman 1680a94d35bSSpenser Bauman propagateShapesInRegion(region, state); 1691b00b94fSRob Suderman } 170b0532286SRob Suderman } 1711b00b94fSRob Suderman 1720a94d35bSSpenser Bauman void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) { 173b0532286SRob Suderman WhileOp whileOp = dyn_cast<WhileOp>(op); 174b0532286SRob Suderman if (!whileOp) 1751b00b94fSRob Suderman return; 176b0532286SRob Suderman 177b0532286SRob Suderman // Determine what the expected argument types are to the cond/body blocks. 178b0532286SRob Suderman // The expected arguments should be compatible with ever iteration of the 179b0532286SRob Suderman // loop body / condition for tosa.while. 1800a94d35bSSpenser Bauman SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes()); 181b0532286SRob Suderman 182b0532286SRob Suderman bool hasNewTypes = true; 183b0532286SRob Suderman while (hasNewTypes) { 1840a94d35bSSpenser Bauman TypeModificationState localState; 185b0532286SRob Suderman 186b0532286SRob Suderman // Set types on the block args. 187b0532286SRob Suderman Region &bodyRegion = op.getRegion(1); 188b0532286SRob Suderman Block &block = bodyRegion.front(); 189b0532286SRob Suderman for (int i = 0, s = argTypes.size(); i < s; i++) { 1900a94d35bSSpenser Bauman localState.setType(block.getArgument(i), argTypes[i]); 191b0532286SRob Suderman } 192b0532286SRob Suderman 193b0532286SRob Suderman // Propagate to the end. 1940a94d35bSSpenser Bauman propagateShapesInRegion(bodyRegion, localState); 195b0532286SRob Suderman 1960a94d35bSSpenser Bauman // Find all the tosa yield types and verify there is a single one. 197b0532286SRob Suderman llvm::SmallVector<YieldOp> yieldOps; 198b0532286SRob Suderman for (auto &block : bodyRegion) 199b0532286SRob Suderman if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator())) 200b0532286SRob Suderman yieldOps.push_back(yieldOp); 201b0532286SRob Suderman 2020a94d35bSSpenser Bauman assert(yieldOps.size() == 1 && "missing or non-unique yield op"); 203b0532286SRob Suderman // Using the new tosa.yield operand types, infer the new subtypes. 204b0532286SRob Suderman llvm::SmallVector<ValueKnowledge> yieldTypeInfo; 205b0532286SRob Suderman for (auto ty : argTypes) { 206b0532286SRob Suderman yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty)); 207b0532286SRob Suderman } 208b0532286SRob Suderman 209b0532286SRob Suderman for (auto yieldOp : yieldOps) { 21089de9cc8SMehdi Amini for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { 211b0532286SRob Suderman auto newKnowledge = 212b0532286SRob Suderman ValueKnowledge::getKnowledgeFromType(it.value().getType()); 213b0532286SRob Suderman yieldTypeInfo[it.index()] = 214b0532286SRob Suderman ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge); 215b0532286SRob Suderman } 216b0532286SRob Suderman } 217b0532286SRob Suderman 218b0532286SRob Suderman // This should never happen. 219b0532286SRob Suderman if (yieldTypeInfo.size() != argTypes.size()) { 220b0532286SRob Suderman op.emitWarning("has a tosa.yield with the incorrect number of operands"); 221b0532286SRob Suderman return; 222b0532286SRob Suderman } 223b0532286SRob Suderman 224b0532286SRob Suderman // Determine the new block args and see if any changed. 225b0532286SRob Suderman hasNewTypes = false; 226b0532286SRob Suderman for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) { 227b0532286SRob Suderman Type newType = yieldTypeInfo[i].getType(); 228b0532286SRob Suderman hasNewTypes |= (newType != argTypes[i]); 229b0532286SRob Suderman argTypes[i] = newType; 230b0532286SRob Suderman } 231b0532286SRob Suderman 232e513f2c6SSpenser Bauman // Roll back all changes made during the speculative part of the algorithm. 233e513f2c6SSpenser Bauman localState.rollBack(); 234b0532286SRob Suderman } 235b0532286SRob Suderman 236b0532286SRob Suderman // We now set the block arguments according to the most recent shape 237b0532286SRob Suderman // inference results. This gives us the block arg types for the next 238b0532286SRob Suderman // iteration. 239b0532286SRob Suderman for (auto ®ion : op.getRegions()) { 240b0532286SRob Suderman for (unsigned int i = 0, s = argTypes.size(); i < s; i++) { 2410a94d35bSSpenser Bauman state.setType(region.front().getArgument(i), argTypes[i]); 242b0532286SRob Suderman } 243b0532286SRob Suderman 2440a94d35bSSpenser Bauman propagateShapesInRegion(region, state); 245b0532286SRob Suderman } 2461b00b94fSRob Suderman } 2471b00b94fSRob Suderman 2480a94d35bSSpenser Bauman void propagateShapesInRegion(Region ®ion, TypeModificationState &state) { 249e513f2c6SSpenser Bauman Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>(); 250e513f2c6SSpenser Bauman 2511b00b94fSRob Suderman for (auto &block : region) { 2521b00b94fSRob Suderman for (Operation &op : block) { 253e513f2c6SSpenser Bauman if (op.getDialect() != tosaDialect) 2541b00b94fSRob Suderman continue; 2551b00b94fSRob Suderman 2560a94d35bSSpenser Bauman propagateShapesToTosaIf(op, state); 2570a94d35bSSpenser Bauman propagateShapesToTosaWhile(op, state); 2581b00b94fSRob Suderman 2598dea784bSRob Suderman InferShapedTypeOpInterface shapeInterface = 2608dea784bSRob Suderman dyn_cast<InferShapedTypeOpInterface>(op); 2618dea784bSRob Suderman if (!shapeInterface) 2621b00b94fSRob Suderman continue; 2638dea784bSRob Suderman 2648dea784bSRob Suderman SmallVector<ShapedTypeComponents> returnedShapes; 26509349303SJacques Pienaar 2668dea784bSRob Suderman if (shapeInterface 267e9cb5828Smaxbartel .inferReturnTypeComponents( 268e9cb5828Smaxbartel op.getContext(), op.getLoc(), op.getOperands(), 269e9cb5828Smaxbartel op.getDiscardableAttrDictionary(), op.getPropertiesStorage(), 270bbe5bf17SMehdi Amini op.getRegions(), returnedShapes) 2718dea784bSRob Suderman .succeeded()) { 2721b00b94fSRob Suderman for (auto it : llvm::zip(op.getResults(), returnedShapes)) { 2738dea784bSRob Suderman Value result = std::get<0>(it); 2748dea784bSRob Suderman ShapedTypeComponents predictedShape = std::get<1>(it); 2758dea784bSRob Suderman 2768dea784bSRob Suderman // Determine the knowledge based on the output type. 27709349303SJacques Pienaar // TODO: should also query WIP type probably 2788dea784bSRob Suderman Type resultTy = result.getType(); 2798dea784bSRob Suderman auto currentKnowledge = 2808dea784bSRob Suderman ValueKnowledge::getKnowledgeFromType(resultTy); 2818dea784bSRob Suderman 2828dea784bSRob Suderman // Compute the knowledge based on the inferred type. 2831b00b94fSRob Suderman auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); 2845550c821STres Popp inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType(); 2851b00b94fSRob Suderman inferredKnowledge.hasRank = predictedShape.hasRank(); 2868dea784bSRob Suderman if (predictedShape.hasRank()) { 2878dea784bSRob Suderman for (auto dim : predictedShape.getDims()) { 2888dea784bSRob Suderman inferredKnowledge.sizes.push_back(dim); 2898dea784bSRob Suderman } 2908dea784bSRob Suderman } 2918dea784bSRob Suderman 2928dea784bSRob Suderman // Compute the new type based on the joined version. 2938dea784bSRob Suderman auto newKnowledge = 2948dea784bSRob Suderman ValueKnowledge::join(currentKnowledge, inferredKnowledge); 2951b00b94fSRob Suderman if (!newKnowledge) 2961b00b94fSRob Suderman continue; 29709349303SJacques Pienaar 298e9cb5828Smaxbartel // Set new type 2990a94d35bSSpenser Bauman state.setType(result, newKnowledge.getType()); 300e9cb5828Smaxbartel } 301e9cb5828Smaxbartel } 30209349303SJacques Pienaar } 30309349303SJacques Pienaar } 3041b00b94fSRob Suderman } 3051b00b94fSRob Suderman 306729f958cSTai Ly /// Recursively validate tosa ops with SameOperandsAndResultRank trait in region 307729f958cSTai Ly /// and all nested regions 308729f958cSTai Ly void validateSameOperandsAndResultRankTrait(Region ®ion) { 309729f958cSTai Ly int errs = 0; 310729f958cSTai Ly for (auto &block : region) { 311729f958cSTai Ly for (auto &op : block) { 312729f958cSTai Ly if (!op.getDialect() || 313729f958cSTai Ly op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace()) 314729f958cSTai Ly continue; 315729f958cSTai Ly if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) { 316729f958cSTai Ly if (OpTrait::impl::verifySameOperandsAndResultRank(&op).failed()) { 317729f958cSTai Ly errs++; 318*5a9b74d2SKazu Hirata (void)errs; 319729f958cSTai Ly } 320729f958cSTai Ly } 321729f958cSTai Ly WhileOp whileOp = dyn_cast<WhileOp>(op); 322729f958cSTai Ly IfOp ifOp = dyn_cast<IfOp>(op); 323729f958cSTai Ly if (whileOp || ifOp) { 324729f958cSTai Ly // recurse into whileOp's regions 325729f958cSTai Ly for (auto &next : op.getRegions()) { 326729f958cSTai Ly validateSameOperandsAndResultRankTrait(next); 327729f958cSTai Ly } 328729f958cSTai Ly } 329729f958cSTai Ly } 330729f958cSTai Ly } 331729f958cSTai Ly } 332729f958cSTai Ly 3331b00b94fSRob Suderman /// Pass that performs shape propagation across TOSA operations. This includes 3341b00b94fSRob Suderman /// migrating to within the regions of if/while operations. 33567d0d7acSMichele Scuttari struct TosaInferShapes 33667d0d7acSMichele Scuttari : public tosa::impl::TosaInferShapesBase<TosaInferShapes> { 3371b00b94fSRob Suderman public: 33841574554SRiver Riddle void runOnOperation() override { 33958ceae95SRiver Riddle func::FuncOp func = getOperation(); 3400a94d35bSSpenser Bauman TypeModificationState state; 3410a94d35bSSpenser Bauman propagateShapesInRegion(func.getBody(), state); 3420a94d35bSSpenser Bauman state.commit(); 343729f958cSTai Ly 344729f958cSTai Ly validateSameOperandsAndResultRankTrait(func.getBody()); 3458dea784bSRob Suderman } 3468dea784bSRob Suderman }; 347be0a7e9fSMehdi Amini } // namespace 348039b969bSMichele Scuttari 349039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() { 350039b969bSMichele Scuttari return std::make_unique<TosaInferShapes>(); 351039b969bSMichele Scuttari } 352