1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Passes.h" 10 11 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/OpDefinition.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 #include <iterator> 20 #include <memory> 21 #include <utility> 22 23 namespace mlir { 24 #define GEN_PASS_DEF_LINALGDETENSORIZEPASS 25 #include "mlir/Dialect/Linalg/Passes.h.inc" 26 } // namespace mlir 27 28 using namespace mlir; 29 using namespace mlir::linalg; 30 31 static Value sourceMaterializationCallback(OpBuilder &builder, Type type, 32 ValueRange inputs, Location loc) { 33 assert(inputs.size() == 1); 34 auto inputType = inputs[0].getType(); 35 if (isa<TensorType>(inputType)) 36 return nullptr; 37 38 // A detensored value is converted back by creating a new tensor from its 39 // element(s). 40 return builder.create<tensor::FromElementsOp>( 41 loc, RankedTensorType::get({}, inputType), inputs[0]); 42 } 43 44 namespace { 45 /// Defines the criteria a TensorType must follow in order to be considered 46 /// "detensorable". 47 /// 48 /// NOTE: For now, only 0-D tensors are supported. 49 /// 50 /// Returns true if tensorType can be detensored. 51 bool canBeDetensored(TensorType tensorType) { 52 return tensorType.hasRank() && tensorType.getRank() == 0; 53 } 54 55 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) { 56 GenericOp genericOp = dyn_cast_or_null<GenericOp>(op); 57 return genericOp && 58 llvm::all_of(genericOp->getOpOperands(), [&](OpOperand &opOperand) { 59 return !typeConverter.isLegal(opOperand.get().getType()); 60 }); 61 } 62 63 /// A conversion pattern for detensoring `linalg.generic` ops. 64 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> { 65 public: 66 using OpConversionPattern::OpConversionPattern; 67 LogicalResult 68 matchAndRewrite(GenericOp op, OpAdaptor adaptor, 69 ConversionPatternRewriter &rewriter) const override { 70 Block *originalBlock = op->getBlock(); 71 72 // Gather some information about the op before inlining its region. 73 Block *opEntryBlock = &*op.getRegion().begin(); 74 YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator()); 75 76 // Split the op's region before the op. This way, we have a clear insertion 77 // point in which the op can be inlined. 78 Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op)); 79 rewriter.inlineRegionBefore(op.getRegion(), newBlock); 80 // Now that op's region is inlined, the operands of its YieldOp are mapped 81 // to the materialized target values. Therefore, we can replace the op's 82 // uses with those of its YielOp's operands. 83 rewriter.replaceOp(op, yieldOp->getOperands()); 84 85 // No need for these intermediate blocks, merge them into 1. 86 rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands()); 87 rewriter.mergeBlocks(newBlock, originalBlock, {}); 88 89 rewriter.eraseOp(&*Block::iterator(yieldOp)); 90 91 return success(); 92 } 93 }; 94 95 /// A conversion pattern for detensoring internal (non-entry) blocks within a 96 /// function. 97 struct FunctionNonEntryBlockConversion 98 : public OpInterfaceConversionPattern<FunctionOpInterface> { 99 FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter, 100 DenseSet<BlockArgument> blockArgsToDetensor) 101 : OpInterfaceConversionPattern(converter, ctx), 102 blockArgsToDetensor(std::move(blockArgsToDetensor)) {} 103 104 LogicalResult 105 matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands, 106 ConversionPatternRewriter &rewriter) const override { 107 rewriter.startOpModification(op); 108 Region ®ion = op.getFunctionBody(); 109 110 for (Block &block : 111 llvm::make_early_inc_range(llvm::drop_begin(region, 1))) { 112 TypeConverter::SignatureConversion conversion( 113 /*numOrigInputs=*/block.getNumArguments()); 114 115 for (BlockArgument blockArgument : block.getArguments()) { 116 int idx = blockArgument.getArgNumber(); 117 118 if (blockArgsToDetensor.count(blockArgument)) 119 conversion.addInputs(idx, {getTypeConverter()->convertType( 120 block.getArgumentTypes()[idx])}); 121 else 122 conversion.addInputs(idx, {block.getArgumentTypes()[idx]}); 123 } 124 125 rewriter.applySignatureConversion(&block, conversion, getTypeConverter()); 126 } 127 128 rewriter.finalizeOpModification(op); 129 return success(); 130 } 131 132 private: 133 const DenseSet<BlockArgument> blockArgsToDetensor; 134 }; 135 136 class DetensorizeTypeConverter : public TypeConverter { 137 public: 138 DetensorizeTypeConverter() { 139 addConversion([](Type type) { return type; }); 140 141 // A TensorType that can be detensored, is converted to the underlying 142 // element type. 143 addConversion([](TensorType tensorType) -> Type { 144 if (canBeDetensored(tensorType)) 145 return tensorType.getElementType(); 146 147 return tensorType; 148 }); 149 150 // A tensor value is detensoried by extracting its element(s). 151 addTargetMaterialization([](OpBuilder &builder, Type type, 152 ValueRange inputs, Location loc) -> Value { 153 return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{}); 154 }); 155 156 addSourceMaterialization(sourceMaterializationCallback); 157 } 158 }; 159 160 /// @see LinalgDetensorize in Linalg/Passes.td for more details. 161 struct LinalgDetensorize 162 : public impl::LinalgDetensorizePassBase<LinalgDetensorize> { 163 using impl::LinalgDetensorizePassBase< 164 LinalgDetensorize>::LinalgDetensorizePassBase; 165 LinalgDetensorize() = default; 166 167 class CostModel { 168 public: 169 virtual ~CostModel() = default; 170 171 /// A cost model algorithm computes the following outputs: 172 /// 173 /// - opsToDetensor: the list of linalg ops that should be 174 /// detensored. 175 /// 176 /// - blockArgsToDetensor: since the operands and results of detensored 177 /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come 178 /// from a BB argument and a linalg op's output can be passed to successor 179 /// BBs), we need to maintain the sub-set of arguments that should be 180 /// detensored (i.e. converted by typeConverter) for each affected BB. 181 /// 182 /// Example: 183 /// 184 /// For the following snippet: 185 /// ... 186 /// ^bb1(%6: tensor<i32>, %9: tensor<i32>): 187 /// %7 = tensor.empty() : tensor<i32> 188 /// %8 = linalg.generic #attrs 189 /// ins(%6, %6 : tensor<i32>, tensor<i32>) 190 /// outs(%7 : tensor<i32>) { 191 /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): 192 /// %9 = arith.addi %arg0, %arg1 : i32 193 /// linalg.yield %9 : i32 194 /// } -> tensor<i32> 195 /// %10 = "some.op"(%9) 196 /// br ^bb2(%8 : tensor<i32>) 197 /// ... 198 /// 199 /// if the cost model decides that the linalg.generic op should be 200 /// detensored, then: 201 /// - opsToDetensor should be = {linalg.generic{add}}. 202 /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}. 203 virtual void compute(FunctionOpInterface func, 204 DetensorizeTypeConverter typeConverter, 205 DenseSet<Operation *> &opsToDetensor, 206 DenseSet<BlockArgument> &blockArgsToDetensor) = 0; 207 208 /// From the blockArgsToDetensor set computed by a CostModel 209 /// implementation, this method computes the corresponding branch op 210 /// detensoring. The result is a map from a branch op to a subset of indices 211 /// of its operands. The indices specify which of the branch op's operands 212 /// should be detensored. 213 /// 214 /// For the previous example, this method would compute: {bb2 -> {0}}. 215 static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring( 216 const DenseSet<BlockArgument> &blockArgsToDetensor) { 217 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; 218 219 for (auto blockArgumentElem : blockArgsToDetensor) { 220 Block *block = blockArgumentElem.getOwner(); 221 222 for (PredecessorIterator pred = block->pred_begin(); 223 pred != block->pred_end(); ++pred) { 224 BranchOpInterface terminator = 225 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 226 auto blockOperands = 227 terminator.getSuccessorOperands(pred.getSuccessorIndex()); 228 229 if (blockOperands.empty() || 230 blockOperands.isOperandProduced(blockArgumentElem.getArgNumber())) 231 continue; 232 233 detensorableBranchOps[terminator].insert( 234 blockOperands.getOperandIndex(blockArgumentElem.getArgNumber())); 235 } 236 } 237 238 return detensorableBranchOps; 239 } 240 }; 241 242 /// Detensorize linalg ops involved in control-flow within a function. 243 /// 244 /// This model starts from BranchOps and CondBranchOps within a function. For 245 /// each such branch, the model then walks the use-def chain for the branch's 246 /// condition backwards in order to understand where the condition's value 247 /// comes from. If the condition value is (indirectly) computed by a linalg op 248 /// that can be detensored, the model then continues walking the use-def chain 249 /// in order to understand where the linalg op's operands come from. This 250 /// leads to discovering a "detensoring component". A detensoring component is 251 /// the set of operations + block arguments that are involved in control-flow 252 /// AND can be detensored. 253 class ControlFlowDetectionModel : public CostModel { 254 public: 255 void compute(FunctionOpInterface func, 256 DetensorizeTypeConverter typeConverter, 257 DenseSet<Operation *> &opsToDetensor, 258 DenseSet<BlockArgument> &blockArgsToDetensor) override { 259 SmallVector<Value> workList; 260 261 func->walk([&](cf::CondBranchOp condBr) { 262 llvm::append_range(workList, condBr.getOperands()); 263 }); 264 265 func->walk([&](cf::BranchOp br) { 266 llvm::append_range(workList, br.getOperands()); 267 }); 268 269 DenseSet<Value> visitedValues; 270 DenseSet<Operation *> visitedOps; 271 272 // For a (to-be-detesored) value, check if it "escapes" the block by being 273 // passed to terminator. If it does, then workList is updated with the 274 // corresponding argument to the successor block. 275 auto updateWorkListWithSuccessorArguments = 276 [&](Value value, BranchOpInterface terminator) { 277 if (!terminator) 278 return; 279 280 for (auto operandIdx : 281 llvm::seq<unsigned>(0, terminator->getOperands().size())) { 282 Value operand = terminator->getOperand(operandIdx); 283 284 if (operand == value) { 285 auto succBlockArg = 286 terminator.getSuccessorBlockArgument(operandIdx); 287 288 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg)) 289 workList.push_back(*succBlockArg); 290 } 291 } 292 }; 293 294 while (!workList.empty()) { 295 Value currentItem = workList.pop_back_val(); 296 297 if (!visitedValues.insert(currentItem).second) 298 continue; 299 300 // 1 - Look forward: 301 // 1.1 - If currentItem escapes to one or more successors, add 302 // the corresponding successor arguments to workList. 303 updateWorkListWithSuccessorArguments( 304 currentItem, dyn_cast<BranchOpInterface>( 305 currentItem.getParentBlock()->getTerminator())); 306 307 // 1.2 - For each user of currentItem, add the defined values to 308 // workList. This way, the user ops can be inspected later if they are 309 // detensorable and if so, their operands will be added to workList to 310 // potentially discover other parts of the detensorable component. 311 for (auto *user : currentItem.getUsers()) 312 llvm::append_range(workList, user->getResults()); 313 314 // 2 - Look backward: 315 // 2.1 - The current item is defined by a block argument. If the owner 316 // block is a non-entry one, then: 317 // * Add the argument to blockArgsToDetensor. 318 // * Walk the use-def chain backwards to add each predecessor's 319 // terminator-operands corresponding to currentItem to workList. 320 if (dyn_cast<BlockArgument>(currentItem)) { 321 BlockArgument currentItemBlockArgument = 322 cast<BlockArgument>(currentItem); 323 Block *ownerBlock = currentItemBlockArgument.getOwner(); 324 325 // Function arguments are not detensored/converted. 326 if (&*ownerBlock->getParent()->begin() == ownerBlock) 327 continue; 328 329 // This inner-block argument is involved in control-flow, it should be 330 // detensored. 331 blockArgsToDetensor.insert(currentItemBlockArgument); 332 333 for (PredecessorIterator pred = ownerBlock->pred_begin(); 334 pred != ownerBlock->pred_end(); ++pred) { 335 BranchOpInterface predTerminator = 336 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 337 338 // TODO: For now, we give up if any of the control-flow components 339 // in a function is not detensorable. Fix that. 340 if (!predTerminator) { 341 opsToDetensor.clear(); 342 blockArgsToDetensor.clear(); 343 return; 344 } 345 346 auto ownerBlockOperands = 347 predTerminator.getSuccessorOperands(pred.getSuccessorIndex()); 348 349 if (ownerBlockOperands.empty() || 350 ownerBlockOperands.isOperandProduced( 351 currentItemBlockArgument.getArgNumber())) 352 continue; 353 354 // For each predecessor, add the value it passes to that argument to 355 // workList to find out how it's computed. 356 workList.push_back( 357 ownerBlockOperands[currentItemBlockArgument.getArgNumber()]); 358 } 359 360 continue; 361 } 362 363 Operation *currentItemDefiningOp = currentItem.getDefiningOp(); 364 365 if (!visitedOps.insert(currentItemDefiningOp).second) 366 continue; 367 368 // 2.2 - The current item is computed by a GenericOp. If the op should 369 // be detensored, then: 370 // * Add it to opsToDetensor. 371 // * Add its operands to workList to discover other parts of the 372 // potentially detensorable component. 373 if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) { 374 // The op was encountered already, no need to inspect it again. 375 if (opsToDetensor.count(genericOp)) 376 continue; 377 378 // The op should not be detensored, give up on it but continue with 379 // discovering the rest of the control-flow component. 380 if (!shouldBeDetensored(genericOp, typeConverter)) { 381 continue; 382 } 383 384 opsToDetensor.insert(genericOp); 385 llvm::append_range(workList, genericOp.getInputs()); 386 continue; 387 } 388 389 // 2.3 - The current item is the result of a FromElementsOp, it will be 390 // trivially detensored later as part of canonicalization patterns 391 // applied at the end of detensoring. 392 // 393 // Note: No need to check whether the result type of this op is 394 // detensorable since if it wasn't we wouldn't reach that point in the 395 // work list. 396 if (isa<tensor::FromElementsOp>(currentItemDefiningOp)) 397 continue; 398 399 // 2.4 - The current item is the result of a scalar op, add all its 400 // operands to the work list. 401 if (llvm::all_of( 402 currentItemDefiningOp->getResultTypes(), 403 [&](Type resultType) { return resultType.isIntOrFloat(); })) 404 llvm::append_range(workList, currentItemDefiningOp->getOperands()); 405 } 406 407 // Since the cost model gives up on some ops (see the details of step 2.2 408 // above), block arguments that correspond to the values produced by those 409 // ops should not be detensored as well. 410 411 DenseSet<BlockArgument> blockArgsToRemove; 412 413 for (auto &blockArg : blockArgsToDetensor) { 414 Block *block = blockArg.getParentBlock(); 415 416 // For the potentially detensorable block argument, find the 417 // correpsonding operands in predecessor blocks. 418 for (PredecessorIterator pred = block->pred_begin(); 419 pred != block->pred_end(); ++pred) { 420 BranchOpInterface terminator = 421 dyn_cast<BranchOpInterface>((*pred)->getTerminator()); 422 auto blockOperands = 423 terminator.getSuccessorOperands(pred.getSuccessorIndex()); 424 425 if (blockOperands.empty() || 426 blockOperands.isOperandProduced(blockArg.getArgNumber())) 427 continue; 428 429 Operation *definingOp = 430 blockOperands[blockArg.getArgNumber()].getDefiningOp(); 431 432 // If the operand is defined by a GenericOp that will not be 433 // detensored, then do not detensor the corresponding block argument. 434 if (isa_and_nonnull<GenericOp>(definingOp) && 435 opsToDetensor.count(definingOp) == 0) { 436 blockArgsToRemove.insert(blockArg); 437 break; 438 } 439 } 440 } 441 442 for (auto &blockArg : blockArgsToRemove) { 443 blockArgsToDetensor.erase(blockArg); 444 } 445 } 446 }; 447 448 /// Detensorize everything that can detensored. 449 class AggressiveDetensoringModel : public CostModel { 450 public: 451 void compute(FunctionOpInterface func, 452 DetensorizeTypeConverter typeConverter, 453 DenseSet<Operation *> &opsToDetensor, 454 DenseSet<BlockArgument> &blockArgsToDetensor) override { 455 func->walk([&](GenericOp genericOp) { 456 if (shouldBeDetensored(genericOp, typeConverter)) 457 opsToDetensor.insert(genericOp); 458 }); 459 460 for (Block &block : llvm::drop_begin(func.getFunctionBody(), 1)) 461 for (BlockArgument blockArgument : block.getArguments()) 462 blockArgsToDetensor.insert(blockArgument); 463 } 464 }; 465 466 void runOnOperation() override { 467 MLIRContext *context = &getContext(); 468 DetensorizeTypeConverter typeConverter; 469 RewritePatternSet patterns(context); 470 ConversionTarget target(*context); 471 DenseSet<Operation *> opsToDetensor; 472 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; 473 DenseSet<BlockArgument> blockArgsToDetensor; 474 FunctionOpInterface funcOp = getOperation(); 475 476 if (funcOp.getFunctionBody().empty()) 477 return; 478 479 // Make sure the entry block of the function doesn't contain any Linalg ops. 480 // Otherwise, it may lead to the signature of the block being changed by the 481 // dialect conversion below, which would make the function op invalid 482 // because its type shouldn't change. 483 IRRewriter rewriter(funcOp->getContext()); 484 Block *entryBlock = &funcOp.getFunctionBody().front(); 485 Block *postEntryBlock = 486 rewriter.splitBlock(entryBlock, entryBlock->begin()); 487 rewriter.setInsertionPointToStart(entryBlock); 488 auto branch = 489 rewriter.create<cf::BranchOp>(rewriter.getUnknownLoc(), postEntryBlock); 490 491 if (aggressiveMode.getValue()) { 492 AggressiveDetensoringModel costModel; 493 costModel.compute(funcOp, typeConverter, opsToDetensor, 494 blockArgsToDetensor); 495 } else { 496 ControlFlowDetectionModel costModel; 497 costModel.compute(funcOp, typeConverter, opsToDetensor, 498 blockArgsToDetensor); 499 } 500 501 detensorableBranchOps = 502 CostModel::computeBranchOpDetensoring(blockArgsToDetensor); 503 504 target.addDynamicallyLegalOp<GenericOp>( 505 [&](GenericOp op) { return !opsToDetensor.count(op); }); 506 507 target.markUnknownOpDynamicallyLegal([&](Operation *op) { 508 // A function is legal if all of its non-entry blocks are legal. We 509 // don't legalize the entry block (i.e. the function's signature) 510 // since detensoring can't happen along external calling convention 511 // boundaries, which we conservatively approximate as all function 512 // signatures. 513 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) { 514 Region &body = funcOp.getFunctionBody(); 515 return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) { 516 return !llvm::any_of( 517 blockArgsToDetensor, [&](BlockArgument blockArgument) { 518 return blockArgument.getOwner() == &block && 519 !typeConverter.isLegal(blockArgument.getType()); 520 }); 521 }); 522 } 523 524 if (isNotBranchOpInterfaceOrReturnLikeOp(op) || 525 isLegalForReturnOpTypeConversionPattern(op, typeConverter, 526 /*returnOpAlwaysLegal*/ true)) 527 return true; 528 529 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { 530 if (!detensorableBranchOps.count(branchOp)) 531 return true; 532 533 for (auto operandIdx : detensorableBranchOps[branchOp]) 534 if (!typeConverter.isLegal( 535 branchOp->getOperand(operandIdx).getType())) 536 return false; 537 538 return true; 539 } 540 541 return false; 542 }); 543 544 patterns.add<DetensorizeGenericOp>(typeConverter, context); 545 patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter, 546 blockArgsToDetensor); 547 // Since non-entry block arguments get detensorized, we also need to 548 // update the control flow inside the function to reflect the correct 549 // types. 550 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp, 551 int operandIdx) -> bool { 552 return detensorableBranchOps.count(branchOp) && 553 detensorableBranchOps[branchOp].count(operandIdx); 554 }; 555 556 populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, 557 shouldConvertBranchOperand); 558 559 if (failed( 560 applyFullConversion(getOperation(), target, std::move(patterns)))) 561 signalPassFailure(); 562 563 RewritePatternSet canonPatterns(context); 564 tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context); 565 if (failed(applyPatternsGreedily(getOperation(), std::move(canonPatterns)))) 566 signalPassFailure(); 567 568 // Get rid of the dummy entry block we created in the beginning to work 569 // around dialect conversion signature rewriting. 570 rewriter.eraseOp(branch); 571 rewriter.mergeBlocks(postEntryBlock, entryBlock); 572 } 573 }; 574 } // namespace 575