xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp (revision 3ace685105d3b50bca68328bf0c945af22d70f23)
1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Passes.h"
10 
11 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/OpDefinition.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 #include <iterator>
20 #include <memory>
21 #include <utility>
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_LINALGDETENSORIZEPASS
25 #include "mlir/Dialect/Linalg/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 
31 static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
32                                            ValueRange inputs, Location loc) {
33   assert(inputs.size() == 1);
34   auto inputType = inputs[0].getType();
35   if (isa<TensorType>(inputType))
36     return nullptr;
37 
38   // A detensored value is converted back by creating a new tensor from its
39   // element(s).
40   return builder.create<tensor::FromElementsOp>(
41       loc, RankedTensorType::get({}, inputType), inputs[0]);
42 }
43 
44 namespace {
45 /// Defines the criteria a TensorType must follow in order to be considered
46 /// "detensorable".
47 ///
48 /// NOTE: For now, only 0-D tensors are supported.
49 ///
50 /// Returns true if tensorType can be detensored.
51 bool canBeDetensored(TensorType tensorType) {
52   return tensorType.hasRank() && tensorType.getRank() == 0;
53 }
54 
55 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
56   GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
57   return genericOp &&
58          llvm::all_of(genericOp->getOpOperands(), [&](OpOperand &opOperand) {
59            return !typeConverter.isLegal(opOperand.get().getType());
60          });
61 }
62 
63 /// A conversion pattern for detensoring `linalg.generic` ops.
64 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
65 public:
66   using OpConversionPattern::OpConversionPattern;
67   LogicalResult
68   matchAndRewrite(GenericOp op, OpAdaptor adaptor,
69                   ConversionPatternRewriter &rewriter) const override {
70     Block *originalBlock = op->getBlock();
71 
72     // Gather some information about the op before inlining its region.
73     Block *opEntryBlock = &*op.getRegion().begin();
74     YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator());
75 
76     // Split the op's region before the op. This way, we have a clear insertion
77     // point in which the op can be inlined.
78     Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));
79     rewriter.inlineRegionBefore(op.getRegion(), newBlock);
80     // Now that op's region is inlined, the operands of its YieldOp are mapped
81     // to the materialized target values. Therefore, we can replace the op's
82     // uses with those of its YielOp's operands.
83     rewriter.replaceOp(op, yieldOp->getOperands());
84 
85     // No need for these intermediate blocks, merge them into 1.
86     rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
87     rewriter.mergeBlocks(newBlock, originalBlock, {});
88 
89     rewriter.eraseOp(&*Block::iterator(yieldOp));
90 
91     return success();
92   }
93 };
94 
95 /// A conversion pattern for detensoring internal (non-entry) blocks within a
96 /// function.
97 struct FunctionNonEntryBlockConversion
98     : public OpInterfaceConversionPattern<FunctionOpInterface> {
99   FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
100                                   DenseSet<BlockArgument> blockArgsToDetensor)
101       : OpInterfaceConversionPattern(converter, ctx),
102         blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
103 
104   LogicalResult
105   matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
106                   ConversionPatternRewriter &rewriter) const override {
107     rewriter.startOpModification(op);
108     Region &region = op.getFunctionBody();
109 
110     for (Block &block :
111          llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
112       TypeConverter::SignatureConversion conversion(
113           /*numOrigInputs=*/block.getNumArguments());
114 
115       for (BlockArgument blockArgument : block.getArguments()) {
116         int idx = blockArgument.getArgNumber();
117 
118         if (blockArgsToDetensor.count(blockArgument))
119           conversion.addInputs(idx, {getTypeConverter()->convertType(
120                                         block.getArgumentTypes()[idx])});
121         else
122           conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
123       }
124 
125       rewriter.applySignatureConversion(&block, conversion, getTypeConverter());
126     }
127 
128     rewriter.finalizeOpModification(op);
129     return success();
130   }
131 
132 private:
133   const DenseSet<BlockArgument> blockArgsToDetensor;
134 };
135 
136 class DetensorizeTypeConverter : public TypeConverter {
137 public:
138   DetensorizeTypeConverter() {
139     addConversion([](Type type) { return type; });
140 
141     // A TensorType that can be detensored, is converted to the underlying
142     // element type.
143     addConversion([](TensorType tensorType) -> Type {
144       if (canBeDetensored(tensorType))
145         return tensorType.getElementType();
146 
147       return tensorType;
148     });
149 
150     // A tensor value is detensoried by extracting its element(s).
151     addTargetMaterialization([](OpBuilder &builder, Type type,
152                                 ValueRange inputs, Location loc) -> Value {
153       return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
154     });
155 
156     addSourceMaterialization(sourceMaterializationCallback);
157   }
158 };
159 
160 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
161 struct LinalgDetensorize
162     : public impl::LinalgDetensorizePassBase<LinalgDetensorize> {
163   using impl::LinalgDetensorizePassBase<
164       LinalgDetensorize>::LinalgDetensorizePassBase;
165   LinalgDetensorize() = default;
166 
167   class CostModel {
168   public:
169     virtual ~CostModel() = default;
170 
171     /// A cost model algorithm computes the following outputs:
172     ///
173     /// - opsToDetensor: the list of linalg ops that should be
174     /// detensored.
175     ///
176     /// - blockArgsToDetensor: since the operands and results of detensored
177     /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
178     /// from a BB argument and a linalg op's output can be passed to successor
179     /// BBs), we need to maintain the sub-set of arguments that should be
180     /// detensored (i.e. converted by typeConverter) for each affected BB.
181     ///
182     /// Example:
183     ///
184     /// For the following snippet:
185     /// ...
186     /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
187     ///   %7 = tensor.empty() : tensor<i32>
188     ///   %8 = linalg.generic #attrs
189     ///     ins(%6, %6 : tensor<i32>, tensor<i32>)
190     ///     outs(%7 : tensor<i32>) {
191     ///     ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
192     ///       %9 = arith.addi %arg0, %arg1 : i32
193     ///       linalg.yield %9 : i32
194     ///   } -> tensor<i32>
195     ///   %10 = "some.op"(%9)
196     ///   br ^bb2(%8 : tensor<i32>)
197     /// ...
198     ///
199     /// if the cost model decides that the linalg.generic op should be
200     /// detensored, then:
201     /// - opsToDetensor should be = {linalg.generic{add}}.
202     /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
203     virtual void compute(FunctionOpInterface func,
204                          DetensorizeTypeConverter typeConverter,
205                          DenseSet<Operation *> &opsToDetensor,
206                          DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
207 
208     /// From the blockArgsToDetensor set computed by a CostModel
209     /// implementation, this method computes the corresponding branch op
210     /// detensoring. The result is a map from a branch op to a subset of indices
211     /// of its operands. The indices specify which of the branch op's operands
212     /// should be detensored.
213     ///
214     /// For the previous example, this method would compute: {bb2 -> {0}}.
215     static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
216         const DenseSet<BlockArgument> &blockArgsToDetensor) {
217       DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
218 
219       for (auto blockArgumentElem : blockArgsToDetensor) {
220         Block *block = blockArgumentElem.getOwner();
221 
222         for (PredecessorIterator pred = block->pred_begin();
223              pred != block->pred_end(); ++pred) {
224           BranchOpInterface terminator =
225               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
226           auto blockOperands =
227               terminator.getSuccessorOperands(pred.getSuccessorIndex());
228 
229           if (blockOperands.empty() ||
230               blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
231             continue;
232 
233           detensorableBranchOps[terminator].insert(
234               blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
235         }
236       }
237 
238       return detensorableBranchOps;
239     }
240   };
241 
242   /// Detensorize linalg ops involved in control-flow within a function.
243   ///
244   /// This model starts from BranchOps and CondBranchOps within a function. For
245   /// each such branch, the model then walks the use-def chain for the branch's
246   /// condition backwards in order to understand where the condition's value
247   /// comes from. If the condition value is (indirectly) computed by a linalg op
248   /// that can be detensored, the model then continues walking the use-def chain
249   /// in order to understand where the linalg op's operands come from. This
250   /// leads to discovering a "detensoring component". A detensoring component is
251   /// the set of operations + block arguments that are involved in control-flow
252   /// AND can be detensored.
253   class ControlFlowDetectionModel : public CostModel {
254   public:
255     void compute(FunctionOpInterface func,
256                  DetensorizeTypeConverter typeConverter,
257                  DenseSet<Operation *> &opsToDetensor,
258                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
259       SmallVector<Value> workList;
260 
261       func->walk([&](cf::CondBranchOp condBr) {
262         llvm::append_range(workList, condBr.getOperands());
263       });
264 
265       func->walk([&](cf::BranchOp br) {
266         llvm::append_range(workList, br.getOperands());
267       });
268 
269       DenseSet<Value> visitedValues;
270       DenseSet<Operation *> visitedOps;
271 
272       // For a (to-be-detesored) value, check if it "escapes" the block by being
273       // passed to terminator. If it does, then workList is updated with the
274       // corresponding argument to the successor block.
275       auto updateWorkListWithSuccessorArguments =
276           [&](Value value, BranchOpInterface terminator) {
277             if (!terminator)
278               return;
279 
280             for (auto operandIdx :
281                  llvm::seq<unsigned>(0, terminator->getOperands().size())) {
282               Value operand = terminator->getOperand(operandIdx);
283 
284               if (operand == value) {
285                 auto succBlockArg =
286                     terminator.getSuccessorBlockArgument(operandIdx);
287 
288                 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
289                   workList.push_back(*succBlockArg);
290               }
291             }
292           };
293 
294       while (!workList.empty()) {
295         Value currentItem = workList.pop_back_val();
296 
297         if (!visitedValues.insert(currentItem).second)
298           continue;
299 
300         // 1   - Look forward:
301         // 1.1 - If currentItem escapes to one or more successors, add
302         // the corresponding successor arguments to workList.
303         updateWorkListWithSuccessorArguments(
304             currentItem, dyn_cast<BranchOpInterface>(
305                              currentItem.getParentBlock()->getTerminator()));
306 
307         // 1.2 - For each user of currentItem, add the defined values to
308         // workList. This way, the user ops can be inspected later if they are
309         // detensorable and if so, their operands will be added to workList to
310         // potentially discover other parts of the detensorable component.
311         for (auto *user : currentItem.getUsers())
312           llvm::append_range(workList, user->getResults());
313 
314         // 2   - Look backward:
315         // 2.1 - The current item is defined by a block argument. If the owner
316         // block is a non-entry one, then:
317         //       * Add the argument to blockArgsToDetensor.
318         //       * Walk the use-def chain backwards to add each predecessor's
319         //       terminator-operands corresponding to currentItem to workList.
320         if (dyn_cast<BlockArgument>(currentItem)) {
321           BlockArgument currentItemBlockArgument =
322               cast<BlockArgument>(currentItem);
323           Block *ownerBlock = currentItemBlockArgument.getOwner();
324 
325           // Function arguments are not detensored/converted.
326           if (&*ownerBlock->getParent()->begin() == ownerBlock)
327             continue;
328 
329           // This inner-block argument is involved in control-flow, it should be
330           // detensored.
331           blockArgsToDetensor.insert(currentItemBlockArgument);
332 
333           for (PredecessorIterator pred = ownerBlock->pred_begin();
334                pred != ownerBlock->pred_end(); ++pred) {
335             BranchOpInterface predTerminator =
336                 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
337 
338             // TODO: For now, we give up if any of the control-flow components
339             // in a function is not detensorable. Fix that.
340             if (!predTerminator) {
341               opsToDetensor.clear();
342               blockArgsToDetensor.clear();
343               return;
344             }
345 
346             auto ownerBlockOperands =
347                 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
348 
349             if (ownerBlockOperands.empty() ||
350                 ownerBlockOperands.isOperandProduced(
351                     currentItemBlockArgument.getArgNumber()))
352               continue;
353 
354             // For each predecessor, add the value it passes to that argument to
355             // workList to find out how it's computed.
356             workList.push_back(
357                 ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
358           }
359 
360           continue;
361         }
362 
363         Operation *currentItemDefiningOp = currentItem.getDefiningOp();
364 
365         if (!visitedOps.insert(currentItemDefiningOp).second)
366           continue;
367 
368         // 2.2 - The current item is computed by a GenericOp. If the op should
369         // be detensored, then:
370         //       * Add it to opsToDetensor.
371         //       * Add its operands to workList to discover other parts of the
372         //       potentially detensorable component.
373         if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
374           // The op was encountered already, no need to inspect it again.
375           if (opsToDetensor.count(genericOp))
376             continue;
377 
378           // The op should not be detensored, give up on it but continue with
379           // discovering the rest of the control-flow component.
380           if (!shouldBeDetensored(genericOp, typeConverter)) {
381             continue;
382           }
383 
384           opsToDetensor.insert(genericOp);
385           llvm::append_range(workList, genericOp.getInputs());
386           continue;
387         }
388 
389         // 2.3 - The current item is the result of a FromElementsOp, it will be
390         // trivially detensored later as part of canonicalization patterns
391         // applied at the end of detensoring.
392         //
393         // Note: No need to check whether the result type of this op is
394         // detensorable since if it wasn't we wouldn't reach that point in the
395         // work list.
396         if (isa<tensor::FromElementsOp>(currentItemDefiningOp))
397           continue;
398 
399         // 2.4 - The current item is the result of a scalar op, add all its
400         // operands to the work list.
401         if (llvm::all_of(
402                 currentItemDefiningOp->getResultTypes(),
403                 [&](Type resultType) { return resultType.isIntOrFloat(); }))
404           llvm::append_range(workList, currentItemDefiningOp->getOperands());
405       }
406 
407       // Since the cost model gives up on some ops (see the details of step 2.2
408       // above), block arguments that correspond to the values produced by those
409       // ops should not be detensored as well.
410 
411       DenseSet<BlockArgument> blockArgsToRemove;
412 
413       for (auto &blockArg : blockArgsToDetensor) {
414         Block *block = blockArg.getParentBlock();
415 
416         // For the potentially detensorable block argument, find the
417         // correpsonding operands in predecessor blocks.
418         for (PredecessorIterator pred = block->pred_begin();
419              pred != block->pred_end(); ++pred) {
420           BranchOpInterface terminator =
421               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
422           auto blockOperands =
423               terminator.getSuccessorOperands(pred.getSuccessorIndex());
424 
425           if (blockOperands.empty() ||
426               blockOperands.isOperandProduced(blockArg.getArgNumber()))
427             continue;
428 
429           Operation *definingOp =
430               blockOperands[blockArg.getArgNumber()].getDefiningOp();
431 
432           // If the operand is defined by a GenericOp that will not be
433           // detensored, then do not detensor the corresponding block argument.
434           if (isa_and_nonnull<GenericOp>(definingOp) &&
435               opsToDetensor.count(definingOp) == 0) {
436             blockArgsToRemove.insert(blockArg);
437             break;
438           }
439         }
440       }
441 
442       for (auto &blockArg : blockArgsToRemove) {
443         blockArgsToDetensor.erase(blockArg);
444       }
445     }
446   };
447 
448   /// Detensorize everything that can detensored.
449   class AggressiveDetensoringModel : public CostModel {
450   public:
451     void compute(FunctionOpInterface func,
452                  DetensorizeTypeConverter typeConverter,
453                  DenseSet<Operation *> &opsToDetensor,
454                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
455       func->walk([&](GenericOp genericOp) {
456         if (shouldBeDetensored(genericOp, typeConverter))
457           opsToDetensor.insert(genericOp);
458       });
459 
460       for (Block &block : llvm::drop_begin(func.getFunctionBody(), 1))
461         for (BlockArgument blockArgument : block.getArguments())
462           blockArgsToDetensor.insert(blockArgument);
463     }
464   };
465 
466   void runOnOperation() override {
467     MLIRContext *context = &getContext();
468     DetensorizeTypeConverter typeConverter;
469     RewritePatternSet patterns(context);
470     ConversionTarget target(*context);
471     DenseSet<Operation *> opsToDetensor;
472     DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
473     DenseSet<BlockArgument> blockArgsToDetensor;
474     FunctionOpInterface funcOp = getOperation();
475 
476     if (funcOp.getFunctionBody().empty())
477       return;
478 
479     // Make sure the entry block of the function doesn't contain any Linalg ops.
480     // Otherwise, it may lead to the signature of the block being changed by the
481     // dialect conversion below, which would make the function op invalid
482     // because its type shouldn't change.
483     IRRewriter rewriter(funcOp->getContext());
484     Block *entryBlock = &funcOp.getFunctionBody().front();
485     Block *postEntryBlock =
486         rewriter.splitBlock(entryBlock, entryBlock->begin());
487     rewriter.setInsertionPointToStart(entryBlock);
488     auto branch =
489         rewriter.create<cf::BranchOp>(rewriter.getUnknownLoc(), postEntryBlock);
490 
491     if (aggressiveMode.getValue()) {
492       AggressiveDetensoringModel costModel;
493       costModel.compute(funcOp, typeConverter, opsToDetensor,
494                         blockArgsToDetensor);
495     } else {
496       ControlFlowDetectionModel costModel;
497       costModel.compute(funcOp, typeConverter, opsToDetensor,
498                         blockArgsToDetensor);
499     }
500 
501     detensorableBranchOps =
502         CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
503 
504     target.addDynamicallyLegalOp<GenericOp>(
505         [&](GenericOp op) { return !opsToDetensor.count(op); });
506 
507     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
508       // A function is legal if all of its non-entry blocks are legal. We
509       // don't legalize the entry block (i.e. the function's signature)
510       // since detensoring can't happen along external calling convention
511       // boundaries, which we conservatively approximate as all function
512       // signatures.
513       if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
514         Region &body = funcOp.getFunctionBody();
515         return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
516           return !llvm::any_of(
517               blockArgsToDetensor, [&](BlockArgument blockArgument) {
518                 return blockArgument.getOwner() == &block &&
519                        !typeConverter.isLegal(blockArgument.getType());
520               });
521         });
522       }
523 
524       if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
525           isLegalForReturnOpTypeConversionPattern(op, typeConverter,
526                                                   /*returnOpAlwaysLegal*/ true))
527         return true;
528 
529       if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
530         if (!detensorableBranchOps.count(branchOp))
531           return true;
532 
533         for (auto operandIdx : detensorableBranchOps[branchOp])
534           if (!typeConverter.isLegal(
535                   branchOp->getOperand(operandIdx).getType()))
536             return false;
537 
538         return true;
539       }
540 
541       return false;
542     });
543 
544     patterns.add<DetensorizeGenericOp>(typeConverter, context);
545     patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter,
546                                                   blockArgsToDetensor);
547     // Since non-entry block arguments get detensorized, we also need to
548     // update the control flow inside the function to reflect the correct
549     // types.
550     auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
551                                           int operandIdx) -> bool {
552       return detensorableBranchOps.count(branchOp) &&
553              detensorableBranchOps[branchOp].count(operandIdx);
554     };
555 
556     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
557                                                    shouldConvertBranchOperand);
558 
559     if (failed(
560             applyFullConversion(getOperation(), target, std::move(patterns))))
561       signalPassFailure();
562 
563     RewritePatternSet canonPatterns(context);
564     tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
565     if (failed(applyPatternsGreedily(getOperation(), std::move(canonPatterns))))
566       signalPassFailure();
567 
568     // Get rid of the dummy entry block we created in the beginning to work
569     // around dialect conversion signature rewriting.
570     rewriter.eraseOp(branch);
571     rewriter.mergeBlocks(postEntryBlock, entryBlock);
572   }
573 };
574 } // namespace
575