xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp (revision 5a9b74d20d5f3b7f92c01d68d28778108dfb1308)
1 //===- TosaInferShapes.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Propagate shapes forward along TOSA operations to resolve dynamic shape
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
15 
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
19 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/Interfaces/InferTypeOpInterface.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 
26 namespace mlir {
27 namespace tosa {
28 #define GEN_PASS_DEF_TOSAINFERSHAPES
29 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
30 } // namespace tosa
31 } // namespace mlir
32 
33 using namespace mlir;
34 using namespace mlir::tosa;
35 
36 namespace {
37 
38 // Check whether this use case is replaceable. We define an op as
39 // being replaceable if it is used by a TosaOp, or an op with a
40 // type-inference related interface.
41 // When a non-replaceable use is encountered, the value is wrapped in a
42 // cast back to the original type after inference.
43 bool canBeRefined(Operation *user) {
44   if (!user->getDialect())
45     return false;
46   return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() ||
47          isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
48 }
49 
50 // During type propagation, the types of values in the operator graph are
51 // updated. For the tosa.while_loop operation, types are speculatively updated
52 // within the body region to determine the output type of the while_loop. This
53 // process is performed until a fixed point is reached, then the types are
54 // rolled back.
55 //
56 // This class encapsulates the state information needed to perform the roll back
57 // process or to commit to the final changes.
58 class TypeModificationState {
59 public:
60   TypeModificationState() = default;
61 
62   ~TypeModificationState() {
63     // Ensure the recorded modifications are either committed or rolled back.
64     assert(oldTypes.empty() && "unhandled type modifications");
65   }
66 
67   // Update the state of the value and record the old type.
68   void setType(Value value, Type type) {
69     if (value.getType() != type) {
70       oldTypes.emplace_back(value, value.getType());
71       value.setType(type);
72     }
73   }
74 
75   // Roll back changes made to the types in the IR by setting all the affected
76   // values to their old types.
77   void rollBack() {
78     for (auto [value, type] : oldTypes)
79       value.setType(type);
80 
81     oldTypes.clear();
82   }
83 
84   // Commit the changes to the types in the IR.
85   // This requires inserting tensor.cast operations to mediate the newly
86   // inferred result types with users that do not support type inference.
87   void commit() {
88     // For each use whose type changed, cast the value with the new type back to
89     // the old type.
90     for (auto [value, oldType] : oldTypes) {
91       // The call to 'use->set()' in the body of the loop below invalidates the
92       // iterator used to traverse op uses, so it is important to make a copy of
93       // these first.
94       llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector(
95           value.getUses(),
96           [](OpOperand &use) -> OpOperand * {
97             return &use;
98           });
99 
100       // A 'tensor.cast' op is emitted only if needed. Once emitted, it is
101       // cached and reused by all consumers.
102       tensor::CastOp castValue;
103 
104       // Traverse all uses
105       for (OpOperand *use : uses) {
106         if (canBeRefined(use->getOwner()))
107           continue;
108 
109         if (!castValue) {
110           // Set the insertion point as far back as possible, since new
111           // consumers of the 'tensor.cast' op generated in future iterations
112           // are likely to be further up in the code due to the order in which
113           // they appear in the use list.
114           OpBuilder builder{value.getContext()};
115           builder.setInsertionPointAfter(value.getDefiningOp());
116           castValue =
117               builder.create<tensor::CastOp>(value.getLoc(), oldType, value);
118         }
119 
120         use->set(castValue);
121       }
122     }
123 
124     oldTypes.clear();
125   }
126 
127 private:
128   // A record of each value whose type was updated along with that value's
129   // previous type.
130   llvm::SmallVector<std::pair<Value, Type>> oldTypes;
131 };
132 
133 void propagateShapesInRegion(Region &region, TypeModificationState &state);
134 
135 void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
136   IfOp ifOp = dyn_cast<IfOp>(op);
137   if (!ifOp)
138     return;
139 
140   for (auto &region : op.getRegions()) {
141     Block &frontBlock = region.front();
142     if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
143       return;
144 
145     for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
146       auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
147       auto blockArg = frontBlock.getArgument(i - 1);
148       auto oldType = cast<ShapedType>(blockArg.getType());
149 
150       if (inferredTy.hasRank()) {
151         Type newType = oldType.clone(inferredTy.getShape());
152         state.setType(blockArg, newType);
153       }
154     }
155 
156     for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
157       ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
158           ifOp.getOperand(i + 1).getType());
159       ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
160           frontBlock.getArgument(i).getType());
161       ValueKnowledge joinedKnowledge =
162           ValueKnowledge::join(operandKnowledge, blockKnowledge);
163       if (!joinedKnowledge)
164         continue;
165       state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
166     }
167 
168     propagateShapesInRegion(region, state);
169   }
170 }
171 
172 void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
173   WhileOp whileOp = dyn_cast<WhileOp>(op);
174   if (!whileOp)
175     return;
176 
177   // Determine what the expected argument types are to the cond/body blocks.
178   // The expected arguments should be compatible with ever iteration of the
179   // loop body / condition for tosa.while.
180   SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
181 
182   bool hasNewTypes = true;
183   while (hasNewTypes) {
184     TypeModificationState localState;
185 
186     // Set types on the block args.
187     Region &bodyRegion = op.getRegion(1);
188     Block &block = bodyRegion.front();
189     for (int i = 0, s = argTypes.size(); i < s; i++) {
190       localState.setType(block.getArgument(i), argTypes[i]);
191     }
192 
193     // Propagate to the end.
194     propagateShapesInRegion(bodyRegion, localState);
195 
196     // Find all the tosa yield types and verify there is a single one.
197     llvm::SmallVector<YieldOp> yieldOps;
198     for (auto &block : bodyRegion)
199       if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
200         yieldOps.push_back(yieldOp);
201 
202     assert(yieldOps.size() == 1 && "missing or non-unique yield op");
203     // Using the new tosa.yield operand types, infer the new subtypes.
204     llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
205     for (auto ty : argTypes) {
206       yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
207     }
208 
209     for (auto yieldOp : yieldOps) {
210       for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
211         auto newKnowledge =
212             ValueKnowledge::getKnowledgeFromType(it.value().getType());
213         yieldTypeInfo[it.index()] =
214             ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
215       }
216     }
217 
218     // This should never happen.
219     if (yieldTypeInfo.size() != argTypes.size()) {
220       op.emitWarning("has a tosa.yield with the incorrect number of operands");
221       return;
222     }
223 
224     // Determine the new block args and see if any changed.
225     hasNewTypes = false;
226     for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
227       Type newType = yieldTypeInfo[i].getType();
228       hasNewTypes |= (newType != argTypes[i]);
229       argTypes[i] = newType;
230     }
231 
232     // Roll back all changes made during the speculative part of the algorithm.
233     localState.rollBack();
234   }
235 
236   // We now set the block arguments according to the most recent shape
237   // inference results. This gives us the block arg types for the next
238   // iteration.
239   for (auto &region : op.getRegions()) {
240     for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
241       state.setType(region.front().getArgument(i), argTypes[i]);
242     }
243 
244     propagateShapesInRegion(region, state);
245   }
246 }
247 
248 void propagateShapesInRegion(Region &region, TypeModificationState &state) {
249   Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
250 
251   for (auto &block : region) {
252     for (Operation &op : block) {
253       if (op.getDialect() != tosaDialect)
254         continue;
255 
256       propagateShapesToTosaIf(op, state);
257       propagateShapesToTosaWhile(op, state);
258 
259       InferShapedTypeOpInterface shapeInterface =
260           dyn_cast<InferShapedTypeOpInterface>(op);
261       if (!shapeInterface)
262         continue;
263 
264       SmallVector<ShapedTypeComponents> returnedShapes;
265 
266       if (shapeInterface
267               .inferReturnTypeComponents(
268                   op.getContext(), op.getLoc(), op.getOperands(),
269                   op.getDiscardableAttrDictionary(), op.getPropertiesStorage(),
270                   op.getRegions(), returnedShapes)
271               .succeeded()) {
272         for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
273           Value result = std::get<0>(it);
274           ShapedTypeComponents predictedShape = std::get<1>(it);
275 
276           // Determine the knowledge based on the output type.
277           // TODO: should also query WIP type probably
278           Type resultTy = result.getType();
279           auto currentKnowledge =
280               ValueKnowledge::getKnowledgeFromType(resultTy);
281 
282           // Compute the knowledge based on the inferred type.
283           auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
284           inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
285           inferredKnowledge.hasRank = predictedShape.hasRank();
286           if (predictedShape.hasRank()) {
287             for (auto dim : predictedShape.getDims()) {
288               inferredKnowledge.sizes.push_back(dim);
289             }
290           }
291 
292           // Compute the new type based on the joined version.
293           auto newKnowledge =
294               ValueKnowledge::join(currentKnowledge, inferredKnowledge);
295           if (!newKnowledge)
296             continue;
297 
298           // Set new type
299           state.setType(result, newKnowledge.getType());
300         }
301       }
302     }
303   }
304 }
305 
306 /// Recursively validate tosa ops with SameOperandsAndResultRank trait in region
307 /// and all nested regions
308 void validateSameOperandsAndResultRankTrait(Region &region) {
309   int errs = 0;
310   for (auto &block : region) {
311     for (auto &op : block) {
312       if (!op.getDialect() ||
313           op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
314         continue;
315       if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) {
316         if (OpTrait::impl::verifySameOperandsAndResultRank(&op).failed()) {
317           errs++;
318           (void)errs;
319         }
320       }
321       WhileOp whileOp = dyn_cast<WhileOp>(op);
322       IfOp ifOp = dyn_cast<IfOp>(op);
323       if (whileOp || ifOp) {
324         // recurse into whileOp's regions
325         for (auto &next : op.getRegions()) {
326           validateSameOperandsAndResultRankTrait(next);
327         }
328       }
329     }
330   }
331 }
332 
333 /// Pass that performs shape propagation across TOSA operations. This includes
334 /// migrating to within the regions of if/while operations.
335 struct TosaInferShapes
336     : public tosa::impl::TosaInferShapesBase<TosaInferShapes> {
337 public:
338   void runOnOperation() override {
339     func::FuncOp func = getOperation();
340     TypeModificationState state;
341     propagateShapesInRegion(func.getBody(), state);
342     state.commit();
343 
344     validateSameOperandsAndResultRankTrait(func.getBody());
345   }
346 };
347 } // namespace
348 
349 std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() {
350   return std::make_unique<TosaInferShapes>();
351 }
352