167e0d58dSKareemErgawy-TomTom //===- Detensorize.cpp - Linalg transformations as patterns ----------===// 267e0d58dSKareemErgawy-TomTom // 367e0d58dSKareemErgawy-TomTom // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 467e0d58dSKareemErgawy-TomTom // See https://llvm.org/LICENSE.txt for license information. 567e0d58dSKareemErgawy-TomTom // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 667e0d58dSKareemErgawy-TomTom // 767e0d58dSKareemErgawy-TomTom //===----------------------------------------------------------------------===// 867e0d58dSKareemErgawy-TomTom 967d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h" 1067d0d7acSMichele Scuttari 11ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 1267d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h" 1323aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 14b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h" 1567e0d58dSKareemErgawy-TomTom #include "mlir/Dialect/Tensor/IR/Tensor.h" 1667e0d58dSKareemErgawy-TomTom #include "mlir/IR/OpDefinition.h" 1767e0d58dSKareemErgawy-TomTom #include "mlir/Transforms/DialectConversion.h" 1867e0d58dSKareemErgawy-TomTom #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 1967e0d58dSKareemErgawy-TomTom #include <iterator> 2067e0d58dSKareemErgawy-TomTom #include <memory> 211fc096afSMehdi Amini #include <utility> 2267e0d58dSKareemErgawy-TomTom 2367d0d7acSMichele Scuttari namespace mlir { 241e98d488SQuinn Dawkins #define GEN_PASS_DEF_LINALGDETENSORIZEPASS 2567d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h.inc" 2667d0d7acSMichele Scuttari } // namespace mlir 2767d0d7acSMichele Scuttari 2867e0d58dSKareemErgawy-TomTom using namespace mlir; 2967e0d58dSKareemErgawy-TomTom using namespace mlir::linalg; 3067e0d58dSKareemErgawy-TomTom 313b021fbdSKareemErgawy-TomTom static Value sourceMaterializationCallback(OpBuilder &builder, Type type, 323b021fbdSKareemErgawy-TomTom ValueRange inputs, Location loc) { 333b021fbdSKareemErgawy-TomTom assert(inputs.size() == 1); 34550ea385SAlexander Belyaev auto inputType = inputs[0].getType(); 355550c821STres Popp if (isa<TensorType>(inputType)) 36015192c6SRiver Riddle return nullptr; 37015192c6SRiver Riddle 383b021fbdSKareemErgawy-TomTom // A detensored value is converted back by creating a new tensor from its 393b021fbdSKareemErgawy-TomTom // element(s). 40550ea385SAlexander Belyaev return builder.create<tensor::FromElementsOp>( 41550ea385SAlexander Belyaev loc, RankedTensorType::get({}, inputType), inputs[0]); 423b021fbdSKareemErgawy-TomTom } 433b021fbdSKareemErgawy-TomTom 4467e0d58dSKareemErgawy-TomTom namespace { 4567e0d58dSKareemErgawy-TomTom /// Defines the criteria a TensorType must follow in order to be considered 4667e0d58dSKareemErgawy-TomTom /// "detensorable". 4767e0d58dSKareemErgawy-TomTom /// 48aa6eb2afSKareemErgawy-TomTom /// NOTE: For now, only 0-D tensors are supported. 4967e0d58dSKareemErgawy-TomTom /// 5067e0d58dSKareemErgawy-TomTom /// Returns true if tensorType can be detensored. 5167e0d58dSKareemErgawy-TomTom bool canBeDetensored(TensorType tensorType) { 5267e0d58dSKareemErgawy-TomTom return tensorType.hasRank() && tensorType.getRank() == 0; 5367e0d58dSKareemErgawy-TomTom } 5467e0d58dSKareemErgawy-TomTom 55aa6eb2afSKareemErgawy-TomTom bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) { 56aa6eb2afSKareemErgawy-TomTom GenericOp genericOp = dyn_cast_or_null<GenericOp>(op); 577c234ae5STobias Gysi return genericOp && 58a7cccb9cSAlexander Belyaev llvm::all_of(genericOp->getOpOperands(), [&](OpOperand &opOperand) { 59a7cccb9cSAlexander Belyaev return !typeConverter.isLegal(opOperand.get().getType()); 60aa6eb2afSKareemErgawy-TomTom }); 61aa6eb2afSKareemErgawy-TomTom } 62aa6eb2afSKareemErgawy-TomTom 6365eedcebSAlex Zinenko /// A conversion pattern for detensoring `linalg.generic` ops. 6467e0d58dSKareemErgawy-TomTom class DetensorizeGenericOp : public OpConversionPattern<GenericOp> { 6567e0d58dSKareemErgawy-TomTom public: 6667e0d58dSKareemErgawy-TomTom using OpConversionPattern::OpConversionPattern; 6767e0d58dSKareemErgawy-TomTom LogicalResult 68b54c724bSRiver Riddle matchAndRewrite(GenericOp op, OpAdaptor adaptor, 6967e0d58dSKareemErgawy-TomTom ConversionPatternRewriter &rewriter) const override { 7067e0d58dSKareemErgawy-TomTom Block *originalBlock = op->getBlock(); 7167e0d58dSKareemErgawy-TomTom 7265eedcebSAlex Zinenko // Gather some information about the op before inlining its region. 73d3b3f765SJacques Pienaar Block *opEntryBlock = &*op.getRegion().begin(); 74d3b3f765SJacques Pienaar YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator()); 7567e0d58dSKareemErgawy-TomTom 7667e0d58dSKareemErgawy-TomTom // Split the op's region before the op. This way, we have a clear insertion 7767e0d58dSKareemErgawy-TomTom // point in which the op can be inlined. 78fc64a164STres Popp Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op)); 79d3b3f765SJacques Pienaar rewriter.inlineRegionBefore(op.getRegion(), newBlock); 8067e0d58dSKareemErgawy-TomTom // Now that op's region is inlined, the operands of its YieldOp are mapped 8167e0d58dSKareemErgawy-TomTom // to the materialized target values. Therefore, we can replace the op's 8267e0d58dSKareemErgawy-TomTom // uses with those of its YielOp's operands. 8367e0d58dSKareemErgawy-TomTom rewriter.replaceOp(op, yieldOp->getOperands()); 8467e0d58dSKareemErgawy-TomTom 8567e0d58dSKareemErgawy-TomTom // No need for these intermediate blocks, merge them into 1. 86b54c724bSRiver Riddle rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands()); 8767e0d58dSKareemErgawy-TomTom rewriter.mergeBlocks(newBlock, originalBlock, {}); 8867e0d58dSKareemErgawy-TomTom 8967e0d58dSKareemErgawy-TomTom rewriter.eraseOp(&*Block::iterator(yieldOp)); 9067e0d58dSKareemErgawy-TomTom 9167e0d58dSKareemErgawy-TomTom return success(); 9267e0d58dSKareemErgawy-TomTom } 9367e0d58dSKareemErgawy-TomTom }; 9467e0d58dSKareemErgawy-TomTom 953b021fbdSKareemErgawy-TomTom /// A conversion pattern for detensoring internal (non-entry) blocks within a 963b021fbdSKareemErgawy-TomTom /// function. 977ceffae1SRiver Riddle struct FunctionNonEntryBlockConversion 987ceffae1SRiver Riddle : public OpInterfaceConversionPattern<FunctionOpInterface> { 99c10995a8SStella Laurenzo FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter, 100aa6eb2afSKareemErgawy-TomTom DenseSet<BlockArgument> blockArgsToDetensor) 1017ceffae1SRiver Riddle : OpInterfaceConversionPattern(converter, ctx), 1021fc096afSMehdi Amini blockArgsToDetensor(std::move(blockArgsToDetensor)) {} 1033b021fbdSKareemErgawy-TomTom 1043b021fbdSKareemErgawy-TomTom LogicalResult 1057ceffae1SRiver Riddle matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands, 1063b021fbdSKareemErgawy-TomTom ConversionPatternRewriter &rewriter) const override { 1075fcf907bSMatthias Springer rewriter.startOpModification(op); 108ecba7c58SRiver Riddle Region ®ion = op.getFunctionBody(); 1093b021fbdSKareemErgawy-TomTom 11052050f3fSMatthias Springer for (Block &block : 11152050f3fSMatthias Springer llvm::make_early_inc_range(llvm::drop_begin(region, 1))) { 11252050f3fSMatthias Springer TypeConverter::SignatureConversion conversion( 11352050f3fSMatthias Springer /*numOrigInputs=*/block.getNumArguments()); 114aa6eb2afSKareemErgawy-TomTom 115aa6eb2afSKareemErgawy-TomTom for (BlockArgument blockArgument : block.getArguments()) { 116aa6eb2afSKareemErgawy-TomTom int idx = blockArgument.getArgNumber(); 117aa6eb2afSKareemErgawy-TomTom 118aa6eb2afSKareemErgawy-TomTom if (blockArgsToDetensor.count(blockArgument)) 11952050f3fSMatthias Springer conversion.addInputs(idx, {getTypeConverter()->convertType( 120aa6eb2afSKareemErgawy-TomTom block.getArgumentTypes()[idx])}); 121aa6eb2afSKareemErgawy-TomTom else 12252050f3fSMatthias Springer conversion.addInputs(idx, {block.getArgumentTypes()[idx]}); 123aa6eb2afSKareemErgawy-TomTom } 124aa6eb2afSKareemErgawy-TomTom 12552050f3fSMatthias Springer rewriter.applySignatureConversion(&block, conversion, getTypeConverter()); 1263b021fbdSKareemErgawy-TomTom } 1273b021fbdSKareemErgawy-TomTom 1285fcf907bSMatthias Springer rewriter.finalizeOpModification(op); 1293b021fbdSKareemErgawy-TomTom return success(); 1303b021fbdSKareemErgawy-TomTom } 131aa6eb2afSKareemErgawy-TomTom 132aa6eb2afSKareemErgawy-TomTom private: 133aa6eb2afSKareemErgawy-TomTom const DenseSet<BlockArgument> blockArgsToDetensor; 1343b021fbdSKareemErgawy-TomTom }; 1353b021fbdSKareemErgawy-TomTom 13667e0d58dSKareemErgawy-TomTom class DetensorizeTypeConverter : public TypeConverter { 13767e0d58dSKareemErgawy-TomTom public: 13867e0d58dSKareemErgawy-TomTom DetensorizeTypeConverter() { 13967e0d58dSKareemErgawy-TomTom addConversion([](Type type) { return type; }); 14067e0d58dSKareemErgawy-TomTom 14167e0d58dSKareemErgawy-TomTom // A TensorType that can be detensored, is converted to the underlying 14267e0d58dSKareemErgawy-TomTom // element type. 14367e0d58dSKareemErgawy-TomTom addConversion([](TensorType tensorType) -> Type { 14467e0d58dSKareemErgawy-TomTom if (canBeDetensored(tensorType)) 14567e0d58dSKareemErgawy-TomTom return tensorType.getElementType(); 14667e0d58dSKareemErgawy-TomTom 14767e0d58dSKareemErgawy-TomTom return tensorType; 14867e0d58dSKareemErgawy-TomTom }); 14967e0d58dSKareemErgawy-TomTom 15067e0d58dSKareemErgawy-TomTom // A tensor value is detensoried by extracting its element(s). 15167e0d58dSKareemErgawy-TomTom addTargetMaterialization([](OpBuilder &builder, Type type, 15267e0d58dSKareemErgawy-TomTom ValueRange inputs, Location loc) -> Value { 15367e0d58dSKareemErgawy-TomTom return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{}); 15467e0d58dSKareemErgawy-TomTom }); 15567e0d58dSKareemErgawy-TomTom 1563b021fbdSKareemErgawy-TomTom addSourceMaterialization(sourceMaterializationCallback); 15767e0d58dSKareemErgawy-TomTom } 15867e0d58dSKareemErgawy-TomTom }; 15967e0d58dSKareemErgawy-TomTom 16067e0d58dSKareemErgawy-TomTom /// @see LinalgDetensorize in Linalg/Passes.td for more details. 16167d0d7acSMichele Scuttari struct LinalgDetensorize 1621e98d488SQuinn Dawkins : public impl::LinalgDetensorizePassBase<LinalgDetensorize> { 1631e98d488SQuinn Dawkins using impl::LinalgDetensorizePassBase< 1641e98d488SQuinn Dawkins LinalgDetensorize>::LinalgDetensorizePassBase; 165aa6eb2afSKareemErgawy-TomTom LinalgDetensorize() = default; 166aa6eb2afSKareemErgawy-TomTom 167aa6eb2afSKareemErgawy-TomTom class CostModel { 168aa6eb2afSKareemErgawy-TomTom public: 169aa6eb2afSKareemErgawy-TomTom virtual ~CostModel() = default; 170aa6eb2afSKareemErgawy-TomTom 171aa6eb2afSKareemErgawy-TomTom /// A cost model algorithm computes the following outputs: 172aa6eb2afSKareemErgawy-TomTom /// 173aa6eb2afSKareemErgawy-TomTom /// - opsToDetensor: the list of linalg ops that should be 174aa6eb2afSKareemErgawy-TomTom /// detensored. 175aa6eb2afSKareemErgawy-TomTom /// 176aa6eb2afSKareemErgawy-TomTom /// - blockArgsToDetensor: since the operands and results of detensored 177aa6eb2afSKareemErgawy-TomTom /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come 178aa6eb2afSKareemErgawy-TomTom /// from a BB argument and a linalg op's output can be passed to successor 179aa6eb2afSKareemErgawy-TomTom /// BBs), we need to maintain the sub-set of arguments that should be 180aa6eb2afSKareemErgawy-TomTom /// detensored (i.e. converted by typeConverter) for each affected BB. 181aa6eb2afSKareemErgawy-TomTom /// 182aa6eb2afSKareemErgawy-TomTom /// Example: 183aa6eb2afSKareemErgawy-TomTom /// 184aa6eb2afSKareemErgawy-TomTom /// For the following snippet: 185aa6eb2afSKareemErgawy-TomTom /// ... 186aa6eb2afSKareemErgawy-TomTom /// ^bb1(%6: tensor<i32>, %9: tensor<i32>): 18781ca5aa4SMatthias Springer /// %7 = tensor.empty() : tensor<i32> 188aa6eb2afSKareemErgawy-TomTom /// %8 = linalg.generic #attrs 189aa6eb2afSKareemErgawy-TomTom /// ins(%6, %6 : tensor<i32>, tensor<i32>) 190aa6eb2afSKareemErgawy-TomTom /// outs(%7 : tensor<i32>) { 191aa6eb2afSKareemErgawy-TomTom /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): 192a54f4eaeSMogball /// %9 = arith.addi %arg0, %arg1 : i32 193aa6eb2afSKareemErgawy-TomTom /// linalg.yield %9 : i32 194aa6eb2afSKareemErgawy-TomTom /// } -> tensor<i32> 195aa6eb2afSKareemErgawy-TomTom /// %10 = "some.op"(%9) 196aa6eb2afSKareemErgawy-TomTom /// br ^bb2(%8 : tensor<i32>) 197aa6eb2afSKareemErgawy-TomTom /// ... 198aa6eb2afSKareemErgawy-TomTom /// 199aa6eb2afSKareemErgawy-TomTom /// if the cost model decides that the linalg.generic op should be 200aa6eb2afSKareemErgawy-TomTom /// detensored, then: 201aa6eb2afSKareemErgawy-TomTom /// - opsToDetensor should be = {linalg.generic{add}}. 202aa6eb2afSKareemErgawy-TomTom /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}. 2037ceffae1SRiver Riddle virtual void compute(FunctionOpInterface func, 204c10995a8SStella Laurenzo DetensorizeTypeConverter typeConverter, 205aa6eb2afSKareemErgawy-TomTom DenseSet<Operation *> &opsToDetensor, 206aa6eb2afSKareemErgawy-TomTom DenseSet<BlockArgument> &blockArgsToDetensor) = 0; 207aa6eb2afSKareemErgawy-TomTom 208aa6eb2afSKareemErgawy-TomTom /// From the blockArgsToDetensor set computed by a CostModel 209aa6eb2afSKareemErgawy-TomTom /// implementation, this method computes the corresponding branch op 210aa6eb2afSKareemErgawy-TomTom /// detensoring. The result is a map from a branch op to a subset of indices 211aa6eb2afSKareemErgawy-TomTom /// of its operands. The indices specify which of the branch op's operands 212aa6eb2afSKareemErgawy-TomTom /// should be detensored. 213aa6eb2afSKareemErgawy-TomTom /// 214aa6eb2afSKareemErgawy-TomTom /// For the previous example, this method would compute: {bb2 -> {0}}. 215aa6eb2afSKareemErgawy-TomTom static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring( 216aa6eb2afSKareemErgawy-TomTom const DenseSet<BlockArgument> &blockArgsToDetensor) { 217aa6eb2afSKareemErgawy-TomTom DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; 218aa6eb2afSKareemErgawy-TomTom 219aa6eb2afSKareemErgawy-TomTom for (auto blockArgumentElem : blockArgsToDetensor) { 220aa6eb2afSKareemErgawy-TomTom Block *block = blockArgumentElem.getOwner(); 221aa6eb2afSKareemErgawy-TomTom 222aa6eb2afSKareemErgawy-TomTom for (PredecessorIterator pred = block->pred_begin(); 223aa6eb2afSKareemErgawy-TomTom pred != block->pred_end(); ++pred) { 224aa6eb2afSKareemErgawy-TomTom BranchOpInterface terminator = 225aa6eb2afSKareemErgawy-TomTom dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 226aa6eb2afSKareemErgawy-TomTom auto blockOperands = 227aa6eb2afSKareemErgawy-TomTom terminator.getSuccessorOperands(pred.getSuccessorIndex()); 228aa6eb2afSKareemErgawy-TomTom 2290c789db5SMarkus Böck if (blockOperands.empty() || 2300c789db5SMarkus Böck blockOperands.isOperandProduced(blockArgumentElem.getArgNumber())) 231aa6eb2afSKareemErgawy-TomTom continue; 232aa6eb2afSKareemErgawy-TomTom 233aa6eb2afSKareemErgawy-TomTom detensorableBranchOps[terminator].insert( 2340c789db5SMarkus Böck blockOperands.getOperandIndex(blockArgumentElem.getArgNumber())); 235aa6eb2afSKareemErgawy-TomTom } 236aa6eb2afSKareemErgawy-TomTom } 237aa6eb2afSKareemErgawy-TomTom 238aa6eb2afSKareemErgawy-TomTom return detensorableBranchOps; 239aa6eb2afSKareemErgawy-TomTom } 240aa6eb2afSKareemErgawy-TomTom }; 241aa6eb2afSKareemErgawy-TomTom 242aa6eb2afSKareemErgawy-TomTom /// Detensorize linalg ops involved in control-flow within a function. 243aa6eb2afSKareemErgawy-TomTom /// 244bdcf4b9bSKareemErgawy-TomTom /// This model starts from BranchOps and CondBranchOps within a function. For 245bdcf4b9bSKareemErgawy-TomTom /// each such branch, the model then walks the use-def chain for the branch's 246bdcf4b9bSKareemErgawy-TomTom /// condition backwards in order to understand where the condition's value 247bdcf4b9bSKareemErgawy-TomTom /// comes from. If the condition value is (indirectly) computed by a linalg op 248bdcf4b9bSKareemErgawy-TomTom /// that can be detensored, the model then continues walking the use-def chain 249bdcf4b9bSKareemErgawy-TomTom /// in order to understand where the linalg op's operands come from. This 250bdcf4b9bSKareemErgawy-TomTom /// leads to discovering a "detensoring component". A detensoring component is 251bdcf4b9bSKareemErgawy-TomTom /// the set of operations + block arguments that are involved in control-flow 252bdcf4b9bSKareemErgawy-TomTom /// AND can be detensored. 253bdcf4b9bSKareemErgawy-TomTom class ControlFlowDetectionModel : public CostModel { 254aa6eb2afSKareemErgawy-TomTom public: 2557ceffae1SRiver Riddle void compute(FunctionOpInterface func, 2567ceffae1SRiver Riddle DetensorizeTypeConverter typeConverter, 257aa6eb2afSKareemErgawy-TomTom DenseSet<Operation *> &opsToDetensor, 258aa6eb2afSKareemErgawy-TomTom DenseSet<BlockArgument> &blockArgsToDetensor) override { 259aa6eb2afSKareemErgawy-TomTom SmallVector<Value> workList; 260aa6eb2afSKareemErgawy-TomTom 261ace01605SRiver Riddle func->walk([&](cf::CondBranchOp condBr) { 26289d8035eSBenjamin Kramer llvm::append_range(workList, condBr.getOperands()); 263f984a805SKareemErgawy-TomTom }); 264f984a805SKareemErgawy-TomTom 265ace01605SRiver Riddle func->walk([&](cf::BranchOp br) { 26689d8035eSBenjamin Kramer llvm::append_range(workList, br.getOperands()); 267f984a805SKareemErgawy-TomTom }); 268aa6eb2afSKareemErgawy-TomTom 269aa6eb2afSKareemErgawy-TomTom DenseSet<Value> visitedValues; 270aa6eb2afSKareemErgawy-TomTom DenseSet<Operation *> visitedOps; 271aa6eb2afSKareemErgawy-TomTom 2720b05207eSKareemErgawy-TomTom // For a (to-be-detesored) value, check if it "escapes" the block by being 2730b05207eSKareemErgawy-TomTom // passed to terminator. If it does, then workList is updated with the 2740b05207eSKareemErgawy-TomTom // corresponding argument to the successor block. 2750b05207eSKareemErgawy-TomTom auto updateWorkListWithSuccessorArguments = 2760b05207eSKareemErgawy-TomTom [&](Value value, BranchOpInterface terminator) { 2770b05207eSKareemErgawy-TomTom if (!terminator) 2780b05207eSKareemErgawy-TomTom return; 2790b05207eSKareemErgawy-TomTom 2800b05207eSKareemErgawy-TomTom for (auto operandIdx : 2810b05207eSKareemErgawy-TomTom llvm::seq<unsigned>(0, terminator->getOperands().size())) { 2820b05207eSKareemErgawy-TomTom Value operand = terminator->getOperand(operandIdx); 2830b05207eSKareemErgawy-TomTom 2840b05207eSKareemErgawy-TomTom if (operand == value) { 2850b05207eSKareemErgawy-TomTom auto succBlockArg = 2860b05207eSKareemErgawy-TomTom terminator.getSuccessorBlockArgument(operandIdx); 2870b05207eSKareemErgawy-TomTom 2880b05207eSKareemErgawy-TomTom if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg)) 2890b05207eSKareemErgawy-TomTom workList.push_back(*succBlockArg); 2900b05207eSKareemErgawy-TomTom } 2910b05207eSKareemErgawy-TomTom } 2920b05207eSKareemErgawy-TomTom }; 2930b05207eSKareemErgawy-TomTom 294aa6eb2afSKareemErgawy-TomTom while (!workList.empty()) { 295aa6eb2afSKareemErgawy-TomTom Value currentItem = workList.pop_back_val(); 296aa6eb2afSKareemErgawy-TomTom 297aa6eb2afSKareemErgawy-TomTom if (!visitedValues.insert(currentItem).second) 298aa6eb2afSKareemErgawy-TomTom continue; 299aa6eb2afSKareemErgawy-TomTom 3000b05207eSKareemErgawy-TomTom // 1 - Look forward: 3010b05207eSKareemErgawy-TomTom // 1.1 - If currentItem escapes to one or more successors, add 3020b05207eSKareemErgawy-TomTom // the corresponding successor arguments to workList. 3030b05207eSKareemErgawy-TomTom updateWorkListWithSuccessorArguments( 3040b05207eSKareemErgawy-TomTom currentItem, dyn_cast<BranchOpInterface>( 3050b05207eSKareemErgawy-TomTom currentItem.getParentBlock()->getTerminator())); 3060b05207eSKareemErgawy-TomTom 3070b05207eSKareemErgawy-TomTom // 1.2 - For each user of currentItem, add the defined values to 3080b05207eSKareemErgawy-TomTom // workList. This way, the user ops can be inspected later if they are 3090b05207eSKareemErgawy-TomTom // detensorable and if so, their operands will be added to workList to 3100b05207eSKareemErgawy-TomTom // potentially discover other parts of the detensorable component. 3110b05207eSKareemErgawy-TomTom for (auto *user : currentItem.getUsers()) 31289d8035eSBenjamin Kramer llvm::append_range(workList, user->getResults()); 3130b05207eSKareemErgawy-TomTom 3140b05207eSKareemErgawy-TomTom // 2 - Look backward: 3150b05207eSKareemErgawy-TomTom // 2.1 - The current item is defined by a block argument. If the owner 3160b05207eSKareemErgawy-TomTom // block is a non-entry one, then: 3170b05207eSKareemErgawy-TomTom // * Add the argument to blockArgsToDetensor. 3180b05207eSKareemErgawy-TomTom // * Walk the use-def chain backwards to add each predecessor's 3190b05207eSKareemErgawy-TomTom // terminator-operands corresponding to currentItem to workList. 3205550c821STres Popp if (dyn_cast<BlockArgument>(currentItem)) { 321aa6eb2afSKareemErgawy-TomTom BlockArgument currentItemBlockArgument = 3225550c821STres Popp cast<BlockArgument>(currentItem); 323aa6eb2afSKareemErgawy-TomTom Block *ownerBlock = currentItemBlockArgument.getOwner(); 324aa6eb2afSKareemErgawy-TomTom 325aa6eb2afSKareemErgawy-TomTom // Function arguments are not detensored/converted. 326aa6eb2afSKareemErgawy-TomTom if (&*ownerBlock->getParent()->begin() == ownerBlock) 327aa6eb2afSKareemErgawy-TomTom continue; 328aa6eb2afSKareemErgawy-TomTom 329aa6eb2afSKareemErgawy-TomTom // This inner-block argument is involved in control-flow, it should be 330aa6eb2afSKareemErgawy-TomTom // detensored. 331aa6eb2afSKareemErgawy-TomTom blockArgsToDetensor.insert(currentItemBlockArgument); 332aa6eb2afSKareemErgawy-TomTom 333aa6eb2afSKareemErgawy-TomTom for (PredecessorIterator pred = ownerBlock->pred_begin(); 334aa6eb2afSKareemErgawy-TomTom pred != ownerBlock->pred_end(); ++pred) { 335bdcf4b9bSKareemErgawy-TomTom BranchOpInterface predTerminator = 336aa6eb2afSKareemErgawy-TomTom dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 337aa6eb2afSKareemErgawy-TomTom 338aa6eb2afSKareemErgawy-TomTom // TODO: For now, we give up if any of the control-flow components 339aa6eb2afSKareemErgawy-TomTom // in a function is not detensorable. Fix that. 340bdcf4b9bSKareemErgawy-TomTom if (!predTerminator) { 341aa6eb2afSKareemErgawy-TomTom opsToDetensor.clear(); 342aa6eb2afSKareemErgawy-TomTom blockArgsToDetensor.clear(); 343aa6eb2afSKareemErgawy-TomTom return; 344aa6eb2afSKareemErgawy-TomTom } 345aa6eb2afSKareemErgawy-TomTom 346aa6eb2afSKareemErgawy-TomTom auto ownerBlockOperands = 347bdcf4b9bSKareemErgawy-TomTom predTerminator.getSuccessorOperands(pred.getSuccessorIndex()); 348aa6eb2afSKareemErgawy-TomTom 3490c789db5SMarkus Böck if (ownerBlockOperands.empty() || 3500c789db5SMarkus Böck ownerBlockOperands.isOperandProduced( 3510c789db5SMarkus Böck currentItemBlockArgument.getArgNumber())) 352aa6eb2afSKareemErgawy-TomTom continue; 353aa6eb2afSKareemErgawy-TomTom 354aa6eb2afSKareemErgawy-TomTom // For each predecessor, add the value it passes to that argument to 355aa6eb2afSKareemErgawy-TomTom // workList to find out how it's computed. 356aa6eb2afSKareemErgawy-TomTom workList.push_back( 3570c789db5SMarkus Böck ownerBlockOperands[currentItemBlockArgument.getArgNumber()]); 358aa6eb2afSKareemErgawy-TomTom } 359aa6eb2afSKareemErgawy-TomTom 360aa6eb2afSKareemErgawy-TomTom continue; 361aa6eb2afSKareemErgawy-TomTom } 362aa6eb2afSKareemErgawy-TomTom 363aa6eb2afSKareemErgawy-TomTom Operation *currentItemDefiningOp = currentItem.getDefiningOp(); 364aa6eb2afSKareemErgawy-TomTom 365aa6eb2afSKareemErgawy-TomTom if (!visitedOps.insert(currentItemDefiningOp).second) 366aa6eb2afSKareemErgawy-TomTom continue; 367aa6eb2afSKareemErgawy-TomTom 3680b05207eSKareemErgawy-TomTom // 2.2 - The current item is computed by a GenericOp. If the op should 3690b05207eSKareemErgawy-TomTom // be detensored, then: 3700b05207eSKareemErgawy-TomTom // * Add it to opsToDetensor. 3710b05207eSKareemErgawy-TomTom // * Add its operands to workList to discover other parts of the 3720b05207eSKareemErgawy-TomTom // potentially detensorable component. 373aa6eb2afSKareemErgawy-TomTom if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) { 374aa6eb2afSKareemErgawy-TomTom // The op was encountered already, no need to inspect it again. 375aa6eb2afSKareemErgawy-TomTom if (opsToDetensor.count(genericOp)) 376aa6eb2afSKareemErgawy-TomTom continue; 377aa6eb2afSKareemErgawy-TomTom 378bdcf4b9bSKareemErgawy-TomTom // The op should not be detensored, give up on it but continue with 379bdcf4b9bSKareemErgawy-TomTom // discovering the rest of the control-flow component. 380aa6eb2afSKareemErgawy-TomTom if (!shouldBeDetensored(genericOp, typeConverter)) { 381bdcf4b9bSKareemErgawy-TomTom continue; 382aa6eb2afSKareemErgawy-TomTom } 383aa6eb2afSKareemErgawy-TomTom 384aa6eb2afSKareemErgawy-TomTom opsToDetensor.insert(genericOp); 385d3b3f765SJacques Pienaar llvm::append_range(workList, genericOp.getInputs()); 386aa6eb2afSKareemErgawy-TomTom continue; 387aa6eb2afSKareemErgawy-TomTom } 388aa6eb2afSKareemErgawy-TomTom 3890b05207eSKareemErgawy-TomTom // 2.3 - The current item is the result of a FromElementsOp, it will be 390aa6eb2afSKareemErgawy-TomTom // trivially detensored later as part of canonicalization patterns 391aa6eb2afSKareemErgawy-TomTom // applied at the end of detensoring. 392aa6eb2afSKareemErgawy-TomTom // 393aa6eb2afSKareemErgawy-TomTom // Note: No need to check whether the result type of this op is 394aa6eb2afSKareemErgawy-TomTom // detensorable since if it wasn't we wouldn't reach that point in the 395aa6eb2afSKareemErgawy-TomTom // work list. 396e4c39501SRahul Joshi if (isa<tensor::FromElementsOp>(currentItemDefiningOp)) 397aa6eb2afSKareemErgawy-TomTom continue; 398aa6eb2afSKareemErgawy-TomTom 3990b05207eSKareemErgawy-TomTom // 2.4 - The current item is the result of a scalar op, add all its 4000b05207eSKareemErgawy-TomTom // operands to the work list. 401aa6eb2afSKareemErgawy-TomTom if (llvm::all_of( 402aa6eb2afSKareemErgawy-TomTom currentItemDefiningOp->getResultTypes(), 403aa6eb2afSKareemErgawy-TomTom [&](Type resultType) { return resultType.isIntOrFloat(); })) 40489d8035eSBenjamin Kramer llvm::append_range(workList, currentItemDefiningOp->getOperands()); 405aa6eb2afSKareemErgawy-TomTom } 406bdcf4b9bSKareemErgawy-TomTom 407bdcf4b9bSKareemErgawy-TomTom // Since the cost model gives up on some ops (see the details of step 2.2 408bdcf4b9bSKareemErgawy-TomTom // above), block arguments that correspond to the values produced by those 409bdcf4b9bSKareemErgawy-TomTom // ops should not be detensored as well. 410bdcf4b9bSKareemErgawy-TomTom 411bdcf4b9bSKareemErgawy-TomTom DenseSet<BlockArgument> blockArgsToRemove; 412bdcf4b9bSKareemErgawy-TomTom 413bdcf4b9bSKareemErgawy-TomTom for (auto &blockArg : blockArgsToDetensor) { 414bdcf4b9bSKareemErgawy-TomTom Block *block = blockArg.getParentBlock(); 415bdcf4b9bSKareemErgawy-TomTom 416bdcf4b9bSKareemErgawy-TomTom // For the potentially detensorable block argument, find the 417bdcf4b9bSKareemErgawy-TomTom // correpsonding operands in predecessor blocks. 418bdcf4b9bSKareemErgawy-TomTom for (PredecessorIterator pred = block->pred_begin(); 419bdcf4b9bSKareemErgawy-TomTom pred != block->pred_end(); ++pred) { 420bdcf4b9bSKareemErgawy-TomTom BranchOpInterface terminator = 421bdcf4b9bSKareemErgawy-TomTom dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 422bdcf4b9bSKareemErgawy-TomTom auto blockOperands = 423bdcf4b9bSKareemErgawy-TomTom terminator.getSuccessorOperands(pred.getSuccessorIndex()); 424bdcf4b9bSKareemErgawy-TomTom 4250c789db5SMarkus Böck if (blockOperands.empty() || 4260c789db5SMarkus Böck blockOperands.isOperandProduced(blockArg.getArgNumber())) 427bdcf4b9bSKareemErgawy-TomTom continue; 428bdcf4b9bSKareemErgawy-TomTom 429bdcf4b9bSKareemErgawy-TomTom Operation *definingOp = 4300c789db5SMarkus Böck blockOperands[blockArg.getArgNumber()].getDefiningOp(); 431bdcf4b9bSKareemErgawy-TomTom 432bdcf4b9bSKareemErgawy-TomTom // If the operand is defined by a GenericOp that will not be 433bdcf4b9bSKareemErgawy-TomTom // detensored, then do not detensor the corresponding block argument. 4340c789db5SMarkus Böck if (isa_and_nonnull<GenericOp>(definingOp) && 435bdcf4b9bSKareemErgawy-TomTom opsToDetensor.count(definingOp) == 0) { 436bdcf4b9bSKareemErgawy-TomTom blockArgsToRemove.insert(blockArg); 437bdcf4b9bSKareemErgawy-TomTom break; 438bdcf4b9bSKareemErgawy-TomTom } 439bdcf4b9bSKareemErgawy-TomTom } 440bdcf4b9bSKareemErgawy-TomTom } 441bdcf4b9bSKareemErgawy-TomTom 442bdcf4b9bSKareemErgawy-TomTom for (auto &blockArg : blockArgsToRemove) { 443bdcf4b9bSKareemErgawy-TomTom blockArgsToDetensor.erase(blockArg); 444bdcf4b9bSKareemErgawy-TomTom } 445aa6eb2afSKareemErgawy-TomTom } 446aa6eb2afSKareemErgawy-TomTom }; 447aa6eb2afSKareemErgawy-TomTom 448aa6eb2afSKareemErgawy-TomTom /// Detensorize everything that can detensored. 449aa6eb2afSKareemErgawy-TomTom class AggressiveDetensoringModel : public CostModel { 450aa6eb2afSKareemErgawy-TomTom public: 4517ceffae1SRiver Riddle void compute(FunctionOpInterface func, 4527ceffae1SRiver Riddle DetensorizeTypeConverter typeConverter, 453aa6eb2afSKareemErgawy-TomTom DenseSet<Operation *> &opsToDetensor, 454aa6eb2afSKareemErgawy-TomTom DenseSet<BlockArgument> &blockArgsToDetensor) override { 455c10995a8SStella Laurenzo func->walk([&](GenericOp genericOp) { 456aa6eb2afSKareemErgawy-TomTom if (shouldBeDetensored(genericOp, typeConverter)) 457aa6eb2afSKareemErgawy-TomTom opsToDetensor.insert(genericOp); 458aa6eb2afSKareemErgawy-TomTom }); 459aa6eb2afSKareemErgawy-TomTom 460ecba7c58SRiver Riddle for (Block &block : llvm::drop_begin(func.getFunctionBody(), 1)) 461aa6eb2afSKareemErgawy-TomTom for (BlockArgument blockArgument : block.getArguments()) 462aa6eb2afSKareemErgawy-TomTom blockArgsToDetensor.insert(blockArgument); 463aa6eb2afSKareemErgawy-TomTom } 464aa6eb2afSKareemErgawy-TomTom }; 465aa6eb2afSKareemErgawy-TomTom 466c10995a8SStella Laurenzo void runOnOperation() override { 467aa6eb2afSKareemErgawy-TomTom MLIRContext *context = &getContext(); 46867e0d58dSKareemErgawy-TomTom DetensorizeTypeConverter typeConverter; 469dc4e913bSChris Lattner RewritePatternSet patterns(context); 47067e0d58dSKareemErgawy-TomTom ConversionTarget target(*context); 471aa6eb2afSKareemErgawy-TomTom DenseSet<Operation *> opsToDetensor; 472aa6eb2afSKareemErgawy-TomTom DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; 473aa6eb2afSKareemErgawy-TomTom DenseSet<BlockArgument> blockArgsToDetensor; 474b6d9e30bSFelix Schneider FunctionOpInterface funcOp = getOperation(); 47567e0d58dSKareemErgawy-TomTom 4760fac44d8SOkwan Kwon if (funcOp.getFunctionBody().empty()) 4770fac44d8SOkwan Kwon return; 4780fac44d8SOkwan Kwon 47965eedcebSAlex Zinenko // Make sure the entry block of the function doesn't contain any Linalg ops. 48065eedcebSAlex Zinenko // Otherwise, it may lead to the signature of the block being changed by the 48165eedcebSAlex Zinenko // dialect conversion below, which would make the function op invalid 48265eedcebSAlex Zinenko // because its type shouldn't change. 48365eedcebSAlex Zinenko IRRewriter rewriter(funcOp->getContext()); 48465eedcebSAlex Zinenko Block *entryBlock = &funcOp.getFunctionBody().front(); 48565eedcebSAlex Zinenko Block *postEntryBlock = 48665eedcebSAlex Zinenko rewriter.splitBlock(entryBlock, entryBlock->begin()); 48765eedcebSAlex Zinenko rewriter.setInsertionPointToStart(entryBlock); 48865eedcebSAlex Zinenko auto branch = 48965eedcebSAlex Zinenko rewriter.create<cf::BranchOp>(rewriter.getUnknownLoc(), postEntryBlock); 49065eedcebSAlex Zinenko 491aa6eb2afSKareemErgawy-TomTom if (aggressiveMode.getValue()) { 492aa6eb2afSKareemErgawy-TomTom AggressiveDetensoringModel costModel; 4937ceffae1SRiver Riddle costModel.compute(funcOp, typeConverter, opsToDetensor, 494aa6eb2afSKareemErgawy-TomTom blockArgsToDetensor); 495aa6eb2afSKareemErgawy-TomTom } else { 496bdcf4b9bSKareemErgawy-TomTom ControlFlowDetectionModel costModel; 4977ceffae1SRiver Riddle costModel.compute(funcOp, typeConverter, opsToDetensor, 498aa6eb2afSKareemErgawy-TomTom blockArgsToDetensor); 499aa6eb2afSKareemErgawy-TomTom } 500aa6eb2afSKareemErgawy-TomTom 501aa6eb2afSKareemErgawy-TomTom detensorableBranchOps = 502aa6eb2afSKareemErgawy-TomTom CostModel::computeBranchOpDetensoring(blockArgsToDetensor); 503aa6eb2afSKareemErgawy-TomTom 504aa6eb2afSKareemErgawy-TomTom target.addDynamicallyLegalOp<GenericOp>( 505aa6eb2afSKareemErgawy-TomTom [&](GenericOp op) { return !opsToDetensor.count(op); }); 50667e0d58dSKareemErgawy-TomTom 507c10995a8SStella Laurenzo target.markUnknownOpDynamicallyLegal([&](Operation *op) { 508aa6eb2afSKareemErgawy-TomTom // A function is legal if all of its non-entry blocks are legal. We 5090b05207eSKareemErgawy-TomTom // don't legalize the entry block (i.e. the function's signature) 5100b05207eSKareemErgawy-TomTom // since detensoring can't happen along external calling convention 511aa6eb2afSKareemErgawy-TomTom // boundaries, which we conservatively approximate as all function 512aa6eb2afSKareemErgawy-TomTom // signatures. 5137ceffae1SRiver Riddle if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) { 514ecba7c58SRiver Riddle Region &body = funcOp.getFunctionBody(); 515c10995a8SStella Laurenzo return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) { 5166786d7e4SMehdi Amini return !llvm::any_of( 517c10995a8SStella Laurenzo blockArgsToDetensor, [&](BlockArgument blockArgument) { 518aa6eb2afSKareemErgawy-TomTom return blockArgument.getOwner() == &block && 519aa6eb2afSKareemErgawy-TomTom !typeConverter.isLegal(blockArgument.getType()); 5206786d7e4SMehdi Amini }); 5213b021fbdSKareemErgawy-TomTom }); 522c10995a8SStella Laurenzo } 5233b021fbdSKareemErgawy-TomTom 524aa6eb2afSKareemErgawy-TomTom if (isNotBranchOpInterfaceOrReturnLikeOp(op) || 525aa6eb2afSKareemErgawy-TomTom isLegalForReturnOpTypeConversionPattern(op, typeConverter, 526aa6eb2afSKareemErgawy-TomTom /*returnOpAlwaysLegal*/ true)) 527aa6eb2afSKareemErgawy-TomTom return true; 528aa6eb2afSKareemErgawy-TomTom 529aa6eb2afSKareemErgawy-TomTom if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { 530aa6eb2afSKareemErgawy-TomTom if (!detensorableBranchOps.count(branchOp)) 531aa6eb2afSKareemErgawy-TomTom return true; 532aa6eb2afSKareemErgawy-TomTom 533aa6eb2afSKareemErgawy-TomTom for (auto operandIdx : detensorableBranchOps[branchOp]) 534aa6eb2afSKareemErgawy-TomTom if (!typeConverter.isLegal( 535aa6eb2afSKareemErgawy-TomTom branchOp->getOperand(operandIdx).getType())) 536aa6eb2afSKareemErgawy-TomTom return false; 537aa6eb2afSKareemErgawy-TomTom 538aa6eb2afSKareemErgawy-TomTom return true; 539aa6eb2afSKareemErgawy-TomTom } 540aa6eb2afSKareemErgawy-TomTom 541aa6eb2afSKareemErgawy-TomTom return false; 5423b021fbdSKareemErgawy-TomTom }); 5433b021fbdSKareemErgawy-TomTom 544b4e0507cSTres Popp patterns.add<DetensorizeGenericOp>(typeConverter, context); 545b4e0507cSTres Popp patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter, 546aa6eb2afSKareemErgawy-TomTom blockArgsToDetensor); 547aa6eb2afSKareemErgawy-TomTom // Since non-entry block arguments get detensorized, we also need to 548aa6eb2afSKareemErgawy-TomTom // update the control flow inside the function to reflect the correct 549aa6eb2afSKareemErgawy-TomTom // types. 550aa6eb2afSKareemErgawy-TomTom auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp, 551aa6eb2afSKareemErgawy-TomTom int operandIdx) -> bool { 552aa6eb2afSKareemErgawy-TomTom return detensorableBranchOps.count(branchOp) && 553aa6eb2afSKareemErgawy-TomTom detensorableBranchOps[branchOp].count(operandIdx); 554aa6eb2afSKareemErgawy-TomTom }; 555aa6eb2afSKareemErgawy-TomTom 556aa6eb2afSKareemErgawy-TomTom populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, 557aa6eb2afSKareemErgawy-TomTom shouldConvertBranchOperand); 55867e0d58dSKareemErgawy-TomTom 559c10995a8SStella Laurenzo if (failed( 560c10995a8SStella Laurenzo applyFullConversion(getOperation(), target, std::move(patterns)))) 56167e0d58dSKareemErgawy-TomTom signalPassFailure(); 56267e0d58dSKareemErgawy-TomTom 563dc4e913bSChris Lattner RewritePatternSet canonPatterns(context); 564550ea385SAlexander Belyaev tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context); 565*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(getOperation(), std::move(canonPatterns)))) 56667e0d58dSKareemErgawy-TomTom signalPassFailure(); 56765eedcebSAlex Zinenko 56865eedcebSAlex Zinenko // Get rid of the dummy entry block we created in the beginning to work 56965eedcebSAlex Zinenko // around dialect conversion signature rewriting. 57065eedcebSAlex Zinenko rewriter.eraseOp(branch); 57165eedcebSAlex Zinenko rewriter.mergeBlocks(postEntryBlock, entryBlock); 57267e0d58dSKareemErgawy-TomTom } 57367e0d58dSKareemErgawy-TomTom }; 57467e0d58dSKareemErgawy-TomTom } // namespace 575