xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp (revision 3ace685105d3b50bca68328bf0c945af22d70f23)
167e0d58dSKareemErgawy-TomTom //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
267e0d58dSKareemErgawy-TomTom //
367e0d58dSKareemErgawy-TomTom // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
467e0d58dSKareemErgawy-TomTom // See https://llvm.org/LICENSE.txt for license information.
567e0d58dSKareemErgawy-TomTom // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
667e0d58dSKareemErgawy-TomTom //
767e0d58dSKareemErgawy-TomTom //===----------------------------------------------------------------------===//
867e0d58dSKareemErgawy-TomTom 
967d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h"
1067d0d7acSMichele Scuttari 
11ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1267d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h"
1323aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
14b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
1567e0d58dSKareemErgawy-TomTom #include "mlir/Dialect/Tensor/IR/Tensor.h"
1667e0d58dSKareemErgawy-TomTom #include "mlir/IR/OpDefinition.h"
1767e0d58dSKareemErgawy-TomTom #include "mlir/Transforms/DialectConversion.h"
1867e0d58dSKareemErgawy-TomTom #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1967e0d58dSKareemErgawy-TomTom #include <iterator>
2067e0d58dSKareemErgawy-TomTom #include <memory>
211fc096afSMehdi Amini #include <utility>
2267e0d58dSKareemErgawy-TomTom 
2367d0d7acSMichele Scuttari namespace mlir {
241e98d488SQuinn Dawkins #define GEN_PASS_DEF_LINALGDETENSORIZEPASS
2567d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h.inc"
2667d0d7acSMichele Scuttari } // namespace mlir
2767d0d7acSMichele Scuttari 
2867e0d58dSKareemErgawy-TomTom using namespace mlir;
2967e0d58dSKareemErgawy-TomTom using namespace mlir::linalg;
3067e0d58dSKareemErgawy-TomTom 
313b021fbdSKareemErgawy-TomTom static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
323b021fbdSKareemErgawy-TomTom                                            ValueRange inputs, Location loc) {
333b021fbdSKareemErgawy-TomTom   assert(inputs.size() == 1);
34550ea385SAlexander Belyaev   auto inputType = inputs[0].getType();
355550c821STres Popp   if (isa<TensorType>(inputType))
36015192c6SRiver Riddle     return nullptr;
37015192c6SRiver Riddle 
383b021fbdSKareemErgawy-TomTom   // A detensored value is converted back by creating a new tensor from its
393b021fbdSKareemErgawy-TomTom   // element(s).
40550ea385SAlexander Belyaev   return builder.create<tensor::FromElementsOp>(
41550ea385SAlexander Belyaev       loc, RankedTensorType::get({}, inputType), inputs[0]);
423b021fbdSKareemErgawy-TomTom }
433b021fbdSKareemErgawy-TomTom 
4467e0d58dSKareemErgawy-TomTom namespace {
4567e0d58dSKareemErgawy-TomTom /// Defines the criteria a TensorType must follow in order to be considered
4667e0d58dSKareemErgawy-TomTom /// "detensorable".
4767e0d58dSKareemErgawy-TomTom ///
48aa6eb2afSKareemErgawy-TomTom /// NOTE: For now, only 0-D tensors are supported.
4967e0d58dSKareemErgawy-TomTom ///
5067e0d58dSKareemErgawy-TomTom /// Returns true if tensorType can be detensored.
5167e0d58dSKareemErgawy-TomTom bool canBeDetensored(TensorType tensorType) {
5267e0d58dSKareemErgawy-TomTom   return tensorType.hasRank() && tensorType.getRank() == 0;
5367e0d58dSKareemErgawy-TomTom }
5467e0d58dSKareemErgawy-TomTom 
55aa6eb2afSKareemErgawy-TomTom bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
56aa6eb2afSKareemErgawy-TomTom   GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
577c234ae5STobias Gysi   return genericOp &&
58a7cccb9cSAlexander Belyaev          llvm::all_of(genericOp->getOpOperands(), [&](OpOperand &opOperand) {
59a7cccb9cSAlexander Belyaev            return !typeConverter.isLegal(opOperand.get().getType());
60aa6eb2afSKareemErgawy-TomTom          });
61aa6eb2afSKareemErgawy-TomTom }
62aa6eb2afSKareemErgawy-TomTom 
6365eedcebSAlex Zinenko /// A conversion pattern for detensoring `linalg.generic` ops.
6467e0d58dSKareemErgawy-TomTom class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
6567e0d58dSKareemErgawy-TomTom public:
6667e0d58dSKareemErgawy-TomTom   using OpConversionPattern::OpConversionPattern;
6767e0d58dSKareemErgawy-TomTom   LogicalResult
68b54c724bSRiver Riddle   matchAndRewrite(GenericOp op, OpAdaptor adaptor,
6967e0d58dSKareemErgawy-TomTom                   ConversionPatternRewriter &rewriter) const override {
7067e0d58dSKareemErgawy-TomTom     Block *originalBlock = op->getBlock();
7167e0d58dSKareemErgawy-TomTom 
7265eedcebSAlex Zinenko     // Gather some information about the op before inlining its region.
73d3b3f765SJacques Pienaar     Block *opEntryBlock = &*op.getRegion().begin();
74d3b3f765SJacques Pienaar     YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator());
7567e0d58dSKareemErgawy-TomTom 
7667e0d58dSKareemErgawy-TomTom     // Split the op's region before the op. This way, we have a clear insertion
7767e0d58dSKareemErgawy-TomTom     // point in which the op can be inlined.
78fc64a164STres Popp     Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));
79d3b3f765SJacques Pienaar     rewriter.inlineRegionBefore(op.getRegion(), newBlock);
8067e0d58dSKareemErgawy-TomTom     // Now that op's region is inlined, the operands of its YieldOp are mapped
8167e0d58dSKareemErgawy-TomTom     // to the materialized target values. Therefore, we can replace the op's
8267e0d58dSKareemErgawy-TomTom     // uses with those of its YielOp's operands.
8367e0d58dSKareemErgawy-TomTom     rewriter.replaceOp(op, yieldOp->getOperands());
8467e0d58dSKareemErgawy-TomTom 
8567e0d58dSKareemErgawy-TomTom     // No need for these intermediate blocks, merge them into 1.
86b54c724bSRiver Riddle     rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
8767e0d58dSKareemErgawy-TomTom     rewriter.mergeBlocks(newBlock, originalBlock, {});
8867e0d58dSKareemErgawy-TomTom 
8967e0d58dSKareemErgawy-TomTom     rewriter.eraseOp(&*Block::iterator(yieldOp));
9067e0d58dSKareemErgawy-TomTom 
9167e0d58dSKareemErgawy-TomTom     return success();
9267e0d58dSKareemErgawy-TomTom   }
9367e0d58dSKareemErgawy-TomTom };
9467e0d58dSKareemErgawy-TomTom 
953b021fbdSKareemErgawy-TomTom /// A conversion pattern for detensoring internal (non-entry) blocks within a
963b021fbdSKareemErgawy-TomTom /// function.
977ceffae1SRiver Riddle struct FunctionNonEntryBlockConversion
987ceffae1SRiver Riddle     : public OpInterfaceConversionPattern<FunctionOpInterface> {
99c10995a8SStella Laurenzo   FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
100aa6eb2afSKareemErgawy-TomTom                                   DenseSet<BlockArgument> blockArgsToDetensor)
1017ceffae1SRiver Riddle       : OpInterfaceConversionPattern(converter, ctx),
1021fc096afSMehdi Amini         blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
1033b021fbdSKareemErgawy-TomTom 
1043b021fbdSKareemErgawy-TomTom   LogicalResult
1057ceffae1SRiver Riddle   matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
1063b021fbdSKareemErgawy-TomTom                   ConversionPatternRewriter &rewriter) const override {
1075fcf907bSMatthias Springer     rewriter.startOpModification(op);
108ecba7c58SRiver Riddle     Region &region = op.getFunctionBody();
1093b021fbdSKareemErgawy-TomTom 
11052050f3fSMatthias Springer     for (Block &block :
11152050f3fSMatthias Springer          llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
11252050f3fSMatthias Springer       TypeConverter::SignatureConversion conversion(
11352050f3fSMatthias Springer           /*numOrigInputs=*/block.getNumArguments());
114aa6eb2afSKareemErgawy-TomTom 
115aa6eb2afSKareemErgawy-TomTom       for (BlockArgument blockArgument : block.getArguments()) {
116aa6eb2afSKareemErgawy-TomTom         int idx = blockArgument.getArgNumber();
117aa6eb2afSKareemErgawy-TomTom 
118aa6eb2afSKareemErgawy-TomTom         if (blockArgsToDetensor.count(blockArgument))
11952050f3fSMatthias Springer           conversion.addInputs(idx, {getTypeConverter()->convertType(
120aa6eb2afSKareemErgawy-TomTom                                         block.getArgumentTypes()[idx])});
121aa6eb2afSKareemErgawy-TomTom         else
12252050f3fSMatthias Springer           conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
123aa6eb2afSKareemErgawy-TomTom       }
124aa6eb2afSKareemErgawy-TomTom 
12552050f3fSMatthias Springer       rewriter.applySignatureConversion(&block, conversion, getTypeConverter());
1263b021fbdSKareemErgawy-TomTom     }
1273b021fbdSKareemErgawy-TomTom 
1285fcf907bSMatthias Springer     rewriter.finalizeOpModification(op);
1293b021fbdSKareemErgawy-TomTom     return success();
1303b021fbdSKareemErgawy-TomTom   }
131aa6eb2afSKareemErgawy-TomTom 
132aa6eb2afSKareemErgawy-TomTom private:
133aa6eb2afSKareemErgawy-TomTom   const DenseSet<BlockArgument> blockArgsToDetensor;
1343b021fbdSKareemErgawy-TomTom };
1353b021fbdSKareemErgawy-TomTom 
13667e0d58dSKareemErgawy-TomTom class DetensorizeTypeConverter : public TypeConverter {
13767e0d58dSKareemErgawy-TomTom public:
13867e0d58dSKareemErgawy-TomTom   DetensorizeTypeConverter() {
13967e0d58dSKareemErgawy-TomTom     addConversion([](Type type) { return type; });
14067e0d58dSKareemErgawy-TomTom 
14167e0d58dSKareemErgawy-TomTom     // A TensorType that can be detensored, is converted to the underlying
14267e0d58dSKareemErgawy-TomTom     // element type.
14367e0d58dSKareemErgawy-TomTom     addConversion([](TensorType tensorType) -> Type {
14467e0d58dSKareemErgawy-TomTom       if (canBeDetensored(tensorType))
14567e0d58dSKareemErgawy-TomTom         return tensorType.getElementType();
14667e0d58dSKareemErgawy-TomTom 
14767e0d58dSKareemErgawy-TomTom       return tensorType;
14867e0d58dSKareemErgawy-TomTom     });
14967e0d58dSKareemErgawy-TomTom 
15067e0d58dSKareemErgawy-TomTom     // A tensor value is detensoried by extracting its element(s).
15167e0d58dSKareemErgawy-TomTom     addTargetMaterialization([](OpBuilder &builder, Type type,
15267e0d58dSKareemErgawy-TomTom                                 ValueRange inputs, Location loc) -> Value {
15367e0d58dSKareemErgawy-TomTom       return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
15467e0d58dSKareemErgawy-TomTom     });
15567e0d58dSKareemErgawy-TomTom 
1563b021fbdSKareemErgawy-TomTom     addSourceMaterialization(sourceMaterializationCallback);
15767e0d58dSKareemErgawy-TomTom   }
15867e0d58dSKareemErgawy-TomTom };
15967e0d58dSKareemErgawy-TomTom 
16067e0d58dSKareemErgawy-TomTom /// @see LinalgDetensorize in Linalg/Passes.td for more details.
16167d0d7acSMichele Scuttari struct LinalgDetensorize
1621e98d488SQuinn Dawkins     : public impl::LinalgDetensorizePassBase<LinalgDetensorize> {
1631e98d488SQuinn Dawkins   using impl::LinalgDetensorizePassBase<
1641e98d488SQuinn Dawkins       LinalgDetensorize>::LinalgDetensorizePassBase;
165aa6eb2afSKareemErgawy-TomTom   LinalgDetensorize() = default;
166aa6eb2afSKareemErgawy-TomTom 
167aa6eb2afSKareemErgawy-TomTom   class CostModel {
168aa6eb2afSKareemErgawy-TomTom   public:
169aa6eb2afSKareemErgawy-TomTom     virtual ~CostModel() = default;
170aa6eb2afSKareemErgawy-TomTom 
171aa6eb2afSKareemErgawy-TomTom     /// A cost model algorithm computes the following outputs:
172aa6eb2afSKareemErgawy-TomTom     ///
173aa6eb2afSKareemErgawy-TomTom     /// - opsToDetensor: the list of linalg ops that should be
174aa6eb2afSKareemErgawy-TomTom     /// detensored.
175aa6eb2afSKareemErgawy-TomTom     ///
176aa6eb2afSKareemErgawy-TomTom     /// - blockArgsToDetensor: since the operands and results of detensored
177aa6eb2afSKareemErgawy-TomTom     /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
178aa6eb2afSKareemErgawy-TomTom     /// from a BB argument and a linalg op's output can be passed to successor
179aa6eb2afSKareemErgawy-TomTom     /// BBs), we need to maintain the sub-set of arguments that should be
180aa6eb2afSKareemErgawy-TomTom     /// detensored (i.e. converted by typeConverter) for each affected BB.
181aa6eb2afSKareemErgawy-TomTom     ///
182aa6eb2afSKareemErgawy-TomTom     /// Example:
183aa6eb2afSKareemErgawy-TomTom     ///
184aa6eb2afSKareemErgawy-TomTom     /// For the following snippet:
185aa6eb2afSKareemErgawy-TomTom     /// ...
186aa6eb2afSKareemErgawy-TomTom     /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
18781ca5aa4SMatthias Springer     ///   %7 = tensor.empty() : tensor<i32>
188aa6eb2afSKareemErgawy-TomTom     ///   %8 = linalg.generic #attrs
189aa6eb2afSKareemErgawy-TomTom     ///     ins(%6, %6 : tensor<i32>, tensor<i32>)
190aa6eb2afSKareemErgawy-TomTom     ///     outs(%7 : tensor<i32>) {
191aa6eb2afSKareemErgawy-TomTom     ///     ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
192a54f4eaeSMogball     ///       %9 = arith.addi %arg0, %arg1 : i32
193aa6eb2afSKareemErgawy-TomTom     ///       linalg.yield %9 : i32
194aa6eb2afSKareemErgawy-TomTom     ///   } -> tensor<i32>
195aa6eb2afSKareemErgawy-TomTom     ///   %10 = "some.op"(%9)
196aa6eb2afSKareemErgawy-TomTom     ///   br ^bb2(%8 : tensor<i32>)
197aa6eb2afSKareemErgawy-TomTom     /// ...
198aa6eb2afSKareemErgawy-TomTom     ///
199aa6eb2afSKareemErgawy-TomTom     /// if the cost model decides that the linalg.generic op should be
200aa6eb2afSKareemErgawy-TomTom     /// detensored, then:
201aa6eb2afSKareemErgawy-TomTom     /// - opsToDetensor should be = {linalg.generic{add}}.
202aa6eb2afSKareemErgawy-TomTom     /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
2037ceffae1SRiver Riddle     virtual void compute(FunctionOpInterface func,
204c10995a8SStella Laurenzo                          DetensorizeTypeConverter typeConverter,
205aa6eb2afSKareemErgawy-TomTom                          DenseSet<Operation *> &opsToDetensor,
206aa6eb2afSKareemErgawy-TomTom                          DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
207aa6eb2afSKareemErgawy-TomTom 
208aa6eb2afSKareemErgawy-TomTom     /// From the blockArgsToDetensor set computed by a CostModel
209aa6eb2afSKareemErgawy-TomTom     /// implementation, this method computes the corresponding branch op
210aa6eb2afSKareemErgawy-TomTom     /// detensoring. The result is a map from a branch op to a subset of indices
211aa6eb2afSKareemErgawy-TomTom     /// of its operands. The indices specify which of the branch op's operands
212aa6eb2afSKareemErgawy-TomTom     /// should be detensored.
213aa6eb2afSKareemErgawy-TomTom     ///
214aa6eb2afSKareemErgawy-TomTom     /// For the previous example, this method would compute: {bb2 -> {0}}.
215aa6eb2afSKareemErgawy-TomTom     static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
216aa6eb2afSKareemErgawy-TomTom         const DenseSet<BlockArgument> &blockArgsToDetensor) {
217aa6eb2afSKareemErgawy-TomTom       DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
218aa6eb2afSKareemErgawy-TomTom 
219aa6eb2afSKareemErgawy-TomTom       for (auto blockArgumentElem : blockArgsToDetensor) {
220aa6eb2afSKareemErgawy-TomTom         Block *block = blockArgumentElem.getOwner();
221aa6eb2afSKareemErgawy-TomTom 
222aa6eb2afSKareemErgawy-TomTom         for (PredecessorIterator pred = block->pred_begin();
223aa6eb2afSKareemErgawy-TomTom              pred != block->pred_end(); ++pred) {
224aa6eb2afSKareemErgawy-TomTom           BranchOpInterface terminator =
225aa6eb2afSKareemErgawy-TomTom               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
226aa6eb2afSKareemErgawy-TomTom           auto blockOperands =
227aa6eb2afSKareemErgawy-TomTom               terminator.getSuccessorOperands(pred.getSuccessorIndex());
228aa6eb2afSKareemErgawy-TomTom 
2290c789db5SMarkus Böck           if (blockOperands.empty() ||
2300c789db5SMarkus Böck               blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
231aa6eb2afSKareemErgawy-TomTom             continue;
232aa6eb2afSKareemErgawy-TomTom 
233aa6eb2afSKareemErgawy-TomTom           detensorableBranchOps[terminator].insert(
2340c789db5SMarkus Böck               blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
235aa6eb2afSKareemErgawy-TomTom         }
236aa6eb2afSKareemErgawy-TomTom       }
237aa6eb2afSKareemErgawy-TomTom 
238aa6eb2afSKareemErgawy-TomTom       return detensorableBranchOps;
239aa6eb2afSKareemErgawy-TomTom     }
240aa6eb2afSKareemErgawy-TomTom   };
241aa6eb2afSKareemErgawy-TomTom 
242aa6eb2afSKareemErgawy-TomTom   /// Detensorize linalg ops involved in control-flow within a function.
243aa6eb2afSKareemErgawy-TomTom   ///
244bdcf4b9bSKareemErgawy-TomTom   /// This model starts from BranchOps and CondBranchOps within a function. For
245bdcf4b9bSKareemErgawy-TomTom   /// each such branch, the model then walks the use-def chain for the branch's
246bdcf4b9bSKareemErgawy-TomTom   /// condition backwards in order to understand where the condition's value
247bdcf4b9bSKareemErgawy-TomTom   /// comes from. If the condition value is (indirectly) computed by a linalg op
248bdcf4b9bSKareemErgawy-TomTom   /// that can be detensored, the model then continues walking the use-def chain
249bdcf4b9bSKareemErgawy-TomTom   /// in order to understand where the linalg op's operands come from. This
250bdcf4b9bSKareemErgawy-TomTom   /// leads to discovering a "detensoring component". A detensoring component is
251bdcf4b9bSKareemErgawy-TomTom   /// the set of operations + block arguments that are involved in control-flow
252bdcf4b9bSKareemErgawy-TomTom   /// AND can be detensored.
253bdcf4b9bSKareemErgawy-TomTom   class ControlFlowDetectionModel : public CostModel {
254aa6eb2afSKareemErgawy-TomTom   public:
2557ceffae1SRiver Riddle     void compute(FunctionOpInterface func,
2567ceffae1SRiver Riddle                  DetensorizeTypeConverter typeConverter,
257aa6eb2afSKareemErgawy-TomTom                  DenseSet<Operation *> &opsToDetensor,
258aa6eb2afSKareemErgawy-TomTom                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
259aa6eb2afSKareemErgawy-TomTom       SmallVector<Value> workList;
260aa6eb2afSKareemErgawy-TomTom 
261ace01605SRiver Riddle       func->walk([&](cf::CondBranchOp condBr) {
26289d8035eSBenjamin Kramer         llvm::append_range(workList, condBr.getOperands());
263f984a805SKareemErgawy-TomTom       });
264f984a805SKareemErgawy-TomTom 
265ace01605SRiver Riddle       func->walk([&](cf::BranchOp br) {
26689d8035eSBenjamin Kramer         llvm::append_range(workList, br.getOperands());
267f984a805SKareemErgawy-TomTom       });
268aa6eb2afSKareemErgawy-TomTom 
269aa6eb2afSKareemErgawy-TomTom       DenseSet<Value> visitedValues;
270aa6eb2afSKareemErgawy-TomTom       DenseSet<Operation *> visitedOps;
271aa6eb2afSKareemErgawy-TomTom 
2720b05207eSKareemErgawy-TomTom       // For a (to-be-detesored) value, check if it "escapes" the block by being
2730b05207eSKareemErgawy-TomTom       // passed to terminator. If it does, then workList is updated with the
2740b05207eSKareemErgawy-TomTom       // corresponding argument to the successor block.
2750b05207eSKareemErgawy-TomTom       auto updateWorkListWithSuccessorArguments =
2760b05207eSKareemErgawy-TomTom           [&](Value value, BranchOpInterface terminator) {
2770b05207eSKareemErgawy-TomTom             if (!terminator)
2780b05207eSKareemErgawy-TomTom               return;
2790b05207eSKareemErgawy-TomTom 
2800b05207eSKareemErgawy-TomTom             for (auto operandIdx :
2810b05207eSKareemErgawy-TomTom                  llvm::seq<unsigned>(0, terminator->getOperands().size())) {
2820b05207eSKareemErgawy-TomTom               Value operand = terminator->getOperand(operandIdx);
2830b05207eSKareemErgawy-TomTom 
2840b05207eSKareemErgawy-TomTom               if (operand == value) {
2850b05207eSKareemErgawy-TomTom                 auto succBlockArg =
2860b05207eSKareemErgawy-TomTom                     terminator.getSuccessorBlockArgument(operandIdx);
2870b05207eSKareemErgawy-TomTom 
2880b05207eSKareemErgawy-TomTom                 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
2890b05207eSKareemErgawy-TomTom                   workList.push_back(*succBlockArg);
2900b05207eSKareemErgawy-TomTom               }
2910b05207eSKareemErgawy-TomTom             }
2920b05207eSKareemErgawy-TomTom           };
2930b05207eSKareemErgawy-TomTom 
294aa6eb2afSKareemErgawy-TomTom       while (!workList.empty()) {
295aa6eb2afSKareemErgawy-TomTom         Value currentItem = workList.pop_back_val();
296aa6eb2afSKareemErgawy-TomTom 
297aa6eb2afSKareemErgawy-TomTom         if (!visitedValues.insert(currentItem).second)
298aa6eb2afSKareemErgawy-TomTom           continue;
299aa6eb2afSKareemErgawy-TomTom 
3000b05207eSKareemErgawy-TomTom         // 1   - Look forward:
3010b05207eSKareemErgawy-TomTom         // 1.1 - If currentItem escapes to one or more successors, add
3020b05207eSKareemErgawy-TomTom         // the corresponding successor arguments to workList.
3030b05207eSKareemErgawy-TomTom         updateWorkListWithSuccessorArguments(
3040b05207eSKareemErgawy-TomTom             currentItem, dyn_cast<BranchOpInterface>(
3050b05207eSKareemErgawy-TomTom                              currentItem.getParentBlock()->getTerminator()));
3060b05207eSKareemErgawy-TomTom 
3070b05207eSKareemErgawy-TomTom         // 1.2 - For each user of currentItem, add the defined values to
3080b05207eSKareemErgawy-TomTom         // workList. This way, the user ops can be inspected later if they are
3090b05207eSKareemErgawy-TomTom         // detensorable and if so, their operands will be added to workList to
3100b05207eSKareemErgawy-TomTom         // potentially discover other parts of the detensorable component.
3110b05207eSKareemErgawy-TomTom         for (auto *user : currentItem.getUsers())
31289d8035eSBenjamin Kramer           llvm::append_range(workList, user->getResults());
3130b05207eSKareemErgawy-TomTom 
3140b05207eSKareemErgawy-TomTom         // 2   - Look backward:
3150b05207eSKareemErgawy-TomTom         // 2.1 - The current item is defined by a block argument. If the owner
3160b05207eSKareemErgawy-TomTom         // block is a non-entry one, then:
3170b05207eSKareemErgawy-TomTom         //       * Add the argument to blockArgsToDetensor.
3180b05207eSKareemErgawy-TomTom         //       * Walk the use-def chain backwards to add each predecessor's
3190b05207eSKareemErgawy-TomTom         //       terminator-operands corresponding to currentItem to workList.
3205550c821STres Popp         if (dyn_cast<BlockArgument>(currentItem)) {
321aa6eb2afSKareemErgawy-TomTom           BlockArgument currentItemBlockArgument =
3225550c821STres Popp               cast<BlockArgument>(currentItem);
323aa6eb2afSKareemErgawy-TomTom           Block *ownerBlock = currentItemBlockArgument.getOwner();
324aa6eb2afSKareemErgawy-TomTom 
325aa6eb2afSKareemErgawy-TomTom           // Function arguments are not detensored/converted.
326aa6eb2afSKareemErgawy-TomTom           if (&*ownerBlock->getParent()->begin() == ownerBlock)
327aa6eb2afSKareemErgawy-TomTom             continue;
328aa6eb2afSKareemErgawy-TomTom 
329aa6eb2afSKareemErgawy-TomTom           // This inner-block argument is involved in control-flow, it should be
330aa6eb2afSKareemErgawy-TomTom           // detensored.
331aa6eb2afSKareemErgawy-TomTom           blockArgsToDetensor.insert(currentItemBlockArgument);
332aa6eb2afSKareemErgawy-TomTom 
333aa6eb2afSKareemErgawy-TomTom           for (PredecessorIterator pred = ownerBlock->pred_begin();
334aa6eb2afSKareemErgawy-TomTom                pred != ownerBlock->pred_end(); ++pred) {
335bdcf4b9bSKareemErgawy-TomTom             BranchOpInterface predTerminator =
336aa6eb2afSKareemErgawy-TomTom                 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
337aa6eb2afSKareemErgawy-TomTom 
338aa6eb2afSKareemErgawy-TomTom             // TODO: For now, we give up if any of the control-flow components
339aa6eb2afSKareemErgawy-TomTom             // in a function is not detensorable. Fix that.
340bdcf4b9bSKareemErgawy-TomTom             if (!predTerminator) {
341aa6eb2afSKareemErgawy-TomTom               opsToDetensor.clear();
342aa6eb2afSKareemErgawy-TomTom               blockArgsToDetensor.clear();
343aa6eb2afSKareemErgawy-TomTom               return;
344aa6eb2afSKareemErgawy-TomTom             }
345aa6eb2afSKareemErgawy-TomTom 
346aa6eb2afSKareemErgawy-TomTom             auto ownerBlockOperands =
347bdcf4b9bSKareemErgawy-TomTom                 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
348aa6eb2afSKareemErgawy-TomTom 
3490c789db5SMarkus Böck             if (ownerBlockOperands.empty() ||
3500c789db5SMarkus Böck                 ownerBlockOperands.isOperandProduced(
3510c789db5SMarkus Böck                     currentItemBlockArgument.getArgNumber()))
352aa6eb2afSKareemErgawy-TomTom               continue;
353aa6eb2afSKareemErgawy-TomTom 
354aa6eb2afSKareemErgawy-TomTom             // For each predecessor, add the value it passes to that argument to
355aa6eb2afSKareemErgawy-TomTom             // workList to find out how it's computed.
356aa6eb2afSKareemErgawy-TomTom             workList.push_back(
3570c789db5SMarkus Böck                 ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
358aa6eb2afSKareemErgawy-TomTom           }
359aa6eb2afSKareemErgawy-TomTom 
360aa6eb2afSKareemErgawy-TomTom           continue;
361aa6eb2afSKareemErgawy-TomTom         }
362aa6eb2afSKareemErgawy-TomTom 
363aa6eb2afSKareemErgawy-TomTom         Operation *currentItemDefiningOp = currentItem.getDefiningOp();
364aa6eb2afSKareemErgawy-TomTom 
365aa6eb2afSKareemErgawy-TomTom         if (!visitedOps.insert(currentItemDefiningOp).second)
366aa6eb2afSKareemErgawy-TomTom           continue;
367aa6eb2afSKareemErgawy-TomTom 
3680b05207eSKareemErgawy-TomTom         // 2.2 - The current item is computed by a GenericOp. If the op should
3690b05207eSKareemErgawy-TomTom         // be detensored, then:
3700b05207eSKareemErgawy-TomTom         //       * Add it to opsToDetensor.
3710b05207eSKareemErgawy-TomTom         //       * Add its operands to workList to discover other parts of the
3720b05207eSKareemErgawy-TomTom         //       potentially detensorable component.
373aa6eb2afSKareemErgawy-TomTom         if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
374aa6eb2afSKareemErgawy-TomTom           // The op was encountered already, no need to inspect it again.
375aa6eb2afSKareemErgawy-TomTom           if (opsToDetensor.count(genericOp))
376aa6eb2afSKareemErgawy-TomTom             continue;
377aa6eb2afSKareemErgawy-TomTom 
378bdcf4b9bSKareemErgawy-TomTom           // The op should not be detensored, give up on it but continue with
379bdcf4b9bSKareemErgawy-TomTom           // discovering the rest of the control-flow component.
380aa6eb2afSKareemErgawy-TomTom           if (!shouldBeDetensored(genericOp, typeConverter)) {
381bdcf4b9bSKareemErgawy-TomTom             continue;
382aa6eb2afSKareemErgawy-TomTom           }
383aa6eb2afSKareemErgawy-TomTom 
384aa6eb2afSKareemErgawy-TomTom           opsToDetensor.insert(genericOp);
385d3b3f765SJacques Pienaar           llvm::append_range(workList, genericOp.getInputs());
386aa6eb2afSKareemErgawy-TomTom           continue;
387aa6eb2afSKareemErgawy-TomTom         }
388aa6eb2afSKareemErgawy-TomTom 
3890b05207eSKareemErgawy-TomTom         // 2.3 - The current item is the result of a FromElementsOp, it will be
390aa6eb2afSKareemErgawy-TomTom         // trivially detensored later as part of canonicalization patterns
391aa6eb2afSKareemErgawy-TomTom         // applied at the end of detensoring.
392aa6eb2afSKareemErgawy-TomTom         //
393aa6eb2afSKareemErgawy-TomTom         // Note: No need to check whether the result type of this op is
394aa6eb2afSKareemErgawy-TomTom         // detensorable since if it wasn't we wouldn't reach that point in the
395aa6eb2afSKareemErgawy-TomTom         // work list.
396e4c39501SRahul Joshi         if (isa<tensor::FromElementsOp>(currentItemDefiningOp))
397aa6eb2afSKareemErgawy-TomTom           continue;
398aa6eb2afSKareemErgawy-TomTom 
3990b05207eSKareemErgawy-TomTom         // 2.4 - The current item is the result of a scalar op, add all its
4000b05207eSKareemErgawy-TomTom         // operands to the work list.
401aa6eb2afSKareemErgawy-TomTom         if (llvm::all_of(
402aa6eb2afSKareemErgawy-TomTom                 currentItemDefiningOp->getResultTypes(),
403aa6eb2afSKareemErgawy-TomTom                 [&](Type resultType) { return resultType.isIntOrFloat(); }))
40489d8035eSBenjamin Kramer           llvm::append_range(workList, currentItemDefiningOp->getOperands());
405aa6eb2afSKareemErgawy-TomTom       }
406bdcf4b9bSKareemErgawy-TomTom 
407bdcf4b9bSKareemErgawy-TomTom       // Since the cost model gives up on some ops (see the details of step 2.2
408bdcf4b9bSKareemErgawy-TomTom       // above), block arguments that correspond to the values produced by those
409bdcf4b9bSKareemErgawy-TomTom       // ops should not be detensored as well.
410bdcf4b9bSKareemErgawy-TomTom 
411bdcf4b9bSKareemErgawy-TomTom       DenseSet<BlockArgument> blockArgsToRemove;
412bdcf4b9bSKareemErgawy-TomTom 
413bdcf4b9bSKareemErgawy-TomTom       for (auto &blockArg : blockArgsToDetensor) {
414bdcf4b9bSKareemErgawy-TomTom         Block *block = blockArg.getParentBlock();
415bdcf4b9bSKareemErgawy-TomTom 
416bdcf4b9bSKareemErgawy-TomTom         // For the potentially detensorable block argument, find the
417bdcf4b9bSKareemErgawy-TomTom         // correpsonding operands in predecessor blocks.
418bdcf4b9bSKareemErgawy-TomTom         for (PredecessorIterator pred = block->pred_begin();
419bdcf4b9bSKareemErgawy-TomTom              pred != block->pred_end(); ++pred) {
420bdcf4b9bSKareemErgawy-TomTom           BranchOpInterface terminator =
421bdcf4b9bSKareemErgawy-TomTom               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
422bdcf4b9bSKareemErgawy-TomTom           auto blockOperands =
423bdcf4b9bSKareemErgawy-TomTom               terminator.getSuccessorOperands(pred.getSuccessorIndex());
424bdcf4b9bSKareemErgawy-TomTom 
4250c789db5SMarkus Böck           if (blockOperands.empty() ||
4260c789db5SMarkus Böck               blockOperands.isOperandProduced(blockArg.getArgNumber()))
427bdcf4b9bSKareemErgawy-TomTom             continue;
428bdcf4b9bSKareemErgawy-TomTom 
429bdcf4b9bSKareemErgawy-TomTom           Operation *definingOp =
4300c789db5SMarkus Böck               blockOperands[blockArg.getArgNumber()].getDefiningOp();
431bdcf4b9bSKareemErgawy-TomTom 
432bdcf4b9bSKareemErgawy-TomTom           // If the operand is defined by a GenericOp that will not be
433bdcf4b9bSKareemErgawy-TomTom           // detensored, then do not detensor the corresponding block argument.
4340c789db5SMarkus Böck           if (isa_and_nonnull<GenericOp>(definingOp) &&
435bdcf4b9bSKareemErgawy-TomTom               opsToDetensor.count(definingOp) == 0) {
436bdcf4b9bSKareemErgawy-TomTom             blockArgsToRemove.insert(blockArg);
437bdcf4b9bSKareemErgawy-TomTom             break;
438bdcf4b9bSKareemErgawy-TomTom           }
439bdcf4b9bSKareemErgawy-TomTom         }
440bdcf4b9bSKareemErgawy-TomTom       }
441bdcf4b9bSKareemErgawy-TomTom 
442bdcf4b9bSKareemErgawy-TomTom       for (auto &blockArg : blockArgsToRemove) {
443bdcf4b9bSKareemErgawy-TomTom         blockArgsToDetensor.erase(blockArg);
444bdcf4b9bSKareemErgawy-TomTom       }
445aa6eb2afSKareemErgawy-TomTom     }
446aa6eb2afSKareemErgawy-TomTom   };
447aa6eb2afSKareemErgawy-TomTom 
448aa6eb2afSKareemErgawy-TomTom   /// Detensorize everything that can detensored.
449aa6eb2afSKareemErgawy-TomTom   class AggressiveDetensoringModel : public CostModel {
450aa6eb2afSKareemErgawy-TomTom   public:
4517ceffae1SRiver Riddle     void compute(FunctionOpInterface func,
4527ceffae1SRiver Riddle                  DetensorizeTypeConverter typeConverter,
453aa6eb2afSKareemErgawy-TomTom                  DenseSet<Operation *> &opsToDetensor,
454aa6eb2afSKareemErgawy-TomTom                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
455c10995a8SStella Laurenzo       func->walk([&](GenericOp genericOp) {
456aa6eb2afSKareemErgawy-TomTom         if (shouldBeDetensored(genericOp, typeConverter))
457aa6eb2afSKareemErgawy-TomTom           opsToDetensor.insert(genericOp);
458aa6eb2afSKareemErgawy-TomTom       });
459aa6eb2afSKareemErgawy-TomTom 
460ecba7c58SRiver Riddle       for (Block &block : llvm::drop_begin(func.getFunctionBody(), 1))
461aa6eb2afSKareemErgawy-TomTom         for (BlockArgument blockArgument : block.getArguments())
462aa6eb2afSKareemErgawy-TomTom           blockArgsToDetensor.insert(blockArgument);
463aa6eb2afSKareemErgawy-TomTom     }
464aa6eb2afSKareemErgawy-TomTom   };
465aa6eb2afSKareemErgawy-TomTom 
466c10995a8SStella Laurenzo   void runOnOperation() override {
467aa6eb2afSKareemErgawy-TomTom     MLIRContext *context = &getContext();
46867e0d58dSKareemErgawy-TomTom     DetensorizeTypeConverter typeConverter;
469dc4e913bSChris Lattner     RewritePatternSet patterns(context);
47067e0d58dSKareemErgawy-TomTom     ConversionTarget target(*context);
471aa6eb2afSKareemErgawy-TomTom     DenseSet<Operation *> opsToDetensor;
472aa6eb2afSKareemErgawy-TomTom     DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
473aa6eb2afSKareemErgawy-TomTom     DenseSet<BlockArgument> blockArgsToDetensor;
474b6d9e30bSFelix Schneider     FunctionOpInterface funcOp = getOperation();
47567e0d58dSKareemErgawy-TomTom 
4760fac44d8SOkwan Kwon     if (funcOp.getFunctionBody().empty())
4770fac44d8SOkwan Kwon       return;
4780fac44d8SOkwan Kwon 
47965eedcebSAlex Zinenko     // Make sure the entry block of the function doesn't contain any Linalg ops.
48065eedcebSAlex Zinenko     // Otherwise, it may lead to the signature of the block being changed by the
48165eedcebSAlex Zinenko     // dialect conversion below, which would make the function op invalid
48265eedcebSAlex Zinenko     // because its type shouldn't change.
48365eedcebSAlex Zinenko     IRRewriter rewriter(funcOp->getContext());
48465eedcebSAlex Zinenko     Block *entryBlock = &funcOp.getFunctionBody().front();
48565eedcebSAlex Zinenko     Block *postEntryBlock =
48665eedcebSAlex Zinenko         rewriter.splitBlock(entryBlock, entryBlock->begin());
48765eedcebSAlex Zinenko     rewriter.setInsertionPointToStart(entryBlock);
48865eedcebSAlex Zinenko     auto branch =
48965eedcebSAlex Zinenko         rewriter.create<cf::BranchOp>(rewriter.getUnknownLoc(), postEntryBlock);
49065eedcebSAlex Zinenko 
491aa6eb2afSKareemErgawy-TomTom     if (aggressiveMode.getValue()) {
492aa6eb2afSKareemErgawy-TomTom       AggressiveDetensoringModel costModel;
4937ceffae1SRiver Riddle       costModel.compute(funcOp, typeConverter, opsToDetensor,
494aa6eb2afSKareemErgawy-TomTom                         blockArgsToDetensor);
495aa6eb2afSKareemErgawy-TomTom     } else {
496bdcf4b9bSKareemErgawy-TomTom       ControlFlowDetectionModel costModel;
4977ceffae1SRiver Riddle       costModel.compute(funcOp, typeConverter, opsToDetensor,
498aa6eb2afSKareemErgawy-TomTom                         blockArgsToDetensor);
499aa6eb2afSKareemErgawy-TomTom     }
500aa6eb2afSKareemErgawy-TomTom 
501aa6eb2afSKareemErgawy-TomTom     detensorableBranchOps =
502aa6eb2afSKareemErgawy-TomTom         CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
503aa6eb2afSKareemErgawy-TomTom 
504aa6eb2afSKareemErgawy-TomTom     target.addDynamicallyLegalOp<GenericOp>(
505aa6eb2afSKareemErgawy-TomTom         [&](GenericOp op) { return !opsToDetensor.count(op); });
50667e0d58dSKareemErgawy-TomTom 
507c10995a8SStella Laurenzo     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
508aa6eb2afSKareemErgawy-TomTom       // A function is legal if all of its non-entry blocks are legal. We
5090b05207eSKareemErgawy-TomTom       // don't legalize the entry block (i.e. the function's signature)
5100b05207eSKareemErgawy-TomTom       // since detensoring can't happen along external calling convention
511aa6eb2afSKareemErgawy-TomTom       // boundaries, which we conservatively approximate as all function
512aa6eb2afSKareemErgawy-TomTom       // signatures.
5137ceffae1SRiver Riddle       if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
514ecba7c58SRiver Riddle         Region &body = funcOp.getFunctionBody();
515c10995a8SStella Laurenzo         return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
5166786d7e4SMehdi Amini           return !llvm::any_of(
517c10995a8SStella Laurenzo               blockArgsToDetensor, [&](BlockArgument blockArgument) {
518aa6eb2afSKareemErgawy-TomTom                 return blockArgument.getOwner() == &block &&
519aa6eb2afSKareemErgawy-TomTom                        !typeConverter.isLegal(blockArgument.getType());
5206786d7e4SMehdi Amini               });
5213b021fbdSKareemErgawy-TomTom         });
522c10995a8SStella Laurenzo       }
5233b021fbdSKareemErgawy-TomTom 
524aa6eb2afSKareemErgawy-TomTom       if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
525aa6eb2afSKareemErgawy-TomTom           isLegalForReturnOpTypeConversionPattern(op, typeConverter,
526aa6eb2afSKareemErgawy-TomTom                                                   /*returnOpAlwaysLegal*/ true))
527aa6eb2afSKareemErgawy-TomTom         return true;
528aa6eb2afSKareemErgawy-TomTom 
529aa6eb2afSKareemErgawy-TomTom       if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
530aa6eb2afSKareemErgawy-TomTom         if (!detensorableBranchOps.count(branchOp))
531aa6eb2afSKareemErgawy-TomTom           return true;
532aa6eb2afSKareemErgawy-TomTom 
533aa6eb2afSKareemErgawy-TomTom         for (auto operandIdx : detensorableBranchOps[branchOp])
534aa6eb2afSKareemErgawy-TomTom           if (!typeConverter.isLegal(
535aa6eb2afSKareemErgawy-TomTom                   branchOp->getOperand(operandIdx).getType()))
536aa6eb2afSKareemErgawy-TomTom             return false;
537aa6eb2afSKareemErgawy-TomTom 
538aa6eb2afSKareemErgawy-TomTom         return true;
539aa6eb2afSKareemErgawy-TomTom       }
540aa6eb2afSKareemErgawy-TomTom 
541aa6eb2afSKareemErgawy-TomTom       return false;
5423b021fbdSKareemErgawy-TomTom     });
5433b021fbdSKareemErgawy-TomTom 
544b4e0507cSTres Popp     patterns.add<DetensorizeGenericOp>(typeConverter, context);
545b4e0507cSTres Popp     patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter,
546aa6eb2afSKareemErgawy-TomTom                                                   blockArgsToDetensor);
547aa6eb2afSKareemErgawy-TomTom     // Since non-entry block arguments get detensorized, we also need to
548aa6eb2afSKareemErgawy-TomTom     // update the control flow inside the function to reflect the correct
549aa6eb2afSKareemErgawy-TomTom     // types.
550aa6eb2afSKareemErgawy-TomTom     auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
551aa6eb2afSKareemErgawy-TomTom                                           int operandIdx) -> bool {
552aa6eb2afSKareemErgawy-TomTom       return detensorableBranchOps.count(branchOp) &&
553aa6eb2afSKareemErgawy-TomTom              detensorableBranchOps[branchOp].count(operandIdx);
554aa6eb2afSKareemErgawy-TomTom     };
555aa6eb2afSKareemErgawy-TomTom 
556aa6eb2afSKareemErgawy-TomTom     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
557aa6eb2afSKareemErgawy-TomTom                                                    shouldConvertBranchOperand);
55867e0d58dSKareemErgawy-TomTom 
559c10995a8SStella Laurenzo     if (failed(
560c10995a8SStella Laurenzo             applyFullConversion(getOperation(), target, std::move(patterns))))
56167e0d58dSKareemErgawy-TomTom       signalPassFailure();
56267e0d58dSKareemErgawy-TomTom 
563dc4e913bSChris Lattner     RewritePatternSet canonPatterns(context);
564550ea385SAlexander Belyaev     tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
565*09dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(getOperation(), std::move(canonPatterns))))
56667e0d58dSKareemErgawy-TomTom       signalPassFailure();
56765eedcebSAlex Zinenko 
56865eedcebSAlex Zinenko     // Get rid of the dummy entry block we created in the beginning to work
56965eedcebSAlex Zinenko     // around dialect conversion signature rewriting.
57065eedcebSAlex Zinenko     rewriter.eraseOp(branch);
57165eedcebSAlex Zinenko     rewriter.mergeBlocks(postEntryBlock, entryBlock);
57267e0d58dSKareemErgawy-TomTom   }
57367e0d58dSKareemErgawy-TomTom };
57467e0d58dSKareemErgawy-TomTom } // namespace
575