xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp (revision 5a9b74d20d5f3b7f92c01d68d28778108dfb1308)
13bcaf2ebSGeorgios Pinitas //===- TosaInferShapes.cpp ------------------------------------------------===//
28dea784bSRob Suderman //
38dea784bSRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48dea784bSRob Suderman // See https://llvm.org/LICENSE.txt for license information.
58dea784bSRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68dea784bSRob Suderman //
78dea784bSRob Suderman //===----------------------------------------------------------------------===//
88dea784bSRob Suderman //
9f8559751SJay Foad // Propagate shapes forward along TOSA operations to resolve dynamic shape
108dea784bSRob Suderman // operations.
118dea784bSRob Suderman //
128dea784bSRob Suderman //===----------------------------------------------------------------------===//
138dea784bSRob Suderman 
1467d0d7acSMichele Scuttari #include "mlir/Dialect/Tosa/Transforms/Passes.h"
1567d0d7acSMichele Scuttari 
1623aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
178dea784bSRob Suderman #include "mlir/Dialect/Tensor/IR/Tensor.h"
188dea784bSRob Suderman #include "mlir/Dialect/Tosa/IR/TosaOps.h"
191b00b94fSRob Suderman #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
208dea784bSRob Suderman #include "mlir/IR/Builders.h"
21e513f2c6SSpenser Bauman #include "mlir/IR/ImplicitLocOpBuilder.h"
2298aad408SSpenser Bauman #include "mlir/Interfaces/InferTypeOpInterface.h"
238dea784bSRob Suderman #include "mlir/Pass/Pass.h"
248dea784bSRob Suderman #include "mlir/Transforms/DialectConversion.h"
258dea784bSRob Suderman 
2667d0d7acSMichele Scuttari namespace mlir {
2767d0d7acSMichele Scuttari namespace tosa {
2867d0d7acSMichele Scuttari #define GEN_PASS_DEF_TOSAINFERSHAPES
2967d0d7acSMichele Scuttari #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
3067d0d7acSMichele Scuttari } // namespace tosa
3167d0d7acSMichele Scuttari } // namespace mlir
3267d0d7acSMichele Scuttari 
338dea784bSRob Suderman using namespace mlir;
348dea784bSRob Suderman using namespace mlir::tosa;
358dea784bSRob Suderman 
368dea784bSRob Suderman namespace {
378dea784bSRob Suderman 
380a94d35bSSpenser Bauman // Check whether this use case is replaceable. We define an op as
390a94d35bSSpenser Bauman // being replaceable if it is used by a TosaOp, or an op with a
400a94d35bSSpenser Bauman // type-inference related interface.
410a94d35bSSpenser Bauman // When a non-replaceable use is encountered, the value is wrapped in a
420a94d35bSSpenser Bauman // cast back to the original type after inference.
43e513f2c6SSpenser Bauman bool canBeRefined(Operation *user) {
440a94d35bSSpenser Bauman   if (!user->getDialect())
450a94d35bSSpenser Bauman     return false;
46e513f2c6SSpenser Bauman   return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() ||
470a94d35bSSpenser Bauman          isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
480a94d35bSSpenser Bauman }
490a94d35bSSpenser Bauman 
500a94d35bSSpenser Bauman // During type propagation, the types of values in the operator graph are
510a94d35bSSpenser Bauman // updated. For the tosa.while_loop operation, types are speculatively updated
520a94d35bSSpenser Bauman // within the body region to determine the output type of the while_loop. This
530a94d35bSSpenser Bauman // process is performed until a fixed point is reached, then the types are
54e513f2c6SSpenser Bauman // rolled back.
550a94d35bSSpenser Bauman //
56e513f2c6SSpenser Bauman // This class encapsulates the state information needed to perform the roll back
570a94d35bSSpenser Bauman // process or to commit to the final changes.
580a94d35bSSpenser Bauman class TypeModificationState {
590a94d35bSSpenser Bauman public:
600a94d35bSSpenser Bauman   TypeModificationState() = default;
610a94d35bSSpenser Bauman 
620a94d35bSSpenser Bauman   ~TypeModificationState() {
63e513f2c6SSpenser Bauman     // Ensure the recorded modifications are either committed or rolled back.
640a94d35bSSpenser Bauman     assert(oldTypes.empty() && "unhandled type modifications");
650a94d35bSSpenser Bauman   }
660a94d35bSSpenser Bauman 
670a94d35bSSpenser Bauman   // Update the state of the value and record the old type.
680a94d35bSSpenser Bauman   void setType(Value value, Type type) {
690a94d35bSSpenser Bauman     if (value.getType() != type) {
700a94d35bSSpenser Bauman       oldTypes.emplace_back(value, value.getType());
710a94d35bSSpenser Bauman       value.setType(type);
720a94d35bSSpenser Bauman     }
730a94d35bSSpenser Bauman   }
740a94d35bSSpenser Bauman 
75e513f2c6SSpenser Bauman   // Roll back changes made to the types in the IR by setting all the affected
760a94d35bSSpenser Bauman   // values to their old types.
77e513f2c6SSpenser Bauman   void rollBack() {
780a94d35bSSpenser Bauman     for (auto [value, type] : oldTypes)
790a94d35bSSpenser Bauman       value.setType(type);
800a94d35bSSpenser Bauman 
810a94d35bSSpenser Bauman     oldTypes.clear();
820a94d35bSSpenser Bauman   }
830a94d35bSSpenser Bauman 
840a94d35bSSpenser Bauman   // Commit the changes to the types in the IR.
850a94d35bSSpenser Bauman   // This requires inserting tensor.cast operations to mediate the newly
860a94d35bSSpenser Bauman   // inferred result types with users that do not support type inference.
870a94d35bSSpenser Bauman   void commit() {
880a94d35bSSpenser Bauman     // For each use whose type changed, cast the value with the new type back to
890a94d35bSSpenser Bauman     // the old type.
900a94d35bSSpenser Bauman     for (auto [value, oldType] : oldTypes) {
91ef6e7affSRafael Ubal       // The call to 'use->set()' in the body of the loop below invalidates the
92ef6e7affSRafael Ubal       // iterator used to traverse op uses, so it is important to make a copy of
93ef6e7affSRafael Ubal       // these first.
94ef6e7affSRafael Ubal       llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector(
95ef6e7affSRafael Ubal           value.getUses(),
96ef6e7affSRafael Ubal           [](OpOperand &use) -> OpOperand * {
97ef6e7affSRafael Ubal             return &use;
98ef6e7affSRafael Ubal           });
99ef6e7affSRafael Ubal 
100ef6e7affSRafael Ubal       // A 'tensor.cast' op is emitted only if needed. Once emitted, it is
101ef6e7affSRafael Ubal       // cached and reused by all consumers.
102ef6e7affSRafael Ubal       tensor::CastOp castValue;
103ef6e7affSRafael Ubal 
104ef6e7affSRafael Ubal       // Traverse all uses
105ef6e7affSRafael Ubal       for (OpOperand *use : uses) {
106ef6e7affSRafael Ubal         if (canBeRefined(use->getOwner()))
1070a94d35bSSpenser Bauman           continue;
1080a94d35bSSpenser Bauman 
109ef6e7affSRafael Ubal         if (!castValue) {
110ef6e7affSRafael Ubal           // Set the insertion point as far back as possible, since new
111ef6e7affSRafael Ubal           // consumers of the 'tensor.cast' op generated in future iterations
112ef6e7affSRafael Ubal           // are likely to be further up in the code due to the order in which
113ef6e7affSRafael Ubal           // they appear in the use list.
114ef6e7affSRafael Ubal           OpBuilder builder{value.getContext()};
115ef6e7affSRafael Ubal           builder.setInsertionPointAfter(value.getDefiningOp());
116ef6e7affSRafael Ubal           castValue =
117ef6e7affSRafael Ubal               builder.create<tensor::CastOp>(value.getLoc(), oldType, value);
118e513f2c6SSpenser Bauman         }
1190a94d35bSSpenser Bauman 
120ef6e7affSRafael Ubal         use->set(castValue);
1210a94d35bSSpenser Bauman       }
1220a94d35bSSpenser Bauman     }
1230a94d35bSSpenser Bauman 
1240a94d35bSSpenser Bauman     oldTypes.clear();
1250a94d35bSSpenser Bauman   }
1260a94d35bSSpenser Bauman 
1270a94d35bSSpenser Bauman private:
1280a94d35bSSpenser Bauman   // A record of each value whose type was updated along with that value's
1290a94d35bSSpenser Bauman   // previous type.
1300a94d35bSSpenser Bauman   llvm::SmallVector<std::pair<Value, Type>> oldTypes;
1310a94d35bSSpenser Bauman };
1320a94d35bSSpenser Bauman 
1330a94d35bSSpenser Bauman void propagateShapesInRegion(Region &region, TypeModificationState &state);
1340a94d35bSSpenser Bauman 
1350a94d35bSSpenser Bauman void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
136b0532286SRob Suderman   IfOp ifOp = dyn_cast<IfOp>(op);
1371b00b94fSRob Suderman   if (!ifOp)
1388dea784bSRob Suderman     return;
1391b00b94fSRob Suderman 
1401b00b94fSRob Suderman   for (auto &region : op.getRegions()) {
1411b00b94fSRob Suderman     Block &frontBlock = region.front();
1421b00b94fSRob Suderman     if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
1431b00b94fSRob Suderman       return;
1441b00b94fSRob Suderman 
145b0532286SRob Suderman     for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
146e9cb5828Smaxbartel       auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
147b0532286SRob Suderman       auto blockArg = frontBlock.getArgument(i - 1);
1485550c821STres Popp       auto oldType = cast<ShapedType>(blockArg.getType());
149b0532286SRob Suderman 
150b0532286SRob Suderman       if (inferredTy.hasRank()) {
151e9cb5828Smaxbartel         Type newType = oldType.clone(inferredTy.getShape());
1520a94d35bSSpenser Bauman         state.setType(blockArg, newType);
153b0532286SRob Suderman       }
154b0532286SRob Suderman     }
155b0532286SRob Suderman 
1561b00b94fSRob Suderman     for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
1571b00b94fSRob Suderman       ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
1581b00b94fSRob Suderman           ifOp.getOperand(i + 1).getType());
1591b00b94fSRob Suderman       ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
1601b00b94fSRob Suderman           frontBlock.getArgument(i).getType());
1611b00b94fSRob Suderman       ValueKnowledge joinedKnowledge =
1621b00b94fSRob Suderman           ValueKnowledge::join(operandKnowledge, blockKnowledge);
1631b00b94fSRob Suderman       if (!joinedKnowledge)
1641b00b94fSRob Suderman         continue;
1650a94d35bSSpenser Bauman       state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
1661b00b94fSRob Suderman     }
1671b00b94fSRob Suderman 
1680a94d35bSSpenser Bauman     propagateShapesInRegion(region, state);
1691b00b94fSRob Suderman   }
170b0532286SRob Suderman }
1711b00b94fSRob Suderman 
1720a94d35bSSpenser Bauman void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
173b0532286SRob Suderman   WhileOp whileOp = dyn_cast<WhileOp>(op);
174b0532286SRob Suderman   if (!whileOp)
1751b00b94fSRob Suderman     return;
176b0532286SRob Suderman 
177b0532286SRob Suderman   // Determine what the expected argument types are to the cond/body blocks.
178b0532286SRob Suderman   // The expected arguments should be compatible with ever iteration of the
179b0532286SRob Suderman   // loop body / condition for tosa.while.
1800a94d35bSSpenser Bauman   SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
181b0532286SRob Suderman 
182b0532286SRob Suderman   bool hasNewTypes = true;
183b0532286SRob Suderman   while (hasNewTypes) {
1840a94d35bSSpenser Bauman     TypeModificationState localState;
185b0532286SRob Suderman 
186b0532286SRob Suderman     // Set types on the block args.
187b0532286SRob Suderman     Region &bodyRegion = op.getRegion(1);
188b0532286SRob Suderman     Block &block = bodyRegion.front();
189b0532286SRob Suderman     for (int i = 0, s = argTypes.size(); i < s; i++) {
1900a94d35bSSpenser Bauman       localState.setType(block.getArgument(i), argTypes[i]);
191b0532286SRob Suderman     }
192b0532286SRob Suderman 
193b0532286SRob Suderman     // Propagate to the end.
1940a94d35bSSpenser Bauman     propagateShapesInRegion(bodyRegion, localState);
195b0532286SRob Suderman 
1960a94d35bSSpenser Bauman     // Find all the tosa yield types and verify there is a single one.
197b0532286SRob Suderman     llvm::SmallVector<YieldOp> yieldOps;
198b0532286SRob Suderman     for (auto &block : bodyRegion)
199b0532286SRob Suderman       if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
200b0532286SRob Suderman         yieldOps.push_back(yieldOp);
201b0532286SRob Suderman 
2020a94d35bSSpenser Bauman     assert(yieldOps.size() == 1 && "missing or non-unique yield op");
203b0532286SRob Suderman     // Using the new tosa.yield operand types, infer the new subtypes.
204b0532286SRob Suderman     llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
205b0532286SRob Suderman     for (auto ty : argTypes) {
206b0532286SRob Suderman       yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
207b0532286SRob Suderman     }
208b0532286SRob Suderman 
209b0532286SRob Suderman     for (auto yieldOp : yieldOps) {
21089de9cc8SMehdi Amini       for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
211b0532286SRob Suderman         auto newKnowledge =
212b0532286SRob Suderman             ValueKnowledge::getKnowledgeFromType(it.value().getType());
213b0532286SRob Suderman         yieldTypeInfo[it.index()] =
214b0532286SRob Suderman             ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
215b0532286SRob Suderman       }
216b0532286SRob Suderman     }
217b0532286SRob Suderman 
218b0532286SRob Suderman     // This should never happen.
219b0532286SRob Suderman     if (yieldTypeInfo.size() != argTypes.size()) {
220b0532286SRob Suderman       op.emitWarning("has a tosa.yield with the incorrect number of operands");
221b0532286SRob Suderman       return;
222b0532286SRob Suderman     }
223b0532286SRob Suderman 
224b0532286SRob Suderman     // Determine the new block args and see if any changed.
225b0532286SRob Suderman     hasNewTypes = false;
226b0532286SRob Suderman     for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
227b0532286SRob Suderman       Type newType = yieldTypeInfo[i].getType();
228b0532286SRob Suderman       hasNewTypes |= (newType != argTypes[i]);
229b0532286SRob Suderman       argTypes[i] = newType;
230b0532286SRob Suderman     }
231b0532286SRob Suderman 
232e513f2c6SSpenser Bauman     // Roll back all changes made during the speculative part of the algorithm.
233e513f2c6SSpenser Bauman     localState.rollBack();
234b0532286SRob Suderman   }
235b0532286SRob Suderman 
236b0532286SRob Suderman   // We now set the block arguments according to the most recent shape
237b0532286SRob Suderman   // inference results. This gives us the block arg types for the next
238b0532286SRob Suderman   // iteration.
239b0532286SRob Suderman   for (auto &region : op.getRegions()) {
240b0532286SRob Suderman     for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
2410a94d35bSSpenser Bauman       state.setType(region.front().getArgument(i), argTypes[i]);
242b0532286SRob Suderman     }
243b0532286SRob Suderman 
2440a94d35bSSpenser Bauman     propagateShapesInRegion(region, state);
245b0532286SRob Suderman   }
2461b00b94fSRob Suderman }
2471b00b94fSRob Suderman 
2480a94d35bSSpenser Bauman void propagateShapesInRegion(Region &region, TypeModificationState &state) {
249e513f2c6SSpenser Bauman   Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
250e513f2c6SSpenser Bauman 
2511b00b94fSRob Suderman   for (auto &block : region) {
2521b00b94fSRob Suderman     for (Operation &op : block) {
253e513f2c6SSpenser Bauman       if (op.getDialect() != tosaDialect)
2541b00b94fSRob Suderman         continue;
2551b00b94fSRob Suderman 
2560a94d35bSSpenser Bauman       propagateShapesToTosaIf(op, state);
2570a94d35bSSpenser Bauman       propagateShapesToTosaWhile(op, state);
2581b00b94fSRob Suderman 
2598dea784bSRob Suderman       InferShapedTypeOpInterface shapeInterface =
2608dea784bSRob Suderman           dyn_cast<InferShapedTypeOpInterface>(op);
2618dea784bSRob Suderman       if (!shapeInterface)
2621b00b94fSRob Suderman         continue;
2638dea784bSRob Suderman 
2648dea784bSRob Suderman       SmallVector<ShapedTypeComponents> returnedShapes;
26509349303SJacques Pienaar 
2668dea784bSRob Suderman       if (shapeInterface
267e9cb5828Smaxbartel               .inferReturnTypeComponents(
268e9cb5828Smaxbartel                   op.getContext(), op.getLoc(), op.getOperands(),
269e9cb5828Smaxbartel                   op.getDiscardableAttrDictionary(), op.getPropertiesStorage(),
270bbe5bf17SMehdi Amini                   op.getRegions(), returnedShapes)
2718dea784bSRob Suderman               .succeeded()) {
2721b00b94fSRob Suderman         for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
2738dea784bSRob Suderman           Value result = std::get<0>(it);
2748dea784bSRob Suderman           ShapedTypeComponents predictedShape = std::get<1>(it);
2758dea784bSRob Suderman 
2768dea784bSRob Suderman           // Determine the knowledge based on the output type.
27709349303SJacques Pienaar           // TODO: should also query WIP type probably
2788dea784bSRob Suderman           Type resultTy = result.getType();
2798dea784bSRob Suderman           auto currentKnowledge =
2808dea784bSRob Suderman               ValueKnowledge::getKnowledgeFromType(resultTy);
2818dea784bSRob Suderman 
2828dea784bSRob Suderman           // Compute the knowledge based on the inferred type.
2831b00b94fSRob Suderman           auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
2845550c821STres Popp           inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
2851b00b94fSRob Suderman           inferredKnowledge.hasRank = predictedShape.hasRank();
2868dea784bSRob Suderman           if (predictedShape.hasRank()) {
2878dea784bSRob Suderman             for (auto dim : predictedShape.getDims()) {
2888dea784bSRob Suderman               inferredKnowledge.sizes.push_back(dim);
2898dea784bSRob Suderman             }
2908dea784bSRob Suderman           }
2918dea784bSRob Suderman 
2928dea784bSRob Suderman           // Compute the new type based on the joined version.
2938dea784bSRob Suderman           auto newKnowledge =
2948dea784bSRob Suderman               ValueKnowledge::join(currentKnowledge, inferredKnowledge);
2951b00b94fSRob Suderman           if (!newKnowledge)
2961b00b94fSRob Suderman             continue;
29709349303SJacques Pienaar 
298e9cb5828Smaxbartel           // Set new type
2990a94d35bSSpenser Bauman           state.setType(result, newKnowledge.getType());
300e9cb5828Smaxbartel         }
301e9cb5828Smaxbartel       }
30209349303SJacques Pienaar     }
30309349303SJacques Pienaar   }
3041b00b94fSRob Suderman }
3051b00b94fSRob Suderman 
306729f958cSTai Ly /// Recursively validate tosa ops with SameOperandsAndResultRank trait in region
307729f958cSTai Ly /// and all nested regions
308729f958cSTai Ly void validateSameOperandsAndResultRankTrait(Region &region) {
309729f958cSTai Ly   int errs = 0;
310729f958cSTai Ly   for (auto &block : region) {
311729f958cSTai Ly     for (auto &op : block) {
312729f958cSTai Ly       if (!op.getDialect() ||
313729f958cSTai Ly           op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
314729f958cSTai Ly         continue;
315729f958cSTai Ly       if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) {
316729f958cSTai Ly         if (OpTrait::impl::verifySameOperandsAndResultRank(&op).failed()) {
317729f958cSTai Ly           errs++;
318*5a9b74d2SKazu Hirata           (void)errs;
319729f958cSTai Ly         }
320729f958cSTai Ly       }
321729f958cSTai Ly       WhileOp whileOp = dyn_cast<WhileOp>(op);
322729f958cSTai Ly       IfOp ifOp = dyn_cast<IfOp>(op);
323729f958cSTai Ly       if (whileOp || ifOp) {
324729f958cSTai Ly         // recurse into whileOp's regions
325729f958cSTai Ly         for (auto &next : op.getRegions()) {
326729f958cSTai Ly           validateSameOperandsAndResultRankTrait(next);
327729f958cSTai Ly         }
328729f958cSTai Ly       }
329729f958cSTai Ly     }
330729f958cSTai Ly   }
331729f958cSTai Ly }
332729f958cSTai Ly 
3331b00b94fSRob Suderman /// Pass that performs shape propagation across TOSA operations. This includes
3341b00b94fSRob Suderman /// migrating to within the regions of if/while operations.
33567d0d7acSMichele Scuttari struct TosaInferShapes
33667d0d7acSMichele Scuttari     : public tosa::impl::TosaInferShapesBase<TosaInferShapes> {
3371b00b94fSRob Suderman public:
33841574554SRiver Riddle   void runOnOperation() override {
33958ceae95SRiver Riddle     func::FuncOp func = getOperation();
3400a94d35bSSpenser Bauman     TypeModificationState state;
3410a94d35bSSpenser Bauman     propagateShapesInRegion(func.getBody(), state);
3420a94d35bSSpenser Bauman     state.commit();
343729f958cSTai Ly 
344729f958cSTai Ly     validateSameOperandsAndResultRankTrait(func.getBody());
3458dea784bSRob Suderman   }
3468dea784bSRob Suderman };
347be0a7e9fSMehdi Amini } // namespace
348039b969bSMichele Scuttari 
349039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() {
350039b969bSMichele Scuttari   return std::make_unique<TosaInferShapes>();
351039b969bSMichele Scuttari }
352