1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" 10 11 #include "mlir/AsmParser/AsmParser.h" 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Arith/Utils/Utils.h" 16 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 17 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 18 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h" 21 #include "mlir/Dialect/Linalg/TransformOps/Syntax.h" 22 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 23 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 24 #include "mlir/Dialect/Linalg/Utils/Utils.h" 25 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 26 #include "mlir/Dialect/Tensor/IR/Tensor.h" 27 #include "mlir/Dialect/Tensor/Utils/Utils.h" 28 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 29 #include "mlir/Dialect/Transform/IR/TransformOps.h" 30 #include "mlir/Dialect/Transform/IR/TransformTypes.h" 31 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 32 #include "mlir/Dialect/Transform/Utils/Utils.h" 33 #include "mlir/Dialect/Utils/IndexingUtils.h" 34 #include "mlir/Dialect/Utils/StaticValueUtils.h" 35 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 36 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 37 #include "mlir/IR/BuiltinTypeInterfaces.h" 38 #include "mlir/IR/PatternMatch.h" 39 #include "mlir/IR/TypeUtilities.h" 40 #include "mlir/Interfaces/TilingInterface.h" 41 #include "mlir/Support/LLVM.h" 42 #include "mlir/Support/TypeID.h" 43 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 44 #include "llvm/ADT/STLExtras.h" 45 #include "llvm/ADT/ScopeExit.h" 46 #include "llvm/ADT/TypeSwitch.h" 47 #include "llvm/Support/Debug.h" 48 #include <type_traits> 49 50 using namespace mlir; 51 using namespace mlir::linalg; 52 using namespace mlir::transform; 53 54 #define DEBUG_TYPE "linalg-transforms" 55 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 56 #define DBGSNL() (llvm::dbgs() << "\n") 57 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") 58 59 /// Attempts to apply the pattern specified as template argument to the given 60 /// operation. The pattern is expected to have a `returningMatchAndRewrite` 61 /// function that returns the "main" result or failure. Returns failure if the 62 /// pattern failed to apply. Extra arguments are forwarded to the pattern 63 /// constructor. 64 template <typename PatternTy, typename... Args> 65 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) { 66 // Check if the given operation has the type expected by the pattern. 67 using OpTy = typename llvm::function_traits< 68 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; 69 auto op = dyn_cast<OpTy>(operation); 70 if (!op) 71 return failure(); 72 73 // Apply the pattern directly to the op. 74 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...); 75 // We want to discourage direct use of PatternRewriter in APIs but In this 76 // very specific case, an IRRewriter is not enough. 77 struct TrivialPatternRewriter : public PatternRewriter { 78 public: 79 explicit TrivialPatternRewriter(MLIRContext *context) 80 : PatternRewriter(context) {} 81 }; 82 TrivialPatternRewriter rewriter(operation->getContext()); 83 rewriter.setInsertionPoint(operation); 84 auto result = pattern.returningMatchAndRewrite(op, rewriter); 85 if (failed(result)) 86 return failure(); 87 return cast<LinalgOp>(result->getOperation()); 88 } 89 90 /// Assuming that `ofr` is an index attr or a param of index type 91 /// or a transform dialect handle mapped to exactly one op 92 /// with one index result, return that value. 93 static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations( 94 transform::TransformState &state, TransformOpInterface transformOp, 95 SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) { 96 for (OpFoldResult ofr : ofrs) { 97 if (auto attr = dyn_cast<Attribute>(ofr)) { 98 if (!isa<IntegerAttr>(attr)) 99 return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; 100 result.push_back(ofr); 101 continue; 102 } 103 104 Value transformValue = cast<Value>(ofr); 105 if (isa<TransformParamTypeInterface>(transformValue.getType())) { 106 ArrayRef<Attribute> params = state.getParams(transformValue); 107 if (params.size() != 1) 108 return transformOp.emitDefiniteFailure() 109 << "requires exactly one parameter associated"; 110 result.push_back(params[0]); 111 continue; 112 } 113 114 auto payloadOps = state.getPayloadOps(transformValue); 115 if (!llvm::hasSingleElement(payloadOps)) { 116 DiagnosedSilenceableFailure diag = 117 transformOp.emitSilenceableError() 118 << "handle must be mapped to exactly one payload op"; 119 diag.attachNote(transformValue.getLoc()) 120 << "mapped to " << llvm::range_size(payloadOps) << " payload ops"; 121 return diag; 122 } 123 124 Operation *op = *payloadOps.begin(); 125 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { 126 DiagnosedSilenceableFailure diag = 127 transformOp.emitSilenceableError() 128 << "payload op must have exactly 1 index result"; 129 diag.attachNote(op->getLoc()) 130 << "has " << op->getNumResults() << " results"; 131 return diag; 132 } 133 result.push_back(op->getResult(0)); 134 } 135 136 return DiagnosedSilenceableFailure::success(); 137 } 138 139 // Given a list of params that are index attrs or a list of OpFoldResults 140 // that are either index attrs or op handles, return a list of OpFoldResults 141 // of index attrs or a list of OpFoldResults where all op handles are 142 // replaced with the first (and only) OpResult of that payload op. 143 // (There must be exactly one parameter associated with the AnyParamType or 144 // one mapped payload op which must have exactly one index result.) 145 static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations( 146 transform::TransformState &state, TransformOpInterface transformOp, 147 SmallVector<OpFoldResult> &result, Value packedHandle) { 148 if (isa<TransformParamTypeInterface>(packedHandle.getType())) { 149 ArrayRef<Attribute> params = state.getParams(packedHandle); 150 for (auto param : params) { 151 if (!isa<IntegerAttr>(param)) 152 return transformOp.emitDefiniteFailure() 153 << "expected the parameter to be associated with an integer " 154 "attribute"; 155 result.push_back(param); 156 } 157 return DiagnosedSilenceableFailure::success(); 158 } 159 160 for (Operation *op : state.getPayloadOps(packedHandle)) { 161 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { 162 DiagnosedSilenceableFailure diag = 163 transformOp.emitSilenceableError() 164 << "payload op must have exactly 1 index result"; 165 diag.attachNote(op->getLoc()) 166 << "has " << op->getNumResults() << " results"; 167 return diag; 168 } 169 result.push_back(op->getResult(0)); 170 } 171 172 return DiagnosedSilenceableFailure::success(); 173 } 174 175 /// When possible, converts each `OpFoldResult` in `mixedResult` to 176 /// an integer if the value can be statically inferred. If a result 177 /// is a `Value` then it must be either a `ParamType` or a handle 178 /// to an a constant like op. 179 static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults( 180 TransformState &state, TransformOpInterface &transformOp, 181 ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) { 182 for (OpFoldResult paramOrHandle : mixedResults) { 183 if (auto attr = dyn_cast<Attribute>(paramOrHandle)) { 184 reified.push_back(cast<IntegerAttr>(attr).getInt()); 185 continue; 186 } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) { 187 ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle)); 188 if (params.size() != 1) 189 return transformOp.emitSilenceableError() << "expected a single param"; 190 reified.push_back( 191 cast<IntegerAttr>(params.front()).getValue().getSExtValue()); 192 continue; 193 } 194 195 Value handle = cast<Value>(paramOrHandle); 196 if (!isa<TransformHandleTypeInterface>(handle.getType())) 197 return transformOp.emitSilenceableError() << "unexpected value handle"; 198 auto payload = state.getPayloadOps(handle); 199 if (!llvm::hasSingleElement(payload)) 200 return transformOp.emitSilenceableError() 201 << "requires param or handle that is mapped to 1 payload op"; 202 203 Operation *paramOrHandlePayloadOp = *payload.begin(); 204 if (paramOrHandlePayloadOp->getNumResults() != 1 || 205 !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) { 206 return transformOp.emitSilenceableError() 207 << "requires param or handle to be result of op with 1 index " 208 "result"; 209 } 210 211 IntegerAttr attr; 212 if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr))) 213 return transformOp.emitSilenceableError() 214 << "requires param or handle to be the result of a constant like " 215 "op"; 216 217 reified.push_back(attr.getInt()); 218 } 219 return DiagnosedSilenceableFailure::success(); 220 } 221 222 //===----------------------------------------------------------------------===// 223 // Apply...PatternsOp 224 //===----------------------------------------------------------------------===// 225 226 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns( 227 RewritePatternSet &patterns) { 228 linalg::populateEraseUnnecessaryInputsPatterns(patterns); 229 } 230 231 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns( 232 RewritePatternSet &patterns) { 233 linalg::populateDecomposePackUnpackPatterns(patterns); 234 } 235 236 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns( 237 RewritePatternSet &patterns) { 238 linalg::populateDecomposePadPatterns(patterns); 239 } 240 241 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns( 242 RewritePatternSet &patterns) { 243 linalg::ControlDropUnitDims options; 244 linalg::populateFoldUnitExtentDimsPatterns(patterns, options); 245 } 246 247 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns( 248 RewritePatternSet &patterns) { 249 linalg::ControlDropUnitDims options; 250 options.rankReductionStrategy = 251 linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice; 252 linalg::populateFoldUnitExtentDimsPatterns(patterns, options); 253 } 254 255 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns( 256 RewritePatternSet &patterns) { 257 linalg::populateLinalgTilingCanonicalizationPatterns(patterns); 258 } 259 260 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns( 261 RewritePatternSet &patterns) { 262 linalg::populateFoldAddIntoDestPatterns(patterns); 263 } 264 265 void transform::ApplyPadVectorizationPatternsOp::populatePatterns( 266 RewritePatternSet &patterns) { 267 linalg::populatePadOpVectorizationPatterns(patterns); 268 linalg::populateInsertSliceVectorizationPatterns(patterns); 269 } 270 271 //===----------------------------------------------------------------------===// 272 // BufferizeToAllocationOp 273 //===----------------------------------------------------------------------===// 274 275 void transform::BufferizeToAllocationOp::build(OpBuilder &b, 276 OperationState &result, 277 Value target, 278 Attribute memorySpace) { 279 SmallVector<Type> resultTypes; 280 resultTypes.push_back(b.getType<transform::AnyValueType>()); 281 resultTypes.push_back(b.getType<transform::AnyOpType>()); 282 return build(b, result, 283 /*resultTypes=*/resultTypes, 284 /*target=*/target, 285 /*memorySpace=*/memorySpace); 286 } 287 288 void transform::BufferizeToAllocationOp::build(OpBuilder &b, 289 OperationState &result, 290 Value target, 291 int64_t memorySpace) { 292 SmallVector<Type> resultTypes; 293 resultTypes.push_back(b.getType<transform::AnyValueType>()); 294 resultTypes.push_back(b.getType<transform::AnyOpType>()); 295 return build(b, result, 296 /*resultTypes=*/resultTypes, 297 /*target=*/target, 298 /*memorySpace=*/b.getI64IntegerAttr(memorySpace)); 299 } 300 301 namespace { 302 class NewOpsListener : public RewriterBase::ForwardingListener { 303 public: 304 using RewriterBase::ForwardingListener::ForwardingListener; 305 306 SmallVector<Operation *> getNewOps() const { 307 return SmallVector<Operation *>(newOps.begin(), newOps.end()); 308 } 309 310 private: 311 void notifyOperationInserted(Operation *op, 312 OpBuilder::InsertPoint previous) override { 313 ForwardingListener::notifyOperationInserted(op, previous); 314 // We only care about newly created ops. 315 if (previous.isSet()) 316 return; 317 auto inserted = newOps.insert(op); 318 (void)inserted; 319 assert(inserted.second && "expected newly created op"); 320 } 321 322 void notifyOperationErased(Operation *op) override { 323 ForwardingListener::notifyOperationErased(op); 324 op->walk([&](Operation *op) { newOps.erase(op); }); 325 } 326 327 DenseSet<Operation *> newOps; 328 }; 329 } // namespace 330 331 DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply( 332 transform::TransformRewriter &rewriter, 333 transform::TransformResults &results, transform::TransformState &state) { 334 // Attach listener to keep track of newly created ops. 335 OpBuilder::Listener *previousListener = rewriter.getListener(); 336 auto resetListener = 337 llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); }); 338 NewOpsListener newOpsListener(previousListener); 339 rewriter.setListener(&newOpsListener); 340 341 linalg::BufferizeToAllocationOptions options; 342 if (getMemcpyOp() == "bufferization.materialize_in_destination") { 343 options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp:: 344 MaterializeInDestination; 345 } else if (getMemcpyOp() == "memref.copy") { 346 options.memcpyOp = 347 linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy; 348 } else if (getMemcpyOp() == "linalg.copy") { 349 options.memcpyOp = 350 linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy; 351 } else { 352 llvm_unreachable("invalid memcpy op"); 353 } 354 if (getAllocOp() == "memref.alloc") { 355 options.allocOp = 356 linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloc; 357 } else if (getAllocOp() == "memref.alloca") { 358 options.allocOp = 359 linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloca; 360 } else { 361 llvm_unreachable("invalid alloc op"); 362 } 363 options.bufferizeDestinationOnly = getBufferizeDestinationOnly(); 364 options.emitDealloc = getEmitDealloc(); 365 366 // Bufferize ops. 367 Attribute memorySpace = 368 getMemorySpace().has_value() ? getMemorySpace().value() : Attribute(); 369 SmallVector<Value> allocatedBuffers; 370 for (Operation *op : state.getPayloadOps(getTarget())) { 371 Value buffer = 372 linalg::bufferizeToAllocation(rewriter, options, op, memorySpace); 373 if (!buffer) { 374 DiagnosedSilenceableFailure diag = emitSilenceableError() 375 << "failed to bufferize operation"; 376 diag.attachNote(op->getLoc()) << "target payload op"; 377 return diag; 378 } 379 allocatedBuffers.push_back(buffer); 380 } 381 382 // Set results. 383 results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers); 384 results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps()); 385 return DiagnosedSilenceableFailure::success(); 386 } 387 388 void transform::BufferizeToAllocationOp::getEffects( 389 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 390 if (getBufferizeDestinationOnly()) { 391 // The destination is replaced with a newly allocated buffer, but the op 392 // itself remains in place. 393 onlyReadsHandle(getTargetMutable(), effects); 394 } else { 395 consumesHandle(getTargetMutable(), effects); 396 } 397 producesHandle(getOperation()->getOpResults(), effects); 398 modifiesPayload(effects); 399 } 400 401 LogicalResult transform::BufferizeToAllocationOp::verify() { 402 if (getMemcpyOp() != "bufferization.materialize_in_destination" && 403 getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy") 404 return emitOpError() << "unsupported memcpy op"; 405 if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca") 406 return emitOpError() << "unsupported alloc op"; 407 return success(); 408 } 409 410 //===----------------------------------------------------------------------===// 411 // DecomposeOp 412 //===----------------------------------------------------------------------===// 413 414 DiagnosedSilenceableFailure 415 transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter, 416 LinalgOp target, 417 transform::ApplyToEachResultList &results, 418 transform::TransformState &state) { 419 #define DOWNSCALE(trans) \ 420 { \ 421 FailureOr<LinalgOp> res = tryApply<trans>(target); \ 422 if (succeeded(res)) { \ 423 results.push_back(*res); \ 424 return DiagnosedSilenceableFailure::success(); \ 425 } \ 426 } 427 428 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b> 429 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b)) 430 431 DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp) 432 DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp) 433 DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp) 434 DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp) 435 DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp) 436 DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp) 437 DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp) 438 DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp) 439 DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp) 440 DOWNSCALE(DownscaleDepthwiseConv2DNhwcHwcOp) 441 DOWNSCALE(DownscaleConv2DOp) 442 #undef DOWNSCALE_NORMAL 443 #undef DOWNSCALE_CALL 444 #undef DOWNSCALE 445 return emitDefaultSilenceableFailure(target); 446 } 447 448 //===----------------------------------------------------------------------===// 449 // DecomposeInterfaceOp 450 //===----------------------------------------------------------------------===// 451 452 // Decompose the target operation if it implements the AggregatedOpInterface. 453 // Push the decomposed operations (the ones that replaces the values produced by 454 // \p target) in the `results`. 455 DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne( 456 transform::TransformRewriter &rewriter, Operation *target, 457 transform::ApplyToEachResultList &results, 458 transform::TransformState &state) { 459 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target); 460 if (!decomposableOp) { 461 failed(rewriter.notifyMatchFailure(target, 462 "payload is not a decomposable op")); 463 return emitDefaultSilenceableFailure(target); 464 } 465 466 FailureOr<SmallVector<Value>> maybeNewResults = 467 decomposableOp.decomposeOperation(rewriter); 468 if (failed(maybeNewResults)) 469 return emitDefaultSilenceableFailure(target); 470 471 rewriter.replaceOp(decomposableOp, *maybeNewResults); 472 for (Value val : *maybeNewResults) { 473 Operation *definition = val.getDefiningOp(); 474 if (definition) 475 results.push_back(definition); 476 } 477 return DiagnosedSilenceableFailure::success(); 478 } 479 480 //===----------------------------------------------------------------------===// 481 // EliminateLinalgOpAnchoredEmptyTensorsOp 482 //===----------------------------------------------------------------------===// 483 484 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects( 485 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 486 onlyReadsHandle(getTargetMutable(), effects); 487 modifiesPayload(effects); 488 } 489 490 DiagnosedSilenceableFailure 491 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply( 492 transform::TransformRewriter &rewriter, TransformResults &transformResults, 493 TransformState &state) { 494 bufferization::OneShotBufferizationOptions options; 495 options.allowReturnAllocsFromLoops = true; 496 497 for (Operation *target : state.getPayloadOps(getTarget())) { 498 bufferization::OneShotAnalysisState state(target, options); 499 if (failed(analyzeOp(target, state))) 500 return mlir::emitSilenceableFailure(target->getLoc()) 501 << "failed to analyze op"; 502 if (failed(linalg::linalgOpAnchoredEmptyTensorEliminationStep( 503 rewriter, target, state))) 504 return mlir::emitSilenceableFailure(target->getLoc()) 505 << "failed to eliminate LinalgOp anchored tensor.empty ops"; 506 } 507 return DiagnosedSilenceableFailure::success(); 508 } 509 510 //===----------------------------------------------------------------------===// 511 // FuseOp 512 //===----------------------------------------------------------------------===// 513 514 /// Apply a tiling transformation to all payload ops and store both the 515 /// tiled operation as well as the created tile loops. 516 template <typename Range> 517 static LogicalResult applyTilingToAll( 518 RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, 519 unsigned numLoops, transform::TransformResults &transformResults, 520 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)> 521 applyFn) { 522 SmallVector<Operation *> tiledLinalgOps; 523 SmallVector<SmallVector<Operation *>> loopOps(numLoops); 524 525 for (Operation *target : payloadOps) { 526 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); 527 if (!tilingInterfaceOp) 528 return transformOp->emitError("only TilingInterface ops are supported"); 529 530 rewriter.setInsertionPoint(target); 531 FailureOr<scf::SCFTileAndFuseResult> tiledResults = 532 applyFn(tilingInterfaceOp); 533 if (failed(tiledResults)) 534 return failure(); 535 536 // Perform the replacement of tiled and fused values. 537 SmallVector<Operation *> opsToReplace{target}; 538 llvm::append_range(opsToReplace, tiledResults->fusedProducers); 539 for (Operation *toReplace : opsToReplace) { 540 for (OpResult res : toReplace->getResults()) 541 if (auto replacement = tiledResults->replacements.lookup(res)) 542 rewriter.replaceAllUsesWith(res, replacement); 543 if (toReplace->use_empty()) { 544 rewriter.eraseOp(toReplace); 545 } 546 } 547 548 // Report back the relevant handles to the transform op. 549 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front()); 550 assert(tiledResults->loops.size() == numLoops && 551 "Mismatched number of loops, tile and fuse transform should have " 552 "failed"); 553 for (unsigned int i = 0; i < numLoops; ++i) 554 loopOps[i].push_back(tiledResults->loops[i]); 555 } 556 557 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); 558 for (unsigned int i = 0; i < numLoops; ++i) 559 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); 560 561 return success(); 562 } 563 564 DiagnosedSilenceableFailure 565 transform::FuseOp::apply(transform::TransformRewriter &rewriter, 566 mlir::transform::TransformResults &transformResults, 567 mlir::transform::TransformState &state) { 568 SmallVector<int64_t> tileSizes = 569 extractFromIntegerArrayAttr<int64_t>(getTileSizes()); 570 SmallVector<int64_t> tileInterchange = 571 extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); 572 573 scf::SCFTilingOptions tilingOptions; 574 tilingOptions.interchangeVector = tileInterchange; 575 SmallVector<OpFoldResult> tileSizesOfr = 576 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); 577 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); 578 scf::SCFTileAndFuseOptions tileAndFuseOptions; 579 tileAndFuseOptions.tilingOptions = tilingOptions; 580 581 if (getApplyCleanup()) { 582 MLIRContext *context = rewriter.getContext(); 583 RewritePatternSet patterns(context); 584 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context); 585 tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); 586 tileAndFuseOptions.cleanupPatterns = std::move(patterns); 587 } 588 589 LogicalResult result = applyTilingToAll( 590 rewriter, getOperation(), state.getPayloadOps(getTarget()), 591 tileSizes.size() - llvm::count(tileSizes, 0), transformResults, 592 [&](TilingInterface tilingInterfaceOp) 593 -> FailureOr<scf::SCFTileAndFuseResult> { 594 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, 595 tileAndFuseOptions); 596 }); 597 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() 598 : DiagnosedSilenceableFailure::success(); 599 } 600 601 LogicalResult transform::FuseOp::verify() { 602 SmallVector<int64_t> permutation = 603 extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); 604 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); 605 if (!std::is_permutation(sequence.begin(), sequence.end(), 606 permutation.begin(), permutation.end())) { 607 return emitOpError() << "expects interchange to be a permutation, found " 608 << getTileInterchange(); 609 } 610 611 SmallVector<int64_t> sizes = 612 extractFromIntegerArrayAttr<int64_t>(getTileSizes()); 613 size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0); 614 if (numExpectedLoops != getNumResults() - 1) 615 return emitOpError() << "expects " << numExpectedLoops << " loop results"; 616 617 return success(); 618 } 619 620 //===----------------------------------------------------------------------===// 621 // FuseIntoContainingOp 622 //===----------------------------------------------------------------------===// 623 624 void transform::FuseIntoContainingOp::build(OpBuilder &builder, 625 OperationState &result, 626 Value producerOp, 627 Value containingOp) { 628 result.addOperands({producerOp, containingOp}); 629 auto resultType = transform::AnyOpType::get(builder.getContext()); 630 result.addTypes({resultType, resultType}); 631 } 632 633 /// Add new operands to the forall op for users of the producerOp 634 /// that are dominated by the containing scf.forall op. 635 static Operation *replaceForAllWithNewSignature( 636 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, 637 Operation *containingOp, TilingResult &tileAndFuseResult, 638 int64_t resultNumber, SmallVector<OpFoldResult> &offsets, 639 SmallVector<OpFoldResult> &sizes) { 640 641 // Count number of users not including the containing op 642 SetVector<Operation *> dominatedUsers; 643 DominanceInfo domInfo(containingOp); 644 for (Operation *user : producerOp->getResult(resultNumber).getUsers()) { 645 if (!containingOp->isAncestor(user) && 646 (domInfo.dominates(containingOp, user))) { 647 dominatedUsers.insert(user); 648 } 649 } 650 if (dominatedUsers.empty()) 651 return nullptr; 652 653 // Create new scf.forall op 654 auto forallOp = cast<scf::ForallOp>(containingOp); 655 OpBuilder::InsertionGuard g(rewriter); 656 rewriter.setInsertionPoint(forallOp); 657 658 // Get new output 659 Location loc = forallOp.getLoc(); 660 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp); 661 if (!genericOp) 662 return nullptr; 663 SmallVector<Value> outputs = genericOp.getOutputs(); 664 SmallVector<Value> newOuts(forallOp.getOutputs()); 665 newOuts.push_back(outputs[resultNumber]); 666 667 // Create new scf.forall op 668 auto newforallOp = rewriter.create<scf::ForallOp>( 669 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), 670 forallOp.getMixedStep(), newOuts, forallOp.getMapping()); 671 rewriter.eraseBlock(newforallOp.getBody()); 672 newforallOp.getRegion().takeBody(forallOp.getRegion()); 673 674 // Add additional block argument for new value being returned 675 // and replaces all uses of the new output with corresponding bbArg 676 // inside the scf.forall to enable fusion into this new scf.forall. 677 newforallOp.getBody()->addArgument(newOuts.back().getType(), 678 newOuts.back().getLoc()); 679 auto bbArgs = newforallOp.getBody()->getArguments(); 680 rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(), 681 [&](OpOperand &use) { 682 Operation *op = use.getOwner(); 683 return newforallOp->isProperAncestor(op); 684 }); 685 686 // Fix terminator 687 scf::InParallelOp terminatorOp = newforallOp.getTerminator(); 688 SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range( 689 terminatorOp.getYieldingOps(), [](Operation &op) { return &op; })); 690 Operation *firstYieldOp = yieldingOps.front(); 691 rewriter.setInsertionPoint(firstYieldOp); 692 Value src = tileAndFuseResult.tiledValues[0]; 693 Value dst = newforallOp.getRegionIterArgs().back(); 694 SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1)); 695 rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src, 696 dst, offsets, sizes, strides); 697 698 for (auto result : llvm::enumerate(forallOp.getResults())) { 699 rewriter.replaceAllUsesWith(result.value(), 700 newforallOp->getResult(result.index())); 701 } 702 rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber), 703 newforallOp->getResults().back(), 704 [&](OpOperand &use) { 705 Operation *user = use.getOwner(); 706 return dominatedUsers.contains(user); 707 }); 708 return newforallOp; 709 } 710 711 /// Find the first "extract" user of `producerOp` and tile it right before its 712 /// use. The tiled op is fused under the `containingOp`. 713 /// Return this fused op on success or nullptr if anything fails. 714 /// If tiled op has uses that are dominated by `containingOp`, return 715 /// a new `containingOp` with results of the fused op appended to 716 /// results of the `containingOp` or nullptr if there are no dominated uses. 717 static std::tuple<SmallVector<Operation *>, Operation *> 718 tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, 719 Operation *producerOp, Operation *containingOp) { 720 LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n"); 721 auto tileableProducer = dyn_cast<TilingInterface>(producerOp); 722 if (!tileableProducer) { 723 diag.attachNote(producerOp->getLoc()) 724 << "producer is not a TileableInterface: " << *producerOp; 725 return {}; 726 } 727 728 // Search the producer slices accessed within the containing operation. 729 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe 730 // evolve into an interface. 731 auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) { 732 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user); 733 return sliceOp && containingOp->isProperAncestor(sliceOp); 734 }); 735 736 // Find a fusion opportunity. 737 if (it == tileableProducer->getUsers().end()) { 738 diag.attachNote(tileableProducer->getLoc()) 739 << "could not find fusion opportunity for: " << *tileableProducer; 740 return {}; 741 } 742 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it); 743 744 // Try to fuse the producer in-place. 745 OpBuilder::InsertionGuard guard(rewriter); 746 rewriter.setInsertionPoint(sliceOpToTile); 747 748 // Tile the producer. 749 int64_t resultNumber = 750 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber(); 751 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); 752 753 SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets(); 754 SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes(); 755 756 FailureOr<TilingResult> tileAndFuseResult = 757 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets, 758 sizes); 759 760 if (failed(tileAndFuseResult)) { 761 diag.attachNote(tileableProducer->getLoc()) 762 << "failed to tile producer op: " << *tileableProducer; 763 return {}; 764 } 765 766 #ifndef NDEBUG 767 for (auto *tiledOp : tileAndFuseResult->tiledOps) { 768 LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n"); 769 } 770 #endif 771 772 // Replace the extract op. 773 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( 774 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], 775 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape()); 776 if (failed(maybeRankReduced)) { 777 diag.attachNote(producerOp->getLoc()) 778 << "shape types don't match (missing canonicalization?):\nTiledOp: " 779 << tileAndFuseResult->tiledValues[0] 780 << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n'; 781 return {}; 782 } 783 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); 784 785 // Add new outputs to containing op, if required 786 Operation *newContainingOp = replaceForAllWithNewSignature( 787 rewriter, diag, producerOp, containingOp, *tileAndFuseResult, 788 resultNumber, offsets, sizes); 789 790 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp); 791 } 792 793 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure 794 /// it is exactly the `containingOp`, otherwise bail. 795 /// Then, find the first "extract" user of the tied block argument and tile it 796 /// right before its "extract" use. The tiled op is fused under the 797 /// `containingOp`. 798 /// Return this fused op on success or nullptr if anything fails. 799 static SmallVector<Operation *> 800 tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( 801 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, 802 Operation *containingOp) { 803 LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n"); 804 805 auto tileableProducer = dyn_cast<TilingInterface>(producerOp); 806 if (!tileableProducer) { 807 diag.attachNote(producerOp->getLoc()) 808 << "producer is not a TileableInterface: " << *producerOp; 809 return {}; 810 } 811 812 // Search the first use by a "scf::ForallOp" user. 813 scf::ForallOp forallOp; 814 auto itProducerUses = 815 llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) { 816 forallOp = dyn_cast<scf::ForallOp>(use.getOwner()); 817 return forallOp; 818 }); 819 // If it's not from the containing op, return. 820 if (!forallOp || forallOp != containingOp) { 821 diag.attachNote(tileableProducer->getLoc()) 822 << "could not find a use by the containing op: " << *tileableProducer; 823 return {}; 824 } 825 826 // Search the producer slices accessed within the containing 827 // operation. 828 // TODO: Generalize to more extract/insert/parallel_insert triples. 829 // Maybe evolve into an interface. 830 OpOperand *pUse = &(*itProducerUses); 831 BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse); 832 833 // Search the producer slices accessed within the containing operation. 834 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe 835 // evolve into an interface. 836 auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) { 837 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user); 838 return sliceOp && containingOp->isProperAncestor(sliceOp); 839 }); 840 841 // Find a fusion opportunity. 842 if (itBBArgUsers == bbArg.getUsers().end()) { 843 diag.attachNote(containingOp->getLoc()) 844 << "could not find fusion opportunity for bbArg: " << bbArg; 845 return {}; 846 } 847 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers); 848 849 // Try to fuse the producer in-place. 850 OpBuilder::InsertionGuard guard(rewriter); 851 rewriter.setInsertionPoint(sliceOpToTile); 852 853 // Replace the use in the tileableProducer before tiling: clone, replace and 854 // then tile. 855 int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber(); 856 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); 857 858 // Gather destination tensors. 859 SmallVector<Value> destinationTensors; 860 if (failed(tensor::getOrCreateDestinations( 861 rewriter, tileableProducer->getLoc(), tileableProducer, 862 destinationTensors))) { 863 diag.attachNote(tileableProducer->getLoc()) 864 << "failed to get destination tensors for: " << *tileableProducer; 865 return {}; 866 } 867 868 IRMapping bvm; 869 bvm.map(destinationTensors[resultNumber], bbArg); 870 auto tileableProducerClone = 871 cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm)); 872 auto scopeGuard = 873 llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); }); 874 875 // Tile the producer. 876 FailureOr<TilingResult> tileAndFuseResult = 877 tileableProducerClone.generateResultTileValue( 878 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), 879 sliceOpToTile.getMixedSizes()); 880 if (failed(tileAndFuseResult)) { 881 diag.attachNote(tileableProducer->getLoc()) 882 << "failed to tile producer op: " << *tileableProducer; 883 return {}; 884 } 885 886 // Replace the extract op. 887 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( 888 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], 889 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape()); 890 assert(succeeded(maybeRankReduced) && "unexpected shape"); 891 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); 892 893 // Replace the use in containingOp. 894 rewriter.modifyOpInPlace(containingOp, [&]() { 895 containingOp->setOperand(pUse->getOperandNumber(), 896 destinationTensors.front()); 897 }); 898 899 return tileAndFuseResult->tiledOps; 900 } 901 902 static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, 903 Operation *producerOp, 904 Operation *containingOp) { 905 LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n"); 906 907 // Gather all uses inside the containing op. 908 SmallVector<OpOperand *> uses; 909 for (OpResult result : producerOp->getOpResults()) { 910 for (OpOperand &use : result.getUses()) { 911 if (containingOp->isProperAncestor(use.getOwner())) { 912 uses.push_back(&use); 913 continue; 914 } 915 // Cannot clone and fuse if the use is by the containing op itself: fail 916 // immediately. 917 if (containingOp == use.getOwner()) { 918 diag.attachNote(producerOp->getLoc()) 919 << "producer op use by containing op cannot be fused by cloning"; 920 return nullptr; 921 } 922 } 923 } 924 925 // Check for a non-empty list of fusion opportunities. 926 if (uses.empty()) { 927 diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning"; 928 return nullptr; 929 } 930 931 // Clone and fuse inside the containing op. 932 Operation *fusedOp = nullptr; 933 OpOperand *use = uses.front(); 934 // Parallel insert slice is not a valid clone destination. 935 // TODO: Generalize to other type of ops. 936 assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) && 937 "Parallel insert slice is not a valid clone destination"); 938 unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber(); 939 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); 940 941 OpBuilder::InsertionGuard guard(rewriter); 942 rewriter.setInsertionPoint(use->getOwner()); 943 fusedOp = rewriter.clone(*producerOp); 944 rewriter.modifyOpInPlace( 945 use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); }); 946 947 return fusedOp; 948 } 949 950 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() { 951 // Allow repeated handles since we are fusing everything anyway. 952 return true; 953 } 954 955 DiagnosedSilenceableFailure 956 transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter, 957 transform::TransformResults &results, 958 transform::TransformState &state) { 959 SmallVector<Operation *> fusedOps; 960 auto producerOps = state.getPayloadOps(getProducerOp()); 961 auto containingOps = state.getPayloadOps(getContainingOp()); 962 if (!llvm::hasSingleElement(containingOps)) { 963 return emitDefiniteFailure() 964 << "requires exactly one containing_op handle (got " 965 << llvm::range_size(containingOps) << ")"; 966 } 967 Operation *containingOp = *containingOps.begin(); 968 969 // If nothing to fuse, propagate success. 970 if (std::empty(producerOps)) { 971 results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{}); 972 results.set(cast<OpResult>(getNewContainingOp()), {containingOp}); 973 return DiagnosedSilenceableFailure::success(); 974 } 975 976 // Helper function to find the next producer that should be fused. Take any 977 // producer that has a use inside the containing op. 978 SetVector<Operation *> remainingProducers(producerOps.begin(), 979 producerOps.end()); 980 auto getNextProducer = [&]() -> FailureOr<Operation *> { 981 for (const auto &it : enumerate(remainingProducers)) { 982 Operation *producerOp = it.value(); 983 // The containing op may be a user of producerOp: use isAncestor. 984 int64_t numUsesInContainingOp = 985 llvm::count_if(producerOp->getUsers(), [&](Operation *op) { 986 return containingOp->isAncestor(op); 987 }); 988 // TODO: When resolving the TODO below (no duplicate ops), take an op 989 // that has no use among the remaining producers. This is a topological 990 // sorting. 991 if (numUsesInContainingOp > 0) { 992 if (numUsesInContainingOp == 1) 993 remainingProducers.erase(remainingProducers.begin() + it.index()); 994 return producerOp; 995 } 996 } 997 return failure(); 998 }; 999 1000 while (!remainingProducers.empty()) { 1001 auto nextProducer = getNextProducer(); 1002 if (failed(nextProducer)) { 1003 auto diag = mlir::emitSilenceableFailure(getLoc()) 1004 << "could not find next producer to fuse into container"; 1005 diag.attachNote(containingOp->getLoc()) << "containing op"; 1006 return diag; 1007 } 1008 1009 Operation *producerOp = *nextProducer; 1010 1011 // Default diagnostic, to be complemented with more failure information. 1012 Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark); 1013 diag << "could not fuse " << *producerOp << " into " << *containingOp; 1014 1015 // TODO: If there are multiple uses of the producer in the containing op, 1016 // we currently tile/clone the op multiple times (once per use). In some 1017 // cases, we can tile/clone once and reuse the value for each use. 1018 // Futhermore, producers should then be traversed according to a 1019 // topological sorting. 1020 auto [tiledOps, newContainingOp] = 1021 tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp); 1022 if (!tiledOps.empty()) { 1023 LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp); 1024 fusedOps.append(tiledOps); 1025 if (newContainingOp) { 1026 // Update handles associated with the containing op so we don't need to 1027 // invalidate them. This is a hack to support better composability 1028 // between tiling and fusion while a proper mechanism is being 1029 // investigated. 1030 // 1031 // DO NOT replicate this elsewhere unless you understand what you are 1032 // doing. 1033 LogicalResult replacementStatus = 1034 rewriter.notifyPayloadOperationReplaced(containingOp, 1035 newContainingOp); 1036 (void)replacementStatus; 1037 assert(succeeded(replacementStatus) && 1038 "unable to update transform state mapping"); 1039 rewriter.eraseOp(containingOp); 1040 containingOp = newContainingOp; 1041 } 1042 continue; 1043 } 1044 1045 SmallVector<Operation *> tiledContainingOpOperand = 1046 tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( 1047 rewriter, diag, producerOp, containingOp); 1048 if (!tiledContainingOpOperand.empty()) { 1049 LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n" 1050 << *containingOp); 1051 fusedOps.append(tiledContainingOpOperand); 1052 continue; 1053 } 1054 1055 Operation *cloned = 1056 cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp); 1057 if (cloned) { 1058 LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp); 1059 fusedOps.push_back(cloned); 1060 continue; 1061 } 1062 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); 1063 } 1064 1065 results.set(cast<OpResult>(getFusedOp()), fusedOps); 1066 results.set(cast<OpResult>(getNewContainingOp()), {containingOp}); 1067 return DiagnosedSilenceableFailure::success(); 1068 } 1069 1070 void transform::FuseIntoContainingOp::getEffects( 1071 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1072 consumesHandle(getProducerOpMutable(), effects); 1073 onlyReadsHandle(getContainingOpMutable(), effects); 1074 producesHandle(getOperation()->getOpResults(), effects); 1075 modifiesPayload(effects); 1076 } 1077 1078 //===----------------------------------------------------------------------===// 1079 // GeneralizeOp 1080 //===----------------------------------------------------------------------===// 1081 1082 DiagnosedSilenceableFailure 1083 transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter, 1084 LinalgOp target, 1085 transform::ApplyToEachResultList &results, 1086 transform::TransformState &state) { 1087 // Exit early if no transformation is needed. 1088 if (isa<GenericOp>(target)) { 1089 results.push_back(target); 1090 return DiagnosedSilenceableFailure::success(); 1091 } 1092 rewriter.setInsertionPoint(target); 1093 FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target); 1094 if (succeeded(generic)) { 1095 results.push_back(generic->getOperation()); 1096 return DiagnosedSilenceableFailure::success(); 1097 } 1098 return emitDefaultSilenceableFailure(target); 1099 } 1100 1101 //===----------------------------------------------------------------------===// 1102 // SpecializeOp 1103 //===----------------------------------------------------------------------===/ 1104 1105 DiagnosedSilenceableFailure 1106 transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter, 1107 LinalgOp target, 1108 transform::ApplyToEachResultList &results, 1109 transform::TransformState &state) { 1110 // Exit early if the operation is not a generic. 1111 if (!isa<GenericOp>(target)) { 1112 results.push_back(target); 1113 return DiagnosedSilenceableFailure::success(); 1114 } 1115 rewriter.setInsertionPoint(target); 1116 FailureOr<LinalgOp> named = 1117 specializeGenericOp(rewriter, cast<GenericOp>(target)); 1118 if (succeeded(named)) { 1119 results.push_back(named->getOperation()); 1120 return DiagnosedSilenceableFailure::success(); 1121 } 1122 return emitDefaultSilenceableFailure(target); 1123 } 1124 1125 //===----------------------------------------------------------------------===// 1126 // InterchangeOp 1127 //===----------------------------------------------------------------------===// 1128 1129 DiagnosedSilenceableFailure 1130 transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter, 1131 GenericOp target, 1132 transform::ApplyToEachResultList &results, 1133 transform::TransformState &state) { 1134 ArrayRef<int64_t> interchangeVector = getIteratorInterchange(); 1135 // Exit early if no transformation is needed. 1136 if (interchangeVector.empty()) { 1137 results.push_back(target); 1138 return DiagnosedSilenceableFailure::success(); 1139 } 1140 1141 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops(); 1142 if (interchangeVector.size() != numLoops) { 1143 return emitSilenceableError() 1144 << getIteratorInterchangeAttrName() << " has length (" 1145 << interchangeVector.size() 1146 << ") different from the number of loops in the target operation (" 1147 << numLoops << ")"; 1148 } 1149 FailureOr<GenericOp> res = interchangeGenericOp( 1150 rewriter, target, SmallVector<unsigned>(interchangeVector)); 1151 if (failed(res)) 1152 return emitDefiniteFailure() << "failed to apply"; 1153 results.push_back(res->getOperation()); 1154 return DiagnosedSilenceableFailure::success(); 1155 } 1156 1157 LogicalResult transform::InterchangeOp::verify() { 1158 ArrayRef<int64_t> permutation = getIteratorInterchange(); 1159 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); 1160 if (!std::is_permutation(sequence.begin(), sequence.end(), 1161 permutation.begin(), permutation.end())) { 1162 return emitOpError() 1163 << "expects iterator_interchange to be a permutation, found " 1164 << getIteratorInterchange(); 1165 } 1166 return success(); 1167 } 1168 1169 //===----------------------------------------------------------------------===// 1170 // LowerPackOp 1171 //===----------------------------------------------------------------------===// 1172 1173 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne( 1174 transform::TransformRewriter &rewriter, tensor::PackOp target, 1175 transform::ApplyToEachResultList &transformResults, 1176 transform::TransformState &state) { 1177 rewriter.setInsertionPoint(target); 1178 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice(); 1179 FailureOr<LowerPackResult> res = 1180 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice); 1181 if (failed(res)) { 1182 return mlir::emitSilenceableFailure(target->getLoc()) 1183 << "cannot lower to pad + expand + transpose"; 1184 } 1185 transformResults.push_back(res->padOp); 1186 transformResults.push_back(res->expandShapeOp); 1187 transformResults.push_back(res->transposeOp); 1188 return DiagnosedSilenceableFailure::success(); 1189 } 1190 1191 //===----------------------------------------------------------------------===// 1192 // LowerUnPackOp 1193 //===----------------------------------------------------------------------===// 1194 1195 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne( 1196 transform::TransformRewriter &rewriter, tensor::UnPackOp target, 1197 transform::ApplyToEachResultList &transformResults, 1198 transform::TransformState &state) { 1199 rewriter.setInsertionPoint(target); 1200 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice(); 1201 FailureOr<LowerUnPackOpResult> res = 1202 lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice); 1203 if (failed(res)) { 1204 DiagnosedSilenceableFailure diag = 1205 emitSilenceableError() 1206 << "cannot lower to transpose + collapse + extract"; 1207 diag.attachNote(target->getLoc()) << "target payload op"; 1208 return diag; 1209 } 1210 transformResults.push_back(res->emptyOp); 1211 transformResults.push_back(res->transposeOp); 1212 transformResults.push_back(res->collapseShapeOp); 1213 transformResults.push_back(res->extractSliceOp); 1214 return DiagnosedSilenceableFailure::success(); 1215 } 1216 1217 //===---------------------------------------------------------------------===// 1218 // MatchOp 1219 //===---------------------------------------------------------------------===// 1220 1221 void transform::MatchOp::build(OpBuilder &builder, OperationState &result, 1222 Value target, ArrayRef<StringRef> opNames) { 1223 result.addOperands(target); 1224 result.addAttribute(MatchOp::getOpsAttrName(result.name), 1225 builder.getStrArrayAttr(opNames)); 1226 result.addTypes(transform::AnyOpType::get(builder.getContext())); 1227 } 1228 1229 void transform::MatchOp::build(OpBuilder &builder, OperationState &result, 1230 TypeRange resultTypes, Value target, 1231 ArrayRef<StringRef> opNames) { 1232 result.addOperands(target); 1233 result.addAttribute(MatchOp::getOpsAttrName(result.name), 1234 builder.getStrArrayAttr(opNames)); 1235 result.addTypes(resultTypes); 1236 } 1237 1238 DiagnosedSilenceableFailure 1239 transform::MatchOp::apply(transform::TransformRewriter &rewriter, 1240 transform::TransformResults &results, 1241 transform::TransformState &state) { 1242 llvm::StringSet<> strs; 1243 if (getOps().has_value()) 1244 strs.insert(getOps()->getAsValueRange<StringAttr>().begin(), 1245 getOps()->getAsValueRange<StringAttr>().end()); 1246 1247 auto payloadOps = state.getPayloadOps(getTarget()); 1248 if (!llvm::hasSingleElement(payloadOps)) { 1249 return emitDefiniteFailure("requires exactly one target handle"); 1250 } 1251 1252 SmallVector<Operation *> res; 1253 bool incorrectNumOperandTypes = false; 1254 auto matchFun = [&](Operation *op) { 1255 if (getOps().has_value() && !strs.contains(op->getName().getStringRef())) 1256 return; 1257 1258 // Interfaces cannot be matched by name, just by ID. 1259 // So we specifically encode the interfaces we care about for this op. 1260 if (getInterface().has_value()) { 1261 auto iface = getInterface().value(); 1262 if (iface == transform::MatchInterfaceEnum::LinalgOp && 1263 !isa<LinalgOp>(op)) 1264 return; 1265 if (iface == transform::MatchInterfaceEnum::TilingInterface && 1266 !isa<TilingInterface>(op)) 1267 return; 1268 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface && 1269 !isa<LoopLikeOpInterface>(op)) 1270 return; 1271 } 1272 1273 // Check if all specified attributes match. 1274 if (getOpAttrs().has_value()) { 1275 DictionaryAttr opAttrs = getOpAttrs().value(); 1276 for (NamedAttribute attr : opAttrs) { 1277 if (attr.getName() == getInterfaceAttrName() || 1278 attr.getName() == getOpsAttrName()) 1279 continue; 1280 if (!op->hasAttr(attr.getName())) 1281 return; 1282 if (op->getAttr(attr.getName()) != attr.getValue()) 1283 return; 1284 } 1285 } 1286 1287 if (getFilterResultType().has_value()) { 1288 Type t = getFilterResultType().value(); 1289 if (op->getNumResults() != 1 || op->getResultTypes().front() != t) 1290 return; 1291 } 1292 1293 if (getFilterOperandTypes().has_value()) { 1294 mlir::ArrayAttr types = getFilterOperandTypes().value(); 1295 auto operandTypes = op->getOperandTypes(); 1296 1297 if (types.size() == 1) { 1298 // All the operands must must be equal to the specified type 1299 auto typeattr = 1300 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]); 1301 Type t = cast<::mlir::Type>(typeattr.getValue()); 1302 if (!llvm::all_of(op->getOperandTypes(), 1303 [&](Type operandType) { return operandType == t; })) 1304 return; 1305 } else { 1306 // The operand types must match all the types in the list (in the same 1307 // order in with they are specified) 1308 if (types.size() != operandTypes.size()) { 1309 incorrectNumOperandTypes = true; 1310 return; 1311 } 1312 1313 for (auto [attr, operandType] : 1314 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) { 1315 auto typeattr = cast<mlir::TypeAttr>(attr); 1316 Type type = cast<::mlir::Type>(typeattr.getValue()); 1317 1318 if (type != operandType) 1319 return; 1320 } 1321 } 1322 } 1323 1324 // All constraints are satisfied. 1325 res.push_back(op); 1326 return; 1327 }; 1328 1329 (*payloadOps.begin())->walk(matchFun); 1330 if (incorrectNumOperandTypes) 1331 return emitDefiniteFailure("If filter_operand_types contains more than a " 1332 "type, then it must contain as much types as " 1333 "the number of operands in the target ops"); 1334 results.set(cast<OpResult>(getResult()), res); 1335 return DiagnosedSilenceableFailure::success(); 1336 } 1337 1338 //===---------------------------------------------------------------------===// 1339 // MultiTileSizesOp 1340 //===---------------------------------------------------------------------===// 1341 1342 static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, 1343 Type targetType, Type lowSizeType, Type, 1344 Type) { 1345 printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType}); 1346 } 1347 1348 static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, 1349 Type &targetType, Type &lowSizeType, 1350 Type &highSizeType, 1351 Type &splitPointType) { 1352 FunctionType funcType; 1353 llvm::SMLoc typeLoc = parser.getCurrentLocation(); 1354 if (failed(parser.parseType<FunctionType>(funcType))) 1355 return failure(); 1356 1357 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) { 1358 parser.emitError(typeLoc) << "expects a trailing functional type with one " 1359 "argument and one result"; 1360 } 1361 targetType = funcType.getInput(0); 1362 lowSizeType = highSizeType = splitPointType = funcType.getResult(0); 1363 1364 return success(); 1365 } 1366 1367 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( 1368 transform::TransformRewriter &rewriter, LinalgOp target, 1369 transform::ApplyToEachResultList &results, TransformState &state) { 1370 if (isa<TransformParamTypeInterface>(getLowSize().getType())) { 1371 if (target.hasDynamicShape()) { 1372 auto diag = emitSilenceableError() 1373 << "cannot compute parametric tile sizes for dynamically " 1374 "shaped payload op"; 1375 diag.attachNote(target->getLoc()) << "payload op"; 1376 return diag; 1377 } 1378 1379 FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes( 1380 target, getDimension(), getTargetSize(), getDivisor()); 1381 if (failed(spec)) { 1382 return emitSilenceableError() 1383 << "failed to compute multi-size tiling sizes"; 1384 } 1385 1386 Builder builder(target.getContext()); 1387 results.assign(llvm::map_range( 1388 ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize, 1389 spec->lowTileSize * spec->lowTripCount}), 1390 [&builder, this](int64_t value) { 1391 return builder.getIntegerAttr( 1392 cast<ParamType>(getLowSize().getType()).getType(), value); 1393 })); 1394 return DiagnosedSilenceableFailure::success(); 1395 } 1396 1397 OpBuilder builder(target.getContext()); 1398 builder.setInsertionPoint(target); 1399 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize()); 1400 OpFoldResult divisor = builder.getIndexAttr(getDivisor()); 1401 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes( 1402 builder, target, getDimension(), targetSize, divisor); 1403 if (failed(spec)) { 1404 return emitSilenceableError() << "could not generate tile size computation"; 1405 } 1406 1407 AffineExpr s0 = builder.getAffineSymbolExpr(0); 1408 AffineExpr s1 = builder.getAffineSymbolExpr(1); 1409 Operation *splitPoint = 1410 affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1, 1411 {spec->lowTileSize, spec->lowTripCount}); 1412 Operation *lowTileSize = spec->lowTileSize.getDefiningOp(); 1413 Operation *highTileSize = spec->highTileSize.getDefiningOp(); 1414 assert(lowTileSize && highTileSize && splitPoint && 1415 "tile sizes are not produced by operations"); 1416 results.reserve(results.size() + 3); 1417 results.push_back(lowTileSize); 1418 results.push_back(highTileSize); 1419 results.push_back(splitPoint); 1420 return DiagnosedSilenceableFailure::success(); 1421 } 1422 1423 void transform::MultiTileSizesOp::getEffects( 1424 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1425 onlyReadsHandle(getTargetMutable(), effects); 1426 producesHandle(getOperation()->getOpResults(), effects); 1427 if (isa<TransformParamTypeInterface>(getLowSize().getType())) 1428 onlyReadsPayload(effects); 1429 else 1430 modifiesPayload(effects); 1431 } 1432 1433 LogicalResult transform::MultiTileSizesOp::verify() { 1434 if (getLowSize().getType() != getHighSize().getType() || 1435 getLowSize().getType() != getSplitPoint().getType()) { 1436 return emitOpError() << "expects all results type to be the same"; 1437 } 1438 return success(); 1439 } 1440 1441 //===---------------------------------------------------------------------===// 1442 // PackOp 1443 //===---------------------------------------------------------------------===// 1444 1445 void transform::PackOp::build(OpBuilder &builder, OperationState &result, 1446 Value target, 1447 ArrayRef<OpFoldResult> mixedPackedSizes) { 1448 SmallVector<int64_t> staticPackedSizes; 1449 SmallVector<Value> dynamicPackedSizes; 1450 dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes, 1451 staticPackedSizes); 1452 // Call the default builder which sets up the proper operands segment sizes 1453 // attributes for multiple variadic operands. In the absence of this, horrible 1454 // bugs ensue. 1455 Type linalgOpHType = transform::OperationType::get( 1456 builder.getContext(), GenericOp::getOperationName()); 1457 build(builder, result, 1458 /*resultType=*/linalgOpHType, 1459 /*target=*/target, 1460 /*dynamic_sizes=*/dynamicPackedSizes, 1461 /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes)); 1462 } 1463 1464 SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() { 1465 Builder b(getContext()); 1466 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b); 1467 } 1468 1469 DiagnosedSilenceableFailure 1470 transform::PackOp::apply(transform::TransformRewriter &rewriter, 1471 transform::TransformResults &transformResults, 1472 transform::TransformState &state) { 1473 auto targetOps = state.getPayloadOps(getTarget()); 1474 // If nothing to pack, propagate success. 1475 if (std::empty(targetOps)) { 1476 transformResults.set(cast<OpResult>(getPackedOp()), 1477 ArrayRef<Operation *>({})); 1478 return DiagnosedSilenceableFailure::success(); 1479 } 1480 // Fail on multi-op handles. 1481 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin()); 1482 if (!llvm::hasSingleElement(targetOps) || !linalgOp) { 1483 return emitSilenceableError() 1484 << "requires target to map to exactly 1 LinalgOp (got " 1485 << llvm::range_size(targetOps) << ")"; 1486 } 1487 // Fail on mismatched number of pack sizes. 1488 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) { 1489 return emitSilenceableError() 1490 << "requires number of packed sizes match the number of loops (" 1491 << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops() 1492 << ")"; 1493 } 1494 1495 // Unpack handles to constants or actual SSA index values. 1496 SmallVector<OpFoldResult> packedSizes; 1497 DiagnosedSilenceableFailure status = unpackSingleIndexResultPayloadOperations( 1498 state, *this, packedSizes, getMixedPackedSizes()); 1499 1500 rewriter.setInsertionPoint(linalgOp); 1501 FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes); 1502 if (failed(maybeResult)) 1503 return emitDefiniteFailure("data tiling failed"); 1504 1505 transformResults.set(cast<OpResult>(getPackedOp()), 1506 {maybeResult->packedLinalgOp.getOperation()}); 1507 return DiagnosedSilenceableFailure::success(); 1508 } 1509 1510 void transform::PackOp::getEffects( 1511 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1512 transform::consumesHandle(getTargetMutable(), effects); 1513 transform::onlyReadsHandle(getPackedSizesMutable(), effects); 1514 transform::producesHandle(getOperation()->getOpResults(), effects); 1515 transform::modifiesPayload(effects); 1516 } 1517 1518 //===---------------------------------------------------------------------===// 1519 // PackGreedilyOp. 1520 //===---------------------------------------------------------------------===// 1521 1522 LogicalResult transform::PackGreedilyOp::verify() { 1523 if (!isPermutationVector(getMatmulInnerDimsOrder())) { 1524 return emitOpError() << getMatmulInnerDimsOrderAttrName() 1525 << " is not a valid permutation"; 1526 } 1527 // TODO: relax to allow empty once we have another strategy than just matmul. 1528 if (!getMatmulPaddedSizesNextMultipleOf().empty()) { 1529 for (auto [s, nmo] : 1530 llvm::zip_equal(getMixedMatmulPackedSizes(), 1531 getMatmulPaddedSizesNextMultipleOf())) { 1532 std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s); 1533 if (nmo != 0 && 1534 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) { 1535 return emitOpError() << "at most one of the packed_size and the " 1536 "padded_sizes_next_multiple_of can be nonzero " 1537 "for the matmul strategy"; 1538 } 1539 } 1540 } 1541 return success(); 1542 } 1543 1544 DiagnosedSilenceableFailure 1545 PackGreedilyOp::apply(transform::TransformRewriter &rewriter, 1546 transform::TransformResults &transformResults, 1547 transform::TransformState &state) { 1548 SmallVector<Operation *> results; 1549 for (Operation *op : state.getPayloadOps(getTarget())) { 1550 auto linalgOp = dyn_cast<LinalgOp>(op); 1551 if (!linalgOp) 1552 continue; 1553 // linalgOp will be replaced and the insertion point may be invalidated if 1554 // we set it before -> set it after. 1555 rewriter.setInsertionPointAfter(linalgOp); 1556 // Failing to pack greedily is perfectly fine. 1557 // In the future we will want to order packings according to some metric. 1558 FailureOr<PackResult> packResult = packMatmulGreedily( 1559 /*rewriter=*/rewriter, 1560 /*linalgOp=*/linalgOp, 1561 /*mnkPackedSizes=*/getMixedMatmulPackedSizes(), 1562 /*mnkPaddedSizesNextMultipleOf=*/ 1563 getMatmulPaddedSizesNextMultipleOf(), 1564 /*mnkOrder=*/getMatmulInnerDimsOrder()); 1565 if (succeeded(packResult)) { 1566 results.push_back(packResult->packedLinalgOp); 1567 continue; 1568 } 1569 results.push_back(linalgOp); 1570 } 1571 transformResults.set(cast<OpResult>(getPackedOp()), results); 1572 return DiagnosedSilenceableFailure::success(); 1573 } 1574 1575 SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() { 1576 Builder b(getContext()); 1577 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(), 1578 b); 1579 } 1580 1581 void transform::PackGreedilyOp::getEffects( 1582 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1583 transform::consumesHandle(getTargetMutable(), effects); 1584 transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects); 1585 transform::producesHandle(getOperation()->getOpResults(), effects); 1586 transform::modifiesPayload(effects); 1587 } 1588 1589 //===---------------------------------------------------------------------===// 1590 // PackTransposeOp 1591 //===---------------------------------------------------------------------===// 1592 1593 LogicalResult transform::PackTransposeOp::verify() { 1594 if (!isPermutationVector(getInnerPerm())) { 1595 return emitOpError() << getInnerPermAttrName() 1596 << " is not a valid permutation"; 1597 } 1598 if (!isPermutationVector(getOuterPerm())) { 1599 return emitOpError() << getOuterPermAttrName() 1600 << " is not a valid permutation"; 1601 } 1602 if (getInnerPerm().empty() && getOuterPerm().empty()) { 1603 return emitOpError() << " at least one of " << getInnerPermAttrName() 1604 << " or " << getOuterPermAttrName() 1605 << " must be specified"; 1606 } 1607 return success(); 1608 } 1609 1610 namespace { 1611 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 }; 1612 } // namespace 1613 1614 /// Return true if `permutation` is a valid permutation of the 1615 /// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos` 1616 /// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op. 1617 /// This is the case when the `permutation` rank matches the rank expected by 1618 /// `op` and `permutation` is itself a permutation vector. 1619 /// Return true if either `op` or `permutation` are empty to allow a simpler 1620 /// polymorphic implementation. 1621 template <typename RelayoutOpTy> 1622 bool isValidPackingPermutation( 1623 RelayoutOpTy op, ArrayRef<int64_t> permutation, 1624 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) { 1625 static_assert( 1626 llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value, 1627 "applies to only pack or unpack operations"); 1628 if (!op || permutation.empty()) 1629 return true; 1630 size_t innerRank = op.getInnerDimsPos().size(); 1631 if (outerOrInnerPerm == OuterOrInnerPerm::Inner) 1632 return permutation.size() == innerRank && isPermutationVector(permutation); 1633 // op.getOuterDimsPerm() may be empty, in which case it is identity. 1634 // Don't rely on it. 1635 if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) { 1636 return permutation.size() == op.getSourceRank() && 1637 isPermutationVector(permutation); 1638 } 1639 return permutation.size() == op.getDestRank() && 1640 isPermutationVector(permutation); 1641 } 1642 1643 DiagnosedSilenceableFailure 1644 transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter, 1645 transform::TransformResults &transformResults, 1646 transform::TransformState &state) { 1647 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp()); 1648 auto linalgOps = state.getPayloadOps(getTargetLinalgOp()); 1649 // Step 1. If nothing to pack, propagate success. 1650 if (std::empty(packOrUnpackOps)) { 1651 transformResults.set(cast<OpResult>(getPackedOp()), {}); 1652 transformResults.set(cast<OpResult>(getPackOp()), {}); 1653 transformResults.set(cast<OpResult>(getUnPackOp()), {}); 1654 return DiagnosedSilenceableFailure::success(); 1655 } 1656 1657 // Step 2. Bunch of runtime sanity check and error messages. 1658 // Step 2.1. Fail on multi-op handles. 1659 if (!llvm::hasSingleElement(packOrUnpackOps) || 1660 !llvm::hasSingleElement(linalgOps)) { 1661 return emitSilenceableError() 1662 << "requires target to map to exactly 1 " 1663 "packing op and 1 packed op (" 1664 << "got " << llvm::range_size(packOrUnpackOps) << " and " 1665 << llvm::range_size(linalgOps) << ")"; 1666 } 1667 1668 // Step 2.2. Fail on wrong type. 1669 auto packOp = dyn_cast<tensor::PackOp>(*packOrUnpackOps.begin()); 1670 auto unPackOp = dyn_cast<tensor::UnPackOp>(*packOrUnpackOps.begin()); 1671 if ((!packOp && !unPackOp)) { 1672 return emitSilenceableError() << "requires target to map to a " 1673 "tensor.pack or tensor.unpack"; 1674 } 1675 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin()); 1676 if (!linalgOpTarget) 1677 return emitSilenceableError() << "requires a LinalgOp target"; 1678 1679 // Step 2.3. Fail if we can't get the producer / consumer Linalg op. 1680 LinalgOp linalgOp; 1681 if (packOp && packOp.getResult().hasOneUse()) 1682 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin())); 1683 else if (unPackOp) 1684 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>(); 1685 if (linalgOp != linalgOpTarget) { 1686 auto errorMsg = 1687 packOp ? StringLiteral{"not a single use by the LinalgOp target"} 1688 : StringLiteral{"not produced by the LinalgOp target"}; 1689 return emitSilenceableError() << errorMsg; 1690 } 1691 1692 // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical 1693 // PackOp. 1694 if (unPackOp) { 1695 assert(!packOp && "packOp must be null on entry when unPackOp is not null"); 1696 OpOperand *packUse = linalgOp.getDpsInitOperand( 1697 cast<OpResult>(unPackOp.getSource()).getResultNumber()); 1698 packOp = dyn_cast_or_null<tensor::PackOp>(packUse->get().getDefiningOp()); 1699 if (!packOp || !packOp.getResult().hasOneUse()) 1700 return emitSilenceableError() << "could not find matching pack op"; 1701 } 1702 1703 // Step 2.5. Fail if any permutation does not validate. 1704 for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) { 1705 ArrayRef<int64_t> perm = 1706 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm(); 1707 auto errorMsg = (permType == OuterOrInnerPerm::Outer) 1708 ? StringLiteral{"invalid outer_perm"} 1709 : StringLiteral{"invalid inner_perm"}; 1710 if (!isValidPackingPermutation(packOp, perm, permType) || 1711 !isValidPackingPermutation(unPackOp, perm, permType)) { 1712 Operation *packOrUnpackOp = 1713 unPackOp ? unPackOp.getOperation() : packOp.getOperation(); 1714 return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp; 1715 } 1716 } 1717 1718 // From here on, packOp and linalgOp are always present, unPackOp may or may 1719 // not be present. 1720 assert(packOp && linalgOp && "unexpected null op"); 1721 1722 // Step 3. Actually transpose the ops. 1723 FailureOr<PackTransposeResult> res = packTranspose( 1724 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm()); 1725 // Preconditions have been checked, it is an error to fail here. 1726 assert(succeeded(res) && "unexpected packTranspose failure"); 1727 1728 // Step 4. Return results. 1729 transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp}); 1730 transformResults.set(cast<OpResult>(getPackedOp()), 1731 {res->transposedLinalgOp}); 1732 if (unPackOp) { 1733 transformResults.set(cast<OpResult>(getUnPackOp()), 1734 {res->transposedUnPackOp}); 1735 } else { 1736 transformResults.set(cast<OpResult>(getUnPackOp()), {}); 1737 } 1738 1739 return DiagnosedSilenceableFailure::success(); 1740 } 1741 1742 //===---------------------------------------------------------------------===// 1743 // PadOp 1744 //===---------------------------------------------------------------------===// 1745 1746 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target, 1747 ArrayRef<int64_t> paddingDimensions, 1748 ArrayRef<int64_t> padToMultipleOf, 1749 ArrayRef<int64_t> nofoldFlags, 1750 ArrayRef<Attribute> transposePaddings, 1751 StringRef copyBackOp) { 1752 auto resultType = transform::AnyOpType::get(b.getContext()); 1753 return build(/*builder=*/b, 1754 /*result=*/result, 1755 /*types=*/TypeRange{resultType, resultType}, 1756 /*target=*/target, 1757 /*paddingValues=*/ArrayAttr(), // let inference handle this 1758 /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions), 1759 /*padToMultipleOf=*/ValueRange{}, 1760 /*padToMultipleOf=*/ 1761 (padToMultipleOf.empty() 1762 ? DenseI64ArrayAttr() 1763 : b.getDenseI64ArrayAttr(padToMultipleOf)), 1764 /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags), 1765 /*transposePaddings=*/b.getArrayAttr(transposePaddings), 1766 /*copyBackOp=*/b.getStringAttr(copyBackOp)); 1767 } 1768 1769 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target, 1770 ArrayRef<int64_t> paddingDimensions, 1771 ArrayRef<OpFoldResult> mixedPadToMultipleOf, 1772 ArrayRef<int64_t> nofoldFlags, 1773 ArrayRef<Attribute> transposePaddings, 1774 StringRef copyBackOp) { 1775 auto resultType = transform::AnyOpType::get(b.getContext()); 1776 SmallVector<int64_t> staticPadToMultipleOf; 1777 SmallVector<Value> dynamicPadToMultipleOf; 1778 dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf, 1779 staticPadToMultipleOf); 1780 return build(/*builder=*/b, 1781 /*result=*/result, 1782 /*types=*/TypeRange{resultType, resultType}, 1783 /*target=*/target, 1784 /*paddingValues=*/ArrayAttr(), // let inference handle this 1785 /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions), 1786 /*padToMultipleOf=*/dynamicPadToMultipleOf, 1787 /*padToMultipleOf=*/staticPadToMultipleOf, 1788 /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags), 1789 /*transposePaddings=*/b.getArrayAttr(transposePaddings), 1790 /*copyBackOp=*/b.getStringAttr(copyBackOp)); 1791 } 1792 1793 void PadOp::getEffects( 1794 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1795 consumesHandle(getTargetMutable(), effects); 1796 onlyReadsHandle(getPadToMultipleOfMutable(), effects); 1797 producesHandle(getOperation()->getOpResults(), effects); 1798 modifiesPayload(effects); 1799 } 1800 1801 SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() { 1802 Builder b(getContext()); 1803 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b); 1804 } 1805 1806 DiagnosedSilenceableFailure 1807 transform::PadOp::apply(transform::TransformRewriter &rewriter, 1808 transform::TransformResults &results, 1809 transform::TransformState &state) { 1810 auto transformOp = cast<TransformOpInterface>(getOperation()); 1811 SmallVector<Operation *> paddedOps, padOps, copyBackOps; 1812 1813 for (Operation *target : state.getPayloadOps(getTarget())) { 1814 auto linalgTarget = dyn_cast<LinalgOp>(target); 1815 if (!linalgTarget) { 1816 auto diag = emitSilenceableError() << "expected LinalgOp target"; 1817 diag.attachNote(target->getLoc()) << "target op"; 1818 return diag; 1819 } 1820 1821 // Convert the integer packing flags to booleans. 1822 SmallVector<bool> nofoldFlags; 1823 for (int64_t packPadding : 1824 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags())) 1825 nofoldFlags.push_back(static_cast<bool>(packPadding)); 1826 1827 // Convert the padding values to attributes. 1828 SmallVector<Attribute> paddingValues; 1829 for (auto const &it : 1830 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) { 1831 auto attr = dyn_cast<TypedAttr>(std::get<0>(it)); 1832 if (!attr) { 1833 emitOpError("expects padding values to be typed attributes"); 1834 return DiagnosedSilenceableFailure::definiteFailure(); 1835 } 1836 Type elementType = getElementTypeOrSelf(std::get<1>(it)); 1837 // Try to parse string attributes to obtain an attribute of element type. 1838 if (auto stringAttr = dyn_cast<StringAttr>(attr)) { 1839 auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute( 1840 stringAttr, getContext(), elementType, 1841 /*numRead=*/nullptr, /*isKnownNullTerminated=*/true)); 1842 if (!parsedAttr || parsedAttr.getType() != elementType) { 1843 auto diag = this->emitOpError("expects a padding that parses to ") 1844 << elementType << ", got " << std::get<0>(it); 1845 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op"; 1846 return DiagnosedSilenceableFailure::definiteFailure(); 1847 } 1848 paddingValues.push_back(parsedAttr); 1849 continue; 1850 } 1851 // Otherwise, add the attribute directly. 1852 if (attr.getType() != elementType) { 1853 auto diag = this->emitOpError("expects a padding value of type ") 1854 << elementType << ", got " << attr; 1855 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op"; 1856 return DiagnosedSilenceableFailure::definiteFailure(); 1857 } 1858 paddingValues.push_back(attr); 1859 } 1860 1861 // Extract the transpose vectors. 1862 SmallVector<SmallVector<int64_t>> transposePaddings; 1863 for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings())) 1864 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>( 1865 cast<ArrayAttr>(transposeVector))); 1866 1867 LinalgOp paddedOp; 1868 LinalgPaddingOptions options; 1869 options.paddingDimensions = 1870 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions()); 1871 1872 SmallVector<int64_t> padToMultipleOf; 1873 DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults( 1874 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf); 1875 if (!status.succeeded()) 1876 return status; 1877 if (padToMultipleOf.empty()) 1878 padToMultipleOf = 1879 SmallVector<int64_t>(options.paddingDimensions.size(), 1); 1880 1881 options.padToMultipleOf = padToMultipleOf; 1882 options.paddingValues = paddingValues; 1883 options.nofoldFlags = nofoldFlags; 1884 if (getCopyBackOp() == 1885 bufferization::MaterializeInDestinationOp::getOperationName()) { 1886 options.copyBackOp = LinalgPaddingOptions::CopyBackOp:: 1887 BufferizationMaterializeInDestination; 1888 } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) { 1889 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy; 1890 } else if (getCopyBackOp() == kCopyOpNone) { 1891 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None; 1892 } else { 1893 llvm_unreachable("unsupported copy_back op"); 1894 } 1895 1896 SmallVector<Value> replacements; 1897 SmallVector<tensor::PadOp> newPadOps; 1898 if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp, 1899 replacements, newPadOps))) { 1900 auto diag = emitSilenceableError() << "failed to pad op"; 1901 diag.attachNote(target->getLoc()) << "target op"; 1902 return diag; 1903 } 1904 1905 // We need to perform our own replacement here because this API is still 1906 // used in patterns that "pad and hoist", for which the replacement values 1907 // need to be different. 1908 // TODO: clean this up and stop "pad and hoist" behavior more globally now 1909 // that we have more composable abstractions. 1910 rewriter.replaceOp(linalgTarget, replacements); 1911 paddedOps.push_back(paddedOp); 1912 padOps.append(newPadOps.begin(), newPadOps.end()); 1913 if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) { 1914 for (Value v : replacements) { 1915 Operation *copyBackOp = v.getDefiningOp(); 1916 if (!llvm::is_contained(copyBackOps, copyBackOp)) 1917 copyBackOps.push_back(copyBackOp); 1918 } 1919 } 1920 } 1921 1922 results.set(cast<OpResult>(getPadded()), paddedOps); 1923 results.set(cast<OpResult>(getPad()), padOps); 1924 results.set(cast<OpResult>(getCopy()), copyBackOps); 1925 return DiagnosedSilenceableFailure::success(); 1926 } 1927 1928 LogicalResult transform::PadOp::verify() { 1929 SmallVector<int64_t> nofoldFlags = 1930 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()); 1931 if (any_of(nofoldFlags, [](int64_t packPadding) { 1932 return packPadding != 0 && packPadding != 1; 1933 })) { 1934 return emitOpError() 1935 << "expects nofold_flags to contain booleans (0/1), found " 1936 << getNofoldFlags(); 1937 } 1938 1939 SmallVector<int64_t> paddingDimensions = 1940 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions()); 1941 if (any_of(paddingDimensions, 1942 [](int64_t paddingDimension) { return paddingDimension < 0; })) { 1943 return emitOpError() << "expects padding_dimensions to contain positive " 1944 "integers, found " 1945 << getPaddingDimensions(); 1946 } 1947 if (!getMixedPadToMultipleOf().empty()) { 1948 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) { 1949 return emitOpError() << "expects as many multiples as padding_dimensions"; 1950 } 1951 } 1952 ArrayAttr transposes = getTransposePaddings(); 1953 for (Attribute attr : transposes) { 1954 SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr); 1955 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size())); 1956 if (!std::is_permutation(sequence.begin(), sequence.end(), 1957 transpose.begin(), transpose.end())) { 1958 return emitOpError() 1959 << "expects transpose_paddings to be a permutation, found " 1960 << attr; 1961 } 1962 } 1963 if (getCopyBackOp() != 1964 bufferization::MaterializeInDestinationOp::getOperationName() && 1965 getCopyBackOp() != linalg::CopyOp::getOperationName() && 1966 getCopyBackOp() != kCopyOpNone) 1967 return emitOpError() << "invalid copy_back_op"; 1968 return success(); 1969 } 1970 1971 //===---------------------------------------------------------------------===// 1972 // HoistPadOp 1973 //===---------------------------------------------------------------------===// 1974 1975 DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply( 1976 transform::TransformRewriter &rewriter, 1977 transform::TransformResults &transformResults, 1978 transform::TransformState &state) { 1979 auto targetOps = state.getPayloadOps(getTarget()); 1980 auto loopOps = state.getPayloadOps(getLoop()); 1981 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) { 1982 return emitDefiniteFailure() 1983 << "requires exactly one target and one loop handle (got " 1984 << llvm::range_size(targetOps) << " and " 1985 << llvm::range_size(loopOps) << ")"; 1986 } 1987 1988 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin()); 1989 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin()); 1990 if (!padOp || !loopOp) 1991 return emitDefiniteFailure() << "requires exactly 2 non-null handles"; 1992 1993 FailureOr<linalg::detail::PackingResult> result = 1994 linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp, 1995 getTranspose()); 1996 if (failed(result)) 1997 return emitDefiniteFailure() << "could not build packing loop nest"; 1998 1999 if (result->clonedLoopIvs.empty()) { 2000 transformResults.set(cast<OpResult>(getPackingLoop()), 2001 {result->hoistedPadOp.getOperation()}); 2002 return DiagnosedSilenceableFailure::success(); 2003 } 2004 auto outerPackedLoop = 2005 scf::getForInductionVarOwner(result->clonedLoopIvs.front()); 2006 transformResults.set(cast<OpResult>(getPackingLoop()), 2007 {outerPackedLoop.getOperation()}); 2008 return DiagnosedSilenceableFailure::success(); 2009 } 2010 2011 LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() { 2012 ArrayRef<int64_t> transpose = getTranspose(); 2013 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size())); 2014 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(), 2015 transpose.end())) { 2016 return emitOpError() << "expects transpose to be a permutation, found " 2017 << getTranspose(); 2018 } 2019 return success(); 2020 } 2021 2022 void transform::HoistPadBuildPackingLoopNestOp::getEffects( 2023 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2024 transform::onlyReadsHandle(getTargetMutable(), effects); 2025 transform::onlyReadsHandle(getLoopMutable(), effects); 2026 transform::producesHandle(getOperation()->getOpResults(), effects); 2027 transform::modifiesPayload(effects); 2028 } 2029 2030 DiagnosedSilenceableFailure 2031 transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter, 2032 tensor::PadOp target, 2033 transform::ApplyToEachResultList &results, 2034 transform::TransformState &state) { 2035 tensor::PadOp hoistedPadOp; 2036 SmallVector<TransposeOp> transposeOps; 2037 FailureOr<Value> result = 2038 hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(), 2039 hoistedPadOp, transposeOps); 2040 if (succeeded(result)) { 2041 // We need to perform our own replacement here because this API is still 2042 // used in patterns that "pad and hoist", for which the replacement values 2043 // need to be different. 2044 // TODO: clean this up and stop "pad and hoist" behavior more globally now 2045 // that we have more composable abstractions. 2046 rewriter.replaceOp(target, *result); 2047 results.push_back(hoistedPadOp); 2048 return DiagnosedSilenceableFailure::success(); 2049 } 2050 return emitDefaultSilenceableFailure(target); 2051 } 2052 2053 LogicalResult transform::HoistPadOp::verify() { 2054 ArrayRef<int64_t> transpose = getTranspose(); 2055 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size())); 2056 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(), 2057 transpose.end())) { 2058 return emitOpError() << "expects transpose to be a permutation, found " 2059 << getTranspose(); 2060 } 2061 return success(); 2062 } 2063 2064 //===----------------------------------------------------------------------===// 2065 // PromoteOp 2066 //===----------------------------------------------------------------------===// 2067 2068 DiagnosedSilenceableFailure 2069 transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter, 2070 LinalgOp target, 2071 transform::ApplyToEachResultList &results, 2072 transform::TransformState &state) { 2073 LinalgPromotionOptions promotionOptions; 2074 if (!getOperandsToPromote().empty()) 2075 promotionOptions = promotionOptions.setOperandsToPromote( 2076 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote())); 2077 if (getUseFullTilesByDefault()) 2078 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault( 2079 getUseFullTilesByDefault()); 2080 if (getUseAlloca()) 2081 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca()); 2082 if (!getUseFullTileBuffers().empty()) 2083 promotionOptions = promotionOptions.setUseFullTileBuffers( 2084 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>())); 2085 if (getAlignment().has_value()) 2086 promotionOptions = promotionOptions.setAlignment(*getAlignment()); 2087 if (getMemorySpace().has_value()) 2088 promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace()); 2089 2090 if (getMapping().has_value()) { 2091 // The mapping should only contain an element 2092 auto mapping = *getMapping(); 2093 if (mapping.size() > 1) 2094 return emitDefaultDefiniteFailure(target); 2095 2096 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]); 2097 2098 if (addressSpace.getAddressSpace() == 2099 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) { 2100 promotionOptions = 2101 promotionOptions 2102 .setAllocationDeallocationFns(allocateWorkgroupMemory, 2103 deallocateWorkgroupMemory) 2104 .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory) 2105 .setUseFullTileBuffers({false, false}); 2106 } else if (addressSpace.getAddressSpace() == 2107 mlir::gpu::GPUDialect::getPrivateAddressSpace()) { 2108 promotionOptions = 2109 promotionOptions 2110 .setAllocationDeallocationFns(allocateGPUPrivateMemory, 2111 deallocateGPUPrivateMemory) 2112 .setCopyInOutFns(copyToGPUPrivateMemory, copyToGPUPrivateMemory) 2113 .setUseFullTileBuffers({false, false}); 2114 } else { 2115 return emitDefaultDefiniteFailure(target); 2116 } 2117 } 2118 2119 if (failed(promoteSubviewsPrecondition(target, promotionOptions))) 2120 return emitDefaultDefiniteFailure(target); 2121 2122 rewriter.setInsertionPoint(target); 2123 FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions); 2124 if (failed(res)) 2125 return emitDefaultDefiniteFailure(target); 2126 results.push_back(target); 2127 return DiagnosedSilenceableFailure::success(); 2128 } 2129 2130 //===----------------------------------------------------------------------===// 2131 // ReplaceOp 2132 //===----------------------------------------------------------------------===// 2133 2134 DiagnosedSilenceableFailure 2135 transform::ReplaceOp::apply(transform::TransformRewriter &rewriter, 2136 TransformResults &transformResults, 2137 TransformState &state) { 2138 auto payload = state.getPayloadOps(getTarget()); 2139 2140 // Check for invalid targets. 2141 for (Operation *target : payload) { 2142 if (target->getNumOperands() > 0) 2143 return emitDefiniteFailure() << "expected target without operands"; 2144 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() && 2145 target->getNumRegions() > 0) 2146 return emitDefiniteFailure() 2147 << "expected target that is isolated from above"; 2148 } 2149 2150 // Clone and replace. 2151 Operation *pattern = &getBodyRegion().front().front(); 2152 SmallVector<Operation *> replacements; 2153 for (Operation *target : payload) { 2154 if (getOperation()->isAncestor(target)) 2155 continue; 2156 rewriter.setInsertionPoint(target); 2157 Operation *replacement = rewriter.clone(*pattern); 2158 rewriter.replaceOp(target, replacement->getResults()); 2159 replacements.push_back(replacement); 2160 } 2161 transformResults.set(cast<OpResult>(getReplacement()), replacements); 2162 return DiagnosedSilenceableFailure::success(); 2163 } 2164 2165 void transform::ReplaceOp::getEffects( 2166 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2167 consumesHandle(getTargetMutable(), effects); 2168 producesHandle(getOperation()->getOpResults(), effects); 2169 modifiesPayload(effects); 2170 } 2171 2172 LogicalResult transform::ReplaceOp::verify() { 2173 if (!getBodyRegion().hasOneBlock()) 2174 return emitOpError() << "expected one block"; 2175 if (std::distance(getBodyRegion().front().begin(), 2176 getBodyRegion().front().end()) != 1) 2177 return emitOpError() << "expected one operation in block"; 2178 Operation *replacement = &getBodyRegion().front().front(); 2179 if (replacement->getNumOperands() > 0) 2180 return replacement->emitOpError() 2181 << "expected replacement without operands"; 2182 if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() && 2183 replacement->getNumRegions() > 0) 2184 return replacement->emitOpError() 2185 << "expect op that is isolated from above"; 2186 return success(); 2187 } 2188 2189 //===----------------------------------------------------------------------===// 2190 // ScalarizeOp 2191 //===----------------------------------------------------------------------===// 2192 2193 DiagnosedSilenceableFailure 2194 transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter, 2195 LinalgOp target, 2196 transform::ApplyToEachResultList &results, 2197 transform::TransformState &state) { 2198 scf::SCFTilingOptions tilingOptions; 2199 tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) { 2200 SmallVector<OpFoldResult> tileSizes; 2201 Location loc = target.getLoc(); 2202 SmallVector<OpFoldResult> allShapeSizes = 2203 target.createFlatListOfOperandDims(b, loc); 2204 AffineMap map = target.getShapesToLoopsMap(); 2205 if (!map) 2206 return tileSizes; 2207 SmallVector<OpFoldResult> shapeSizes = 2208 affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map, 2209 allShapeSizes); 2210 // If the shape size is dynamic, tile by 1. 2211 // Otherwise, do not tile (i.e. tile size 0). 2212 for (OpFoldResult shapeSize : shapeSizes) { 2213 tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0) 2214 : b.getIndexAttr(1)); 2215 } 2216 return tileSizes; 2217 }); 2218 SmallVector<int64_t> emptyTileSizes; 2219 rewriter.setInsertionPoint(target); 2220 FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF( 2221 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions); 2222 if (failed(maybeTilingResult)) 2223 return emitDefaultDefiniteFailure(target); 2224 2225 if (target->getNumResults()) 2226 rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements); 2227 else 2228 rewriter.eraseOp(target); 2229 2230 results.reserve(maybeTilingResult->tiledOps.size()); 2231 for (Operation *tiled : maybeTilingResult->tiledOps) 2232 results.push_back(tiled); 2233 return DiagnosedSilenceableFailure::success(); 2234 } 2235 2236 //===----------------------------------------------------------------------===// 2237 // ConvertToLoopsOp 2238 //===----------------------------------------------------------------------===// 2239 2240 DiagnosedSilenceableFailure 2241 transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter, 2242 transform::TransformResults &results, 2243 transform::TransformState &state) { 2244 SmallVector<Operation *> loops; 2245 for (Operation *target : state.getPayloadOps(getTarget())) { 2246 auto tilingOp = dyn_cast<TilingInterface>(*target); 2247 if (!tilingOp) { 2248 DiagnosedSilenceableFailure diag = 2249 emitSilenceableError() 2250 << "expected the payload to implement TilingInterface"; 2251 diag.attachNote(target->getLoc()) << "payload op"; 2252 return diag; 2253 } 2254 rewriter.setInsertionPoint(target); 2255 FailureOr<SmallVector<scf::ForOp>> generatedLoops = 2256 scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp); 2257 if (failed(generatedLoops)) 2258 return emitDefaultDefiniteFailure(target); 2259 for (scf::ForOp &loop : *generatedLoops) { 2260 loops.push_back(loop.getOperation()); 2261 } 2262 rewriter.eraseOp(target); 2263 } 2264 results.set(cast<OpResult>(getResult()), loops); 2265 return DiagnosedSilenceableFailure::success(); 2266 } 2267 2268 //===----------------------------------------------------------------------===// 2269 // RewriteInDestinationPassingStyleOp 2270 //===----------------------------------------------------------------------===// 2271 2272 DiagnosedSilenceableFailure 2273 transform::RewriteInDestinationPassingStyleOp::applyToOne( 2274 transform::TransformRewriter &rewriter, Operation *target, 2275 transform::ApplyToEachResultList &results, 2276 transform::TransformState &state) { 2277 SmallVector<Operation *> res; 2278 rewriter.setInsertionPoint(target); 2279 FailureOr<Operation *> maybeResult = 2280 TypeSwitch<Operation *, FailureOr<Operation *>>(target) 2281 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>( 2282 [&rewriter](auto op) { 2283 return rewriteInDestinationPassingStyle(rewriter, op); 2284 }); 2285 if (failed(maybeResult)) 2286 return emitDefaultSilenceableFailure(target); 2287 results.push_back(*maybeResult); 2288 return DiagnosedSilenceableFailure::success(); 2289 } 2290 2291 //===----------------------------------------------------------------------===// 2292 // SplitOp 2293 //===----------------------------------------------------------------------===// 2294 2295 DiagnosedSilenceableFailure 2296 SplitOp::apply(transform::TransformRewriter &rewriter, 2297 TransformResults &results, TransformState &state) { 2298 // Collect the dynamic split points if provided. 2299 SmallVector<Operation *> payload = 2300 llvm::to_vector(state.getPayloadOps(getTarget())); 2301 2302 bool isMultiwaySplit = getMultiway(); 2303 2304 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) { 2305 return mlir::emitSilenceableFailure(getLoc()) 2306 << "requires exactly one target when " 2307 "multiway split is enabled (got " 2308 << llvm::range_size(payload) << ")"; 2309 } 2310 2311 SmallVector<OpFoldResult> chunkSizes; 2312 2313 if (!isMultiwaySplit) 2314 chunkSizes.reserve(payload.size()); 2315 2316 if (getDynamicChunkSizes()) { 2317 auto diag = DiagnosedSilenceableFailure::success(); 2318 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) { 2319 chunkSizes = llvm::to_vector(llvm::map_range( 2320 state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) { 2321 if (op->getNumResults() != 1 || 2322 !op->getResult(0).getType().isIndex()) { 2323 diag = emitSilenceableError() 2324 << "expected dynamic split point handle to point to a " 2325 "single-result index-typed op"; 2326 diag.attachNote(op->getLoc()) << "dynamic split point"; 2327 } 2328 return OpFoldResult(op->getResult(0)); 2329 })); 2330 } else { 2331 chunkSizes = llvm::to_vector( 2332 llvm::map_range(state.getParams(getDynamicChunkSizes()), 2333 [](Attribute attr) { return OpFoldResult(attr); })); 2334 } 2335 if (diag.isSilenceableFailure()) 2336 return diag; 2337 2338 // For multiway split, a single payload is expected to have multiple 2339 // split points. 2340 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) { 2341 return emitDefiniteFailure() 2342 << "expected the dynamic split point handle to point to as " 2343 "many operations (" 2344 << chunkSizes.size() << ") as the target handle (" 2345 << payload.size() << ")"; 2346 } 2347 } else { 2348 chunkSizes.resize(payload.size(), 2349 rewriter.getIndexAttr(getStaticChunkSizes())); 2350 } 2351 2352 auto checkStructuredOpAndDimensions = 2353 [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure { 2354 if (!linalgOp) { 2355 auto diag = emitSilenceableError() << "only applies to structured ops"; 2356 diag.attachNote(loc) << "target op"; 2357 return diag; 2358 } 2359 2360 if (getDimension() >= linalgOp.getNumLoops()) { 2361 auto diag = emitSilenceableError() << "dimension " << getDimension() 2362 << " does not exist in target op"; 2363 diag.attachNote(loc) << "target op"; 2364 return diag; 2365 } 2366 return DiagnosedSilenceableFailure::success(); 2367 }; 2368 2369 auto checkFailureInSplitting = 2370 [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure { 2371 if (hasFailed) { 2372 auto diag = emitDefiniteFailure() << "internal failure in splitting"; 2373 diag.attachNote(loc) << "target op"; 2374 return diag; 2375 } 2376 return DiagnosedSilenceableFailure::success(); 2377 }; 2378 2379 SmallVector<Operation *> opList; 2380 if (isMultiwaySplit) { 2381 2382 // Split a single target operation at multiple points. 2383 TilingInterface head, tail; 2384 Operation *target = payload.front(); 2385 2386 LinalgOp linalgOp = dyn_cast<LinalgOp>(target); 2387 2388 // Check that the target is a valid LinalgOp with correct dimensions. 2389 DiagnosedSilenceableFailure diag = 2390 checkStructuredOpAndDimensions(linalgOp, target->getLoc()); 2391 if (diag.isSilenceableFailure()) 2392 return diag; 2393 2394 for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) { 2395 2396 if (idx > 0) 2397 target = tail.getOperation(); 2398 2399 if (!target) 2400 break; 2401 2402 linalgOp = cast<LinalgOp>(target); 2403 Location loc = target->getLoc(); 2404 2405 rewriter.setInsertionPoint(linalgOp); 2406 std::tie(head, tail) = linalg::splitOp( 2407 rewriter, cast<TilingInterface>(linalgOp.getOperation()), 2408 getDimension(), chunkSize); 2409 2410 // Propagate errors. 2411 DiagnosedSilenceableFailure diag = 2412 checkFailureInSplitting(!head && !tail, loc); 2413 if (diag.isDefiniteFailure()) 2414 return diag; 2415 2416 opList.push_back(head.getOperation()); 2417 } 2418 2419 // Append any leftover parts to the end of the result list. 2420 if (tail) 2421 opList.push_back(tail.getOperation()); 2422 2423 } else { 2424 // Split each target operation. 2425 SmallVector<Operation *> first, second; 2426 Operation *noSecondPart = nullptr; 2427 for (const auto &pair : llvm::zip(payload, chunkSizes)) { 2428 Operation *target = std::get<0>(pair); 2429 Location loc = target->getLoc(); 2430 LinalgOp linalgOp = dyn_cast<LinalgOp>(target); 2431 DiagnosedSilenceableFailure diag = 2432 checkStructuredOpAndDimensions(linalgOp, target->getLoc()); 2433 2434 if (diag.isSilenceableFailure()) 2435 return diag; 2436 2437 rewriter.setInsertionPoint(linalgOp); 2438 std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp( 2439 rewriter, cast<TilingInterface>(linalgOp.getOperation()), 2440 getDimension(), std::get<1>(pair)); 2441 2442 // Propagate errors. 2443 DiagnosedSilenceableFailure diagSplit = 2444 checkFailureInSplitting(!first.back() && !second.back(), loc); 2445 if (diagSplit.isDefiniteFailure()) 2446 return diag; 2447 2448 // Do not add null second parts. 2449 if (!second.back()) { 2450 noSecondPart = target; 2451 second.pop_back(); 2452 } 2453 } 2454 2455 if (second.size() != first.size() && !second.empty()) { 2456 auto diag = emitSilenceableError() 2457 << "splitting does not produce the second part for a subset " 2458 "of targets"; 2459 diag.attachNote() 2460 << "expected splitting to produce the second part of all " 2461 "or none of the targets"; 2462 diag.attachNote(noSecondPart->getLoc()) 2463 << "first target with no second part"; 2464 return diag; 2465 } 2466 2467 opList.append(first); 2468 if (second.size()) 2469 opList.append(second); 2470 } 2471 results.set(cast<OpResult>(getSplitList()), opList); 2472 return DiagnosedSilenceableFailure::success(); 2473 } 2474 2475 void SplitOp::getEffects( 2476 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2477 consumesHandle(getTargetMutable(), effects); 2478 if (getDynamicChunkSizes()) 2479 onlyReadsHandle(getDynamicChunkSizesMutable(), effects); 2480 producesHandle(getOperation()->getOpResults(), effects); 2481 modifiesPayload(effects); 2482 } 2483 2484 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { 2485 OpAsmParser::UnresolvedOperand target, dynamicChunkSizes; 2486 IntegerAttr staticChunkSizes; 2487 if (parser.parseOperand(target) || parser.parseKeyword("after")) 2488 return failure(); 2489 2490 OptionalParseResult dynamicPointParseResult = 2491 parser.parseOptionalOperand(dynamicChunkSizes); 2492 if (!dynamicPointParseResult.has_value()) { 2493 int64_t staticChunkSizesValue; 2494 if (failed(parser.parseInteger(staticChunkSizesValue))) 2495 return failure(); 2496 2497 staticChunkSizes = 2498 parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue); 2499 } 2500 2501 Type targetType; 2502 if (parser.parseOptionalAttrDict(result.attributes) || 2503 parser.parseColonType(targetType) || 2504 parser.resolveOperand(target, targetType, result.operands)) { 2505 return failure(); 2506 } 2507 if (dynamicPointParseResult.has_value()) { 2508 Type ChunkSizesType; 2509 if (failed(*dynamicPointParseResult) || parser.parseComma() || 2510 parser.parseType(ChunkSizesType) || 2511 parser.resolveOperand(dynamicChunkSizes, ChunkSizesType, 2512 result.operands)) { 2513 return failure(); 2514 } 2515 2516 staticChunkSizes = 2517 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic); 2518 } 2519 2520 result.addAttribute( 2521 SplitOp::getStaticChunkSizesAttrName(result.name).getValue(), 2522 staticChunkSizes); 2523 result.addTypes(targetType); 2524 return success(); 2525 } 2526 2527 void SplitOp::print(OpAsmPrinter &printer) { 2528 printer << " " << getTarget() << " after "; 2529 int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes()); 2530 if (staticChunkSize != ShapedType::kDynamic) 2531 printer << staticChunkSize; 2532 else 2533 printer << getDynamicChunkSizes(); 2534 printer << " "; 2535 printer.printOptionalAttrDict(getOperation()->getAttrs(), 2536 {getStaticChunkSizesAttrName()}); 2537 printer << " : " << getTarget().getType(); 2538 if (staticChunkSize == ShapedType::kDynamic) 2539 printer << ", " << getDynamicChunkSizes().getType(); 2540 } 2541 2542 LogicalResult SplitOp::verify() { 2543 if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^ 2544 (getDynamicChunkSizes() == nullptr)) { 2545 return emitOpError() << "expects either a dynamic or a static split " 2546 "point to be provided"; 2547 } 2548 return success(); 2549 } 2550 2551 //===----------------------------------------------------------------------===// 2552 // SplitReductionOp 2553 //===----------------------------------------------------------------------===// 2554 2555 void transform::SplitReductionOp::build( 2556 OpBuilder &builder, OperationState &result, Value target, 2557 int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel, 2558 bool useScalingAlgorithm, bool useAlloc) { 2559 MLIRContext *ctx = builder.getContext(); 2560 result.addOperands(target); 2561 result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name), 2562 builder.getI64IntegerAttr(splitFactor)); 2563 result.addAttribute( 2564 SplitReductionOp::getInsertSplitDimensionAttrName(result.name), 2565 builder.getI64IntegerAttr(insertSplitDimension)); 2566 if (innerParallel) { 2567 result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name), 2568 builder.getUnitAttr()); 2569 } 2570 if (useScalingAlgorithm) { 2571 result.addAttribute( 2572 SplitReductionOp::getUseScalingAlgorithmAttrName(result.name), 2573 builder.getUnitAttr()); 2574 } 2575 if (useAlloc) { 2576 result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name), 2577 builder.getUnitAttr()); 2578 } 2579 auto resultType = transform::AnyOpType::get(ctx); 2580 result.addTypes({resultType, resultType, resultType, resultType}); 2581 } 2582 2583 DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne( 2584 transform::TransformRewriter &rewriter, LinalgOp target, 2585 transform::ApplyToEachResultList &results, 2586 transform::TransformState &state) { 2587 ControlSplitReductionFn splitFn = [&](LinalgOp) { 2588 return linalg::SplitReductionOptions{int64_t(getSplitFactor()), 2589 unsigned(getInsertSplitDimension()), 2590 bool(getInnerParallel())}; 2591 }; 2592 rewriter.setInsertionPoint(target); 2593 FailureOr<SplitReductionResult> splitResult = 2594 (getUseScalingAlgorithm()) 2595 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) 2596 : splitReduction(rewriter, target, splitFn, getUseAlloc()); 2597 if (failed(splitResult)) 2598 return emitDefaultDefiniteFailure(target); 2599 2600 results.push_back(splitResult->initOrAlloc); 2601 results.push_back(splitResult->fillOp); 2602 results.push_back(splitResult->splitLinalgOp); 2603 results.push_back(splitResult->resultCombiningLinalgOp); 2604 return DiagnosedSilenceableFailure::success(); 2605 } 2606 2607 //===----------------------------------------------------------------------===// 2608 // TileReductionUsingForOp 2609 //===----------------------------------------------------------------------===// 2610 2611 void transform::TileReductionUsingForOp::build( 2612 OpBuilder &builder, OperationState &result, Value target, 2613 ArrayRef<int64_t> staticTileSizes) { 2614 // Call the default builder. 2615 // This is future-proof re mixed static-dynamic and setting up the proper 2616 // operands segment sizes attributes for multiple variadic operands. 2617 // In the absence of this, horrible bugs ensue. 2618 // TODO: support mixed static-dynamic (see TileUsingForallOp). 2619 MLIRContext *ctx = builder.getContext(); 2620 auto opTy = transform::AnyOpType::get(ctx); 2621 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); 2622 build(builder, result, 2623 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy}, 2624 /*target=*/target, 2625 /*tile_sizes=*/staticTileSizesAttr); 2626 } 2627 2628 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne( 2629 transform::TransformRewriter &rewriter, Operation *target, 2630 transform::ApplyToEachResultList &results, 2631 transform::TransformState &state) { 2632 rewriter.setInsertionPoint(target); 2633 2634 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target); 2635 if (!partialReductionOp) { 2636 return emitSilenceableFailure( 2637 target->getLoc(), 2638 "Operation should implement PartialReductionOpInterface"); 2639 } 2640 FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf( 2641 rewriter, partialReductionOp, 2642 getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()))); 2643 2644 if (failed(result)) 2645 return emitDefaultSilenceableFailure(target); 2646 rewriter.replaceOp(target, result->mergeResult.replacements); 2647 for (Value initValue : result->initialValues) 2648 results.push_back(initValue.getDefiningOp()); 2649 for (auto parallelTiledOp : result->tiledOps) 2650 results.push_back(parallelTiledOp); 2651 for (auto mergeOp : result->mergeResult.mergeOps) 2652 results.push_back(mergeOp); 2653 results.push_back(result->loops.front()); 2654 return DiagnosedSilenceableFailure::success(); 2655 } 2656 2657 //===----------------------------------------------------------------------===// 2658 // TileReductionUsingForallOp 2659 //===----------------------------------------------------------------------===// 2660 2661 void transform::TileReductionUsingForallOp::build( 2662 OpBuilder &builder, OperationState &result, Value target, 2663 ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes, 2664 ArrayAttr mapping) { 2665 // Call the default builder. 2666 // This is future-proof re mixed static-dynamic and setting up the proper 2667 // operands segment sizes attributes for multiple variadic operands. 2668 // In the absence of this, horrible bugs ensue. 2669 // TODO: support mixed static-dynamic (see TileUsingForallOp). 2670 MLIRContext *ctx = builder.getContext(); 2671 auto opTy = transform::AnyOpType::get(ctx); 2672 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads); 2673 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); 2674 build(builder, result, 2675 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy}, 2676 /*target=*/target, 2677 /*num_threads=*/staticNumThreadsAttr, 2678 /*tile_sizes=*/staticTileSizesAttr, 2679 /*mapping=*/mapping); 2680 } 2681 2682 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne( 2683 transform::TransformRewriter &rewriter, LinalgOp target, 2684 transform::ApplyToEachResultList &results, 2685 transform::TransformState &state) { 2686 rewriter.setInsertionPoint(target); 2687 SmallVector<OpFoldResult> numThreads = 2688 getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads())); 2689 SmallVector<OpFoldResult> tileSizes = 2690 getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())); 2691 FailureOr<linalg::ForallReductionTilingResult> result = 2692 linalg::tileReductionUsingForall( 2693 rewriter, cast<PartialReductionOpInterface>(target.getOperation()), 2694 numThreads, tileSizes, getMapping()); 2695 2696 if (failed(result)) { 2697 auto diag = emitSilenceableError() << "could not tile reduction"; 2698 diag.attachNote(target.getLoc()) << "target operation"; 2699 return diag; 2700 } 2701 for (Value initValue : result->initialValues) 2702 results.push_back(initValue.getDefiningOp()); 2703 for (auto parallelTiledOp : result->parallelTiledOps) 2704 results.push_back(parallelTiledOp); 2705 for (auto mergeOp : result->mergeOps) 2706 results.push_back(mergeOp); 2707 results.push_back(result->loops); 2708 return DiagnosedSilenceableFailure::success(); 2709 } 2710 2711 //===----------------------------------------------------------------------===// 2712 // ContinuousTileSizesOp 2713 //===----------------------------------------------------------------------===// 2714 2715 DiagnosedSilenceableFailure 2716 transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter, 2717 TransformResults &transformResults, 2718 TransformState &state) { 2719 2720 SmallVector<Operation *> targetOps = 2721 llvm::to_vector(state.getPayloadOps(getTarget())); 2722 2723 if (!llvm::hasSingleElement(targetOps)) { 2724 return mlir::emitSilenceableFailure(getLoc()) 2725 << "requires exactly one target (got " << llvm::range_size(targetOps) 2726 << ")"; 2727 } 2728 2729 Operation *target = *targetOps.begin(); 2730 auto linalgOp = dyn_cast<LinalgOp>(target); 2731 auto tileableOp = dyn_cast<TilingInterface>(target); 2732 2733 if (!linalgOp) 2734 return emitDefiniteFailure() << "expected Linalg Op"; 2735 2736 OpBuilder builder(linalgOp.getContext()); 2737 2738 if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) { 2739 if (linalgOp.hasDynamicShape()) { 2740 auto diag = emitSilenceableError() 2741 << "cannot compute parametric tile sizes for dynamically " 2742 "shaped payload op"; 2743 diag.attachNote(linalgOp->getLoc()) << "payload op"; 2744 return diag; 2745 } 2746 2747 FailureOr<StaticContinuousTileSizeSpecification> spec = 2748 computeStaticContinuousTileSizes(linalgOp, getDimension(), 2749 getTargetSize()); 2750 if (failed(spec)) { 2751 return emitSilenceableError() 2752 << "failed to compute multi-size tiling sizes"; 2753 } 2754 2755 SmallVector<int64_t> chunkSizes; 2756 2757 for (auto &&[tileSize, tripCount] : 2758 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) 2759 chunkSizes.push_back(tileSize * tripCount); 2760 2761 auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) { 2762 return llvm::map_to_vector(values, [&](int64_t value) -> Attribute { 2763 return builder.getI64IntegerAttr(value); 2764 }); 2765 }; 2766 transformResults.setParams(cast<OpResult>(getTileSizes()), 2767 getI64AttrsFromI64(spec->tileSizes)); 2768 transformResults.setParams(cast<OpResult>(getChunkSizes()), 2769 getI64AttrsFromI64(chunkSizes)); 2770 2771 return DiagnosedSilenceableFailure::success(); 2772 } 2773 2774 builder.setInsertionPoint(linalgOp); 2775 2776 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize()); 2777 unsigned dimension = getDimension(); 2778 2779 FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes( 2780 builder, tileableOp, dimension, targetSize, true); 2781 if (failed(spec)) { 2782 return emitSilenceableError() << "could not generate tile size computation"; 2783 } 2784 2785 AffineExpr s0 = builder.getAffineSymbolExpr(0); 2786 AffineExpr s1 = builder.getAffineSymbolExpr(1); 2787 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value { 2788 return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr, 2789 ofrs); 2790 }; 2791 2792 SmallVector<Value> chunkSizes; 2793 Value splitPoint; 2794 for (auto &&[tileSize, tripCount] : 2795 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) { 2796 splitPoint = apply(s0 * s1, {tileSize, tripCount}); 2797 chunkSizes.push_back(splitPoint); 2798 } 2799 2800 auto getDefiningOps = [&](ArrayRef<Value> values) { 2801 return llvm::map_to_vector(values, [&](Value value) -> Operation * { 2802 return value.getDefiningOp(); 2803 }); 2804 }; 2805 2806 transformResults.set(cast<OpResult>(getTileSizes()), 2807 getDefiningOps(spec->tileSizes)); 2808 transformResults.set(cast<OpResult>(getChunkSizes()), 2809 getDefiningOps(chunkSizes)); 2810 2811 return DiagnosedSilenceableFailure::success(); 2812 } 2813 2814 LogicalResult transform::ContinuousTileSizesOp::verify() { 2815 2816 if (getTileSizes().getType() != getChunkSizes().getType()) { 2817 return emitOpError() << "expects all results type to be the same"; 2818 } 2819 2820 return success(); 2821 } 2822 2823 void transform::ContinuousTileSizesOp::getEffects( 2824 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2825 if (isa<TransformParamTypeInterface>(getTileSizes().getType())) 2826 onlyReadsPayload(effects); 2827 else 2828 modifiesPayload(effects); 2829 onlyReadsHandle(getTargetMutable(), effects); 2830 producesHandle(getOperation()->getOpResults(), effects); 2831 } 2832 2833 static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, 2834 Type targetType, Type tile_sizes, 2835 Type) { 2836 printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes}); 2837 } 2838 2839 static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, 2840 Type &targetType, 2841 Type &tileSizesType, 2842 Type &chunkSizesType) { 2843 FunctionType funcType; 2844 llvm::SMLoc typeLoc = parser.getCurrentLocation(); 2845 if (failed(parser.parseType<FunctionType>(funcType))) 2846 return failure(); 2847 2848 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) { 2849 parser.emitError(typeLoc) << "expects a trailing functional type with one " 2850 "argument and one result"; 2851 } 2852 targetType = funcType.getInput(0); 2853 tileSizesType = chunkSizesType = funcType.getResult(0); 2854 2855 return success(); 2856 } 2857 2858 //===----------------------------------------------------------------------===// 2859 // TileUsingForOp 2860 //===----------------------------------------------------------------------===// 2861 2862 void transform::TileUsingForOp::build( 2863 OpBuilder &builder, OperationState &result, TypeRange loopTypes, 2864 Value target, ArrayRef<int64_t> staticTileSizes, 2865 ArrayRef<int64_t> interchange, 2866 std::optional<ArrayRef<bool>> scalableSizes) { 2867 return build(builder, result, loopTypes, 2868 /*target=*/target, 2869 /*mixedTileSizes=*/ 2870 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), 2871 interchange, scalableSizes); 2872 } 2873 2874 void transform::TileUsingForOp::build( 2875 OpBuilder &builder, OperationState &result, Value target, 2876 ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange, 2877 std::optional<ArrayRef<bool>> scalableSizes) { 2878 build(builder, result, target, 2879 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), 2880 interchange, scalableSizes); 2881 } 2882 2883 void transform::TileUsingForOp::build( 2884 OpBuilder &builder, OperationState &result, Value target, 2885 ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange, 2886 std::optional<ArrayRef<bool>> scalableSizes) { 2887 // Loop types are automaticaly splat by the callee, setting up one is 2888 // enough. 2889 SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>()); 2890 build(builder, result, loopTypes, target, mixedTileSizes, interchange, 2891 scalableSizes); 2892 } 2893 2894 void transform::TileUsingForOp::build( 2895 OpBuilder &builder, OperationState &result, TypeRange loopTypes, 2896 Value target, ArrayRef<OpFoldResult> mixedTileSizes, 2897 ArrayRef<int64_t> interchange, 2898 std::optional<ArrayRef<bool>> scalableSizes) { 2899 SmallVector<int64_t> staticTileSizes; 2900 SmallVector<Value> dynamicTileSizes; 2901 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); 2902 // Call the default builder which sets up the proper operands segment sizes 2903 // attributes for multiple variadic operands. In the absence of this, 2904 // horrible bugs ensue. 2905 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); 2906 unsigned numExpectedLoops = 2907 staticTileSizes.size() - llvm::count(staticTileSizes, 0); 2908 SmallVector<Type> resultTypes; 2909 resultTypes.reserve(numExpectedLoops); 2910 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) && 2911 "expected one loop type or as many as loops"); 2912 if (loopTypes.size() == 1) 2913 resultTypes.append(numExpectedLoops, loopTypes[0]); 2914 else 2915 llvm::append_range(resultTypes, loopTypes); 2916 SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false); 2917 if (scalableSizes.has_value()) 2918 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end()); 2919 build(builder, result, /*tiled_linalg_op=*/target.getType(), 2920 /*loops=*/resultTypes, 2921 /*target=*/target, 2922 /*dynamic_sizes=*/dynamicTileSizes, 2923 /*static_sizes=*/staticTileSizesAttr, 2924 /*interchange=*/builder.getDenseI64ArrayAttr(interchange), 2925 /*scalable_sizes=*/expandedScalableSizes); 2926 } 2927 2928 LogicalResult transform::TileUsingForOp::verify() { 2929 if (getMixedSizes().size() != getScalableSizes().size()) 2930 return emitOpError("expected same number of sizes (") 2931 << getMixedSizes().size() << ") and scalable sizes (" 2932 << getScalableSizes().size() << ")"; 2933 ArrayRef<int64_t> staticSizes = getStaticSizes(); 2934 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0); 2935 if (getLoops().size() != numExpectedLoops) 2936 return emitOpError("expected number of loops to tile (") 2937 << numExpectedLoops << ") to match number of `loops` results (" 2938 << getLoops().size() << ")"; 2939 return success(); 2940 } 2941 2942 DiagnosedSilenceableFailure 2943 transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter, 2944 TransformResults &transformResults, 2945 TransformState &state) { 2946 ArrayRef<int64_t> tileSizes = getStaticSizes(); 2947 2948 SmallVector<Operation *> targets = 2949 llvm::to_vector(state.getPayloadOps(getTarget())); 2950 SmallVector<SmallVector<Operation *>> dynamicSizeProducers; 2951 SmallVector<SmallVector<int64_t>> paramSizes; 2952 dynamicSizeProducers.reserve(getDynamicSizes().size()); 2953 paramSizes.reserve(getDynamicSizes().size()); 2954 for (Value transformValue : getDynamicSizes()) { 2955 if (isa<ParamType>(transformValue.getType())) { 2956 dynamicSizeProducers.push_back({}); 2957 ArrayRef<Attribute> params = state.getParams(transformValue); 2958 paramSizes.push_back( 2959 llvm::to_vector(llvm::map_range(params, [](Attribute attr) { 2960 return cast<IntegerAttr>(attr).getValue().getSExtValue(); 2961 }))); 2962 2963 if (paramSizes.back().size() != targets.size()) { 2964 DiagnosedSilenceableFailure diag = 2965 emitSilenceableError() 2966 << "expected as many parameter values (" 2967 << dynamicSizeProducers.back().size() << ") as target ops (" 2968 << targets.size() << ")"; 2969 diag.attachNote(transformValue.getLoc()) << "for this parameter"; 2970 return diag; 2971 } 2972 2973 continue; 2974 } 2975 paramSizes.push_back({}); 2976 dynamicSizeProducers.push_back( 2977 llvm::to_vector(state.getPayloadOps(transformValue))); 2978 2979 if (dynamicSizeProducers.back().size() != targets.size()) { 2980 DiagnosedSilenceableFailure diag = 2981 emitSilenceableError() 2982 << "expected as many dynamic size-producing operations (" 2983 << dynamicSizeProducers.back().size() << ") as target ops (" 2984 << targets.size() << ")"; 2985 diag.attachNote(transformValue.getLoc()) << "for this handle"; 2986 return diag; 2987 } 2988 2989 for (Operation *op : dynamicSizeProducers.back()) { 2990 if (op->getNumResults() == 1 && 2991 isa<IndexType>(op->getResult(0).getType())) { 2992 continue; 2993 } 2994 2995 DiagnosedSilenceableFailure diag = 2996 emitSilenceableError() << "expected sizes to be produced by ops " 2997 "with a single index-type result"; 2998 diag.attachNote(op->getLoc()) << "size producer op"; 2999 diag.attachNote(transformValue.getLoc()) << "for this handle"; 3000 return diag; 3001 } 3002 } 3003 3004 SmallVector<Operation *> tiled; 3005 SmallVector<SmallVector<Operation *, 4>, 4> loops; 3006 loops.resize(getLoops().size()); 3007 auto scalableSizes = getScalableSizes(); 3008 for (auto [i, op] : llvm::enumerate(targets)) { 3009 auto tilingInterface = dyn_cast<TilingInterface>(op); 3010 if (!tilingInterface) { 3011 DiagnosedSilenceableFailure diag = 3012 emitSilenceableError() 3013 << "only ops implementing TilingInterface are supported"; 3014 diag.attachNote(op->getLoc()) << "target op"; 3015 return diag; 3016 } 3017 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) { 3018 DiagnosedSilenceableFailure diag = 3019 emitSilenceableError() 3020 << "too many tiles provided, expected at most " 3021 << tilingInterface.getLoopIteratorTypes().size() << " found " 3022 << tileSizes.size(); 3023 diag.attachNote(op->getLoc()) << "target op"; 3024 return diag; 3025 } 3026 3027 scf::SCFTilingOptions tilingOptions; 3028 if (tileSizes.empty()) { 3029 tilingOptions.setTileSizeComputationFunction( 3030 [](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> { 3031 return {}; 3032 }); 3033 } else { 3034 tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b, 3035 Operation *) { 3036 SmallVector<OpFoldResult> sizes; 3037 sizes.reserve(tileSizes.size()); 3038 unsigned dynamicIdx = 0; 3039 3040 for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) { 3041 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) { 3042 if (scalableSizes[ofrIdx]) { 3043 auto val = b.create<arith::ConstantIndexOp>( 3044 getLoc(), cast<IntegerAttr>(attr).getInt()); 3045 Value vscale = 3046 b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType()); 3047 sizes.push_back( 3048 b.create<arith::MulIOp>(getLoc(), val, vscale).getResult()); 3049 } else { 3050 sizes.push_back(attr); 3051 } 3052 continue; 3053 } 3054 ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx]; 3055 ArrayRef<int64_t> params = paramSizes[dynamicIdx]; 3056 ++dynamicIdx; 3057 assert((dynamicSizes.empty() ^ params.empty()) && 3058 "expected either dynamic sizes or parameters"); 3059 if (!params.empty()) { 3060 sizes.push_back(b.getIndexAttr(params[index])); 3061 } else { 3062 sizes.push_back(dynamicSizes[index]->getResult(0)); 3063 } 3064 } 3065 return sizes; 3066 }); 3067 } 3068 3069 tilingOptions.setInterchange(getInterchange()); 3070 FailureOr<scf::SCFTilingResult> maybeTilingResult = 3071 tileUsingSCF(rewriter, tilingInterface, tilingOptions); 3072 if (failed(maybeTilingResult)) 3073 return DiagnosedSilenceableFailure::definiteFailure(); 3074 3075 rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements); 3076 3077 tiled.append(maybeTilingResult->tiledOps); 3078 for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops)) 3079 loops[en2.index()].push_back(en2.value()); 3080 } 3081 3082 transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled); 3083 for (const auto &en : llvm::enumerate(loops)) 3084 transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value()); 3085 3086 return DiagnosedSilenceableFailure::success(); 3087 } 3088 3089 SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() { 3090 ValueRange dynamic = getDynamicSizes(); 3091 ArrayRef<int64_t> tileSizes = getStaticSizes(); 3092 SmallVector<OpFoldResult> results; 3093 results.reserve(tileSizes.size()); 3094 unsigned dynamicPos = 0; 3095 Builder builder(getContext()); 3096 for (int64_t size : tileSizes) { 3097 if (size == ShapedType::kDynamic) { 3098 results.push_back(dynamic[dynamicPos++]); 3099 } else { 3100 results.push_back(builder.getIndexAttr(size)); 3101 } 3102 } 3103 return results; 3104 } 3105 3106 void transform::TileUsingForOp::getEffects( 3107 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 3108 consumesHandle(getTargetMutable(), effects); 3109 onlyReadsHandle(getDynamicSizesMutable(), effects); 3110 producesHandle(getOperation()->getOpResults(), effects); 3111 modifiesPayload(effects); 3112 } 3113 3114 //===----------------------------------------------------------------------===// 3115 // TileUsingForallOp 3116 //===----------------------------------------------------------------------===// 3117 3118 void transform::TileUsingForallOp::build(OpBuilder &builder, 3119 OperationState &result, Value target, 3120 ArrayRef<int64_t> staticTileSizes, 3121 transform::TileSizesSpec, 3122 ArrayAttr mapping) { 3123 return build(builder, result, 3124 /*target=*/target, 3125 /*mixedTileSizes=*/ 3126 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), 3127 /*_=*/TileSizesSpec(), 3128 /*mapping=*/mapping); 3129 } 3130 3131 void transform::TileUsingForallOp::build(OpBuilder &builder, 3132 OperationState &result, Value target, 3133 ArrayRef<OpFoldResult> mixedTileSizes, 3134 transform::TileSizesSpec, 3135 ArrayAttr mapping) { 3136 SmallVector<int64_t> staticTileSizes; 3137 SmallVector<Value> dynamicTileSizes; 3138 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); 3139 // Call the default builder which sets up the proper operands segment sizes 3140 // attributes for multiple variadic operands. In the absence of this, 3141 // horrible bugs ensue. 3142 MLIRContext *ctx = builder.getContext(); 3143 auto operationType = transform::AnyOpType::get(ctx); 3144 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); 3145 build(builder, result, 3146 /*resultTypes=*/TypeRange{operationType, operationType}, 3147 /*target=*/target, 3148 /*num_threads=*/ValueRange{}, 3149 /*tile_sizes=*/dynamicTileSizes, 3150 /*packed_num_threads=*/Value(), 3151 /*packed_tile_sizes=*/Value(), 3152 /*static_num_threads=*/builder.getDenseI64ArrayAttr({}), 3153 /*static_tile_sizes=*/staticTileSizesAttr, 3154 /*mapping=*/mapping); 3155 } 3156 3157 void transform::TileUsingForallOp::build(OpBuilder &builder, 3158 OperationState &result, Value target, 3159 ArrayRef<int64_t> staticNumThreads, 3160 transform::NumThreadsSpec, 3161 ArrayAttr mapping) { 3162 return build(builder, result, target, 3163 getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)), 3164 NumThreadsSpec(), mapping); 3165 } 3166 3167 void transform::TileUsingForallOp::build(OpBuilder &builder, 3168 OperationState &result, Value target, 3169 ArrayRef<OpFoldResult> mixedNumThreads, 3170 transform::NumThreadsSpec, 3171 ArrayAttr mapping) { 3172 SmallVector<int64_t> staticNumThreads; 3173 SmallVector<Value> dynamicNumThreads; 3174 dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, 3175 staticNumThreads); 3176 // Call the default builder which sets up the proper operands segment sizes 3177 // attributes for multiple variadic operands. In the absence of this, 3178 // horrible bugs ensue. 3179 MLIRContext *ctx = builder.getContext(); 3180 auto operationType = transform::AnyOpType::get(ctx); 3181 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads); 3182 build(builder, result, 3183 /*resultTypes=*/TypeRange{operationType, operationType}, 3184 /*target=*/target, 3185 /*num_threads=*/dynamicNumThreads, 3186 /*tile_sizes=*/ValueRange{}, 3187 /*packed_num_threads=*/Value(), 3188 /*packed_tile_sizes=*/Value(), 3189 /*static_num_threads=*/staticNumThreadsAttr, 3190 /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}), 3191 /*mapping=*/mapping); 3192 } 3193 3194 /// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the 3195 /// normalized upper bound. 3196 static SmallVector<OpFoldResult> 3197 normalizeUpperBounds(RewriterBase &rewriter, Location loc, 3198 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, 3199 ArrayRef<OpFoldResult> steps) { 3200 AffineExpr s0, s1, s2; 3201 bindSymbols(rewriter.getContext(), s0, s1, s2); 3202 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2); 3203 SmallVector<OpFoldResult> normalizedUbs; 3204 for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) { 3205 OpFoldResult normalizedUb = affine::makeComposedFoldedAffineApply( 3206 rewriter, loc, normalizedUbExpr, {lb, ub, step}); 3207 normalizedUbs.push_back(normalizedUb); 3208 } 3209 return normalizedUbs; 3210 } 3211 3212 /// When a loop is normalized, the uses of the induction variable within the 3213 /// loop need to replaced with `original_lb + old_iv * original_step`. 3214 static SmallVector<Value> denormalizeIndVar(RewriterBase &rewriter, 3215 Location loc, ValueRange ivs, 3216 ArrayRef<OpFoldResult> lbs, 3217 ArrayRef<OpFoldResult> steps) { 3218 AffineExpr s0, s1; 3219 AffineExpr d0; 3220 bindSymbols(rewriter.getContext(), s0, s1); 3221 bindDims(rewriter.getContext(), d0); 3222 AffineExpr denormExpr = s0 + d0 * s1; 3223 SmallVector<Value> denormalizedIvs; 3224 3225 for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) { 3226 OpFoldResult denormValue = affine::makeComposedFoldedAffineApply( 3227 rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step}); 3228 denormalizedIvs.push_back( 3229 getValueOrCreateConstantIndexOp(rewriter, loc, denormValue)); 3230 } 3231 return denormalizedIvs; 3232 } 3233 3234 /// Given a `scf.forall` loop return a loop op with the loop bounds 3235 /// normalized. 3236 /// TODO: Replace this with a general utility to normalize `scf.forall`. 3237 /// At the time of writing, this wasnt done since adding this to `scf` 3238 /// dialect would disallow using of `affine.apply` operations due 3239 /// to cyclic dependencies. To avoid churn in lit tests 3240 /// with the change this was added with, defer that to a follow up. 3241 static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, 3242 scf::ForallOp loop) { 3243 SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound(); 3244 SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound(); 3245 SmallVector<OpFoldResult> steps = loop.getMixedStep(); 3246 3247 if (llvm::all_of( 3248 lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) && 3249 llvm::all_of( 3250 steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) { 3251 return loop; 3252 } 3253 3254 Location loc = loop.getLoc(); 3255 SmallVector<OpFoldResult> normalizedUbs = 3256 normalizeUpperBounds(rewriter, loc, lbs, ubs, steps); 3257 SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(), 3258 rewriter.getIndexAttr(0)); 3259 SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(), 3260 rewriter.getIndexAttr(1)); 3261 3262 auto normalizedForallOp = rewriter.create<scf::ForallOp>( 3263 loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(), 3264 loop.getMapping(), [](OpBuilder &, Location, ValueRange) {}); 3265 3266 auto normalizedLoopIvs = normalizedForallOp.getInductionVars(); 3267 OpBuilder::InsertionGuard g(rewriter); 3268 Block *normalizedLoopBlock = normalizedForallOp.getBody(); 3269 rewriter.setInsertionPointToStart(normalizedLoopBlock); 3270 3271 SmallVector<Value> argValues = 3272 denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps); 3273 argValues.append(normalizedForallOp.getRegionIterArgs().begin(), 3274 normalizedForallOp.getRegionIterArgs().end()); 3275 Block *origLoopBlock = loop.getBody(); 3276 rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues); 3277 3278 rewriter.replaceOp(loop, normalizedForallOp); 3279 return normalizedForallOp; 3280 } 3281 3282 DiagnosedSilenceableFailure transform::tileToForallOpImpl( 3283 RewriterBase &rewriter, transform::TransformState &state, 3284 TransformOpInterface transformOp, Operation *target, 3285 ArrayRef<OpFoldResult> mixedNumThreads, 3286 ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping, 3287 scf::SCFTilingResult &tilingResult) { 3288 // Transform all targets one by one. 3289 auto tileableOp = dyn_cast<TilingInterface>(target); 3290 if (!tileableOp) { 3291 DiagnosedSilenceableFailure diag = 3292 transformOp.emitSilenceableError() 3293 << "only TilingInterface ops are supported"; 3294 diag.attachNote(target->getLoc()) << "target op"; 3295 return diag; 3296 } 3297 rewriter.setInsertionPoint(tileableOp); 3298 scf::SCFTilingOptions options; 3299 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); 3300 if (!mixedNumThreads.empty()) { 3301 options.setNumThreads(mixedNumThreads); 3302 } else { 3303 options.setTileSizes(mixedTileSizes); 3304 } 3305 if (mapping) { 3306 options.setMapping(mapping.value().getValue()); 3307 } 3308 FailureOr<scf::SCFTilingResult> maybeTilingResult = 3309 scf::tileUsingSCF(rewriter, tileableOp, options); 3310 3311 if (failed(maybeTilingResult)) 3312 return transformOp.emitDefaultSilenceableFailure(tileableOp); 3313 3314 rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements); 3315 3316 tilingResult = *maybeTilingResult; 3317 3318 if (mixedNumThreads.empty()) { 3319 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front()); 3320 OpBuilder::InsertionGuard g(rewriter); 3321 rewriter.setInsertionPoint(generatedForallOp); 3322 scf::ForallOp normalizedForallOp = 3323 normalizeForallLoopOp(rewriter, generatedForallOp); 3324 tilingResult.loops.front() = normalizedForallOp; 3325 } 3326 3327 return DiagnosedSilenceableFailure::success(); 3328 } 3329 3330 DiagnosedSilenceableFailure transform::TileUsingForallOp::apply( 3331 transform::TransformRewriter &rewriter, 3332 transform::TransformResults &transformResults, 3333 transform::TransformState &state) { 3334 auto transformOp = cast<TransformOpInterface>(getOperation()); 3335 3336 // Result payload ops. 3337 SmallVector<Operation *> tileOps; 3338 SmallVector<Operation *> tiledOps; 3339 3340 // Unpack handles. 3341 SmallVector<OpFoldResult> mixedNumThreads; 3342 DiagnosedSilenceableFailure status = 3343 getPackedNumThreads() 3344 ? unpackSingleIndexResultPayloadOperations( 3345 state, transformOp, mixedNumThreads, getPackedNumThreads()) 3346 : unpackSingleIndexResultPayloadOperations( 3347 state, transformOp, mixedNumThreads, getMixedNumThreads()); 3348 if (!status.succeeded()) 3349 return status; 3350 SmallVector<OpFoldResult> mixedTileSizes; 3351 status = getPackedTileSizes() 3352 ? unpackSingleIndexResultPayloadOperations( 3353 state, transformOp, mixedTileSizes, getPackedTileSizes()) 3354 : unpackSingleIndexResultPayloadOperations( 3355 state, transformOp, mixedTileSizes, getMixedTileSizes()); 3356 if (!status.succeeded()) 3357 return status; 3358 3359 for (Operation *target : state.getPayloadOps(getTarget())) { 3360 scf::SCFTilingResult tilingResult; 3361 DiagnosedSilenceableFailure diag = tileToForallOpImpl( 3362 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes, 3363 getMapping(), tilingResult); 3364 if (!diag.succeeded()) 3365 return diag; 3366 tileOps.push_back(tilingResult.loops.front()); 3367 tiledOps.append(tilingResult.tiledOps); 3368 } 3369 3370 transformResults.set(cast<OpResult>(getForallOp()), tileOps); 3371 transformResults.set(cast<OpResult>(getTiledOp()), tiledOps); 3372 3373 return DiagnosedSilenceableFailure::success(); 3374 } 3375 3376 void transform::TileUsingForallOp::getEffects( 3377 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 3378 consumesHandle(getTargetMutable(), effects); 3379 onlyReadsHandle(getTileSizesMutable(), effects); 3380 onlyReadsHandle(getNumThreadsMutable(), effects); 3381 onlyReadsHandle(getPackedNumThreadsMutable(), effects); 3382 onlyReadsHandle(getPackedTileSizesMutable(), effects); 3383 producesHandle(getOperation()->getOpResults(), effects); 3384 modifiesPayload(effects); 3385 } 3386 3387 SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() { 3388 Builder b(getContext()); 3389 return getMixedValues(getStaticNumThreads(), getNumThreads(), b); 3390 } 3391 3392 SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() { 3393 Builder b(getContext()); 3394 return getMixedValues(getStaticTileSizes(), getTileSizes(), b); 3395 } 3396 3397 LogicalResult TileUsingForallOp::verify() { 3398 int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) + 3399 static_cast<int>(getPackedNumThreads() != Value()); 3400 if (numThreadsSpec > 1) 3401 return emitOpError( 3402 "num_threads and packed_num_threads are mutually exclusive"); 3403 int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) + 3404 static_cast<int>(getPackedTileSizes() != Value()); 3405 if (tileSizesSpec > 1) 3406 return emitOpError( 3407 "tile_sizes and packed_tile_sizes are mutually exclusive"); 3408 if (numThreadsSpec == 0 && tileSizesSpec == 0) 3409 return emitOpError("either (packed_)num_threads or (packed_)tile_sizes " 3410 "must be specified"); 3411 return success(); 3412 } 3413 3414 //===----------------------------------------------------------------------===// 3415 // VectorizeChildrenAndApplyPatternsOp 3416 //===----------------------------------------------------------------------===// 3417 3418 void transform::VectorizeChildrenAndApplyPatternsOp::build( 3419 OpBuilder &builder, OperationState &result, Value target, 3420 bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) { 3421 result.addOperands(target); 3422 if (vectorizePadding) { 3423 result.addAttribute( 3424 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName( 3425 result.name), 3426 builder.getUnitAttr()); 3427 } 3428 if (vectorizeExtract) { 3429 result.addAttribute( 3430 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName( 3431 result.name), 3432 builder.getUnitAttr()); 3433 } 3434 if (flatten1DDepthwiseConv) { 3435 result.addAttribute( 3436 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName( 3437 result.name), 3438 builder.getUnitAttr()); 3439 } 3440 result.addTypes(transform::AnyOpType::get(builder.getContext())); 3441 } 3442 3443 namespace { 3444 /// This is an helper only to call vectorize via a pattern inside of 3445 /// VectorizeChildrenAndApplyPatternsOp::applyToOne. 3446 struct VectorizationPattern : public RewritePattern { 3447 explicit VectorizationPattern(MLIRContext *context, 3448 bool vectorizeExtract = false, 3449 bool flattenConv = false) 3450 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), 3451 vectorizeNDExtract(vectorizeExtract), 3452 flatten1DDepthwiseConv(flattenConv) {} 3453 LogicalResult matchAndRewrite(Operation *op, 3454 PatternRewriter &rewriter) const override { 3455 if (!linalg::hasVectorizationImpl(op)) 3456 return rewriter.notifyMatchFailure(op, 3457 "Unsupported Op, cannot vectorize"); 3458 return vectorize(rewriter, op, /*inputVectorSizes=*/{}, 3459 /*inputScalableVecDims=*/{}, vectorizeNDExtract, 3460 flatten1DDepthwiseConv); 3461 } 3462 3463 private: 3464 /// Controls whether to vectorize `tensor.extract` when the input tensor is 3465 /// rank >= 2. 3466 bool vectorizeNDExtract = false; 3467 /// Controls whether to "flatten" the channel dimension when vectorising 1D 3468 /// depthwise convolutions. This should lead to bette vectorization for 3469 /// tensors with a low number of channel dimensions. 3470 bool flatten1DDepthwiseConv = false; 3471 }; 3472 } // namespace 3473 3474 DiagnosedSilenceableFailure 3475 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne( 3476 transform::TransformRewriter &rewriter, Operation *target, 3477 transform::ApplyToEachResultList &results, 3478 transform::TransformState &state) { 3479 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 3480 auto diag = this->emitOpError("requires isolated-from-above targets"); 3481 diag.attachNote(target->getLoc()) << "non-isolated target"; 3482 return DiagnosedSilenceableFailure::definiteFailure(); 3483 } 3484 3485 MLIRContext *ctx = getContext(); 3486 RewritePatternSet patterns(ctx); 3487 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(), 3488 getFlatten_1dDepthwiseConv()); 3489 3490 if (!getDisableTransferPermutationMapLoweringPatterns()) 3491 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 3492 3493 if (!getDisableMultiReductionToContractPatterns()) 3494 vector::populateVectorReductionToContractPatterns(patterns); 3495 3496 vector::populateSinkVectorOpsPatterns(patterns); 3497 3498 patterns.add<linalg::LinalgCopyVTRForwardingPattern, 3499 linalg::LinalgCopyVTWForwardingPattern>(ctx, 3500 /*benefit=*/2); 3501 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); 3502 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); 3503 tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); 3504 3505 patterns.add<CopyVectorizationPattern>(ctx); 3506 3507 // Add misc. vectorization patterns (e.g. for tensor.insert_slice) 3508 linalg::populateInsertSliceVectorizationPatterns(patterns); 3509 3510 if (getVectorizePadding()) { 3511 linalg::populatePadOpVectorizationPatterns(patterns); 3512 // This creates an alternative path for lowering tensor.pad - by 3513 // decomposing it into e.g. linalg.fill. 3514 linalg::populateDecomposePadPatterns(patterns); 3515 } 3516 vector::populateVectorStepLoweringPatterns(patterns); 3517 3518 TrackingListener listener(state, *this); 3519 GreedyRewriteConfig config; 3520 config.listener = &listener; 3521 if (failed(applyPatternsGreedily(target, std::move(patterns), config))) 3522 return emitDefaultDefiniteFailure(target); 3523 3524 results.push_back(target); 3525 return DiagnosedSilenceableFailure::success(); 3526 } 3527 3528 //===----------------------------------------------------------------------===// 3529 // VectorizeOp 3530 //===----------------------------------------------------------------------===// 3531 3532 DiagnosedSilenceableFailure transform::VectorizeOp::apply( 3533 transform::TransformRewriter &rewriter, 3534 mlir::transform::TransformResults &transformResults, 3535 mlir::transform::TransformState &state) { 3536 auto targets = state.getPayloadOps(getTarget()); 3537 if (std::empty(targets)) 3538 return DiagnosedSilenceableFailure::success(); 3539 auto transformOp = cast<TransformOpInterface>(getOperation()); 3540 SmallVector<int64_t> vectorSizes; 3541 DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults( 3542 state, transformOp, getMixedVectorSizes(), vectorSizes); 3543 if (!status.succeeded()) 3544 return status; 3545 3546 // TODO: Check that the correct number of vectorSizes was provided. 3547 for (Operation *target : targets) { 3548 if (!linalg::hasVectorizationImpl(target)) { 3549 return mlir::emitSilenceableFailure(target->getLoc()) 3550 << "Unsupported Op, cannot vectorize"; 3551 } 3552 3553 if (failed(linalg::vectorize(rewriter, target, vectorSizes, 3554 getScalableSizes(), 3555 getVectorizeNdExtract().value_or(false)))) { 3556 return mlir::emitSilenceableFailure(target->getLoc()) 3557 << "Attempted to vectorize, but failed"; 3558 } 3559 } 3560 3561 return DiagnosedSilenceableFailure::success(); 3562 } 3563 3564 void transform::VectorizeOp::getEffects( 3565 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 3566 consumesHandle(getTargetMutable(), effects); 3567 onlyReadsHandle(getVectorSizesMutable(), effects); 3568 modifiesPayload(effects); 3569 } 3570 3571 SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() { 3572 OpBuilder b(getContext()); 3573 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b); 3574 } 3575 3576 LogicalResult transform::VectorizeOp::verify() { 3577 if (getStaticVectorSizes().size() != getScalableSizes().size()) 3578 return emitOpError("expected same number of vector sizes (") 3579 << getStaticVectorSizes().size() << ") and scalable sizes (" 3580 << getScalableSizes().size() << ")"; 3581 return success(); 3582 } 3583 3584 //===----------------------------------------------------------------------===// 3585 // HoistRedundantVectorTransfersOp 3586 //===----------------------------------------------------------------------===// 3587 3588 DiagnosedSilenceableFailure 3589 transform::HoistRedundantVectorTransfersOp::applyToOne( 3590 transform::TransformRewriter &rewriter, func::FuncOp target, 3591 transform::ApplyToEachResultList &results, 3592 transform::TransformState &state) { 3593 // WARNING: This hoisting does not model parallelism and is generally 3594 // incorrect when used on distributed loops with memref semantics! 3595 // TODO: obsolete and should be retired. 3596 linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip()); 3597 results.push_back(target); 3598 return DiagnosedSilenceableFailure::success(); 3599 } 3600 3601 //===----------------------------------------------------------------------===// 3602 // HoistRedundantVectorBroadcastsOp 3603 //===----------------------------------------------------------------------===// 3604 3605 DiagnosedSilenceableFailure 3606 transform::HoistRedundantVectorBroadcastsOp::applyToOne( 3607 transform::TransformRewriter &rewriter, mlir::Operation *target, 3608 transform::ApplyToEachResultList &results, 3609 transform::TransformState &state) { 3610 rewriter.setInsertionPoint(target); 3611 linalg::hoistRedundantVectorBroadcasts(rewriter, target); 3612 results.push_back(target); 3613 return DiagnosedSilenceableFailure::success(); 3614 } 3615 3616 //===----------------------------------------------------------------------===// 3617 // ConvertConv2DToImg2ColOp. 3618 //===----------------------------------------------------------------------===// 3619 3620 DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne( 3621 transform::TransformRewriter &rewriter, linalg::LinalgOp target, 3622 transform::ApplyToEachResultList &results, 3623 transform::TransformState &state) { 3624 rewriter.setInsertionPoint(target); 3625 auto maybeTransformed = 3626 TypeSwitch<Operation *, FailureOr<std::pair<Operation *, Operation *>>>( 3627 target) 3628 .Case([&](linalg::Conv2DNhwcHwcfOp op) { 3629 return rewriteInIm2Col(rewriter, op); 3630 }) 3631 .Case([&](linalg::Conv2DNhwcFhwcOp op) { 3632 return rewriteInIm2Col(rewriter, op); 3633 }) 3634 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) { 3635 return rewriteInIm2Col(rewriter, op); 3636 }) 3637 .Case([&](linalg::Conv2DNchwFchwOp op) { 3638 return rewriteInIm2Col(rewriter, op); 3639 }) 3640 .Default([&](Operation *op) { 3641 return rewriter.notifyMatchFailure(op, "not supported"); 3642 }); 3643 if (failed(maybeTransformed)) 3644 return emitDefaultSilenceableFailure(target); 3645 // Handle to the operation producing the img2col tensor. 3646 results.push_back(maybeTransformed->first); 3647 // Handle to the operation that replaces the original convolution. 3648 results.push_back(maybeTransformed->second); 3649 return DiagnosedSilenceableFailure::success(); 3650 } 3651 3652 //===----------------------------------------------------------------------===// 3653 // FlattenElementwiseLinalgOp. 3654 //===----------------------------------------------------------------------===// 3655 3656 DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne( 3657 transform::TransformRewriter &rewriter, linalg::LinalgOp target, 3658 transform::ApplyToEachResultList &results, 3659 transform::TransformState &state) { 3660 rewriter.setInsertionPoint(target); 3661 if (!isElementwise(target)) 3662 return mlir::emitSilenceableFailure(target->getLoc()) 3663 << "only elementwise flattening is supported"; 3664 3665 // If rank <= 1, do nothing 3666 if (target.getNumLoops() <= 1) { 3667 results.push_back(target); 3668 return DiagnosedSilenceableFailure::success(); 3669 } 3670 3671 // Attempt to flatten all dims to one. 3672 ReassociationIndices reassociation(target.getNumLoops()); 3673 std::iota(reassociation.begin(), reassociation.end(), 0); 3674 auto maybeFlattened = 3675 collapseOpIterationDims(target, reassociation, rewriter); 3676 if (failed(maybeFlattened)) 3677 return mlir::emitSilenceableFailure(target->getLoc()) 3678 << "attempted to flatten, but failed"; 3679 results.push_back(maybeFlattened->collapsedOp); 3680 rewriter.replaceOp(target, maybeFlattened->results); 3681 return DiagnosedSilenceableFailure::success(); 3682 } 3683 3684 //===----------------------------------------------------------------------===// 3685 // TransposeConv2DOp 3686 //===----------------------------------------------------------------------===// 3687 3688 DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne( 3689 transform::TransformRewriter &rewriter, linalg::LinalgOp target, 3690 transform::ApplyToEachResultList &results, 3691 transform::TransformState &state) { 3692 rewriter.setInsertionPoint(target); 3693 auto maybeTransformed = 3694 TypeSwitch<Operation *, FailureOr<Operation *>>(target) 3695 .Case([&](linalg::Conv2DNhwcFhwcOp op) { 3696 return transposeConv2D(rewriter, op); 3697 }) 3698 .Case([&](linalg::Conv2DNhwcFhwcQOp op) { 3699 return transposeConv2D(rewriter, op); 3700 }) 3701 .Default([&](Operation *op) { 3702 return rewriter.notifyMatchFailure(op, "not supported"); 3703 }); 3704 if (failed(maybeTransformed)) 3705 return emitDefaultSilenceableFailure(target); 3706 // Handle to the new Conv2D operation with transposed filters 3707 results.push_back(*maybeTransformed); 3708 return DiagnosedSilenceableFailure::success(); 3709 } 3710 3711 //===----------------------------------------------------------------------===// 3712 // TransposeMatmulOp 3713 //===----------------------------------------------------------------------===// 3714 3715 DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne( 3716 transform::TransformRewriter &rewriter, linalg::LinalgOp target, 3717 transform::ApplyToEachResultList &results, 3718 transform::TransformState &state) { 3719 rewriter.setInsertionPoint(target); 3720 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs; 3721 auto maybeTransformed = 3722 TypeSwitch<Operation *, FailureOr<Operation *>>(target) 3723 .Case([&](linalg::MatmulOp op) { 3724 return transposeMatmul(rewriter, op, transposeLHS); 3725 }) 3726 .Case([&](linalg::BatchMatmulOp op) { 3727 return transposeBatchMatmul(rewriter, op, transposeLHS); 3728 }) 3729 .Default([&](Operation *op) { return failure(); }); 3730 if (failed(maybeTransformed)) 3731 return emitSilenceableFailure(target->getLoc()) << "not supported"; 3732 // Handle to the new Matmul operation with transposed filters 3733 results.push_back(*maybeTransformed); 3734 return DiagnosedSilenceableFailure::success(); 3735 } 3736 3737 //===----------------------------------------------------------------------===// 3738 // InsertSliceToCopyOp 3739 //===----------------------------------------------------------------------===// 3740 template <typename OpTy> 3741 DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, 3742 transform::ApplyToEachResultList &results, 3743 transform::TransformState &state) { 3744 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp, 3745 tensor::ParallelInsertSliceOp>() && 3746 "wrong op type"); 3747 3748 if (auto copySource = 3749 target.getSource().template getDefiningOp<linalg::CopyOp>()) { 3750 results.push_back(copySource); 3751 return DiagnosedSilenceableFailure::success(); 3752 } 3753 3754 // If we are inside an InParallel region, temporarily set the insertion point 3755 // outside: only tensor.parallel_insert_slice ops are allowed in there. 3756 if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) { 3757 rewriter.setInsertionPoint( 3758 target->template getParentOfType<scf::InParallelOp>()); 3759 } 3760 3761 Value extracted = rewriter.create<tensor::ExtractSliceOp>( 3762 target.getLoc(), target.getDest(), target.getMixedOffsets(), 3763 target.getMixedSizes(), target.getMixedStrides()); 3764 Value copied = rewriter 3765 .create<linalg::CopyOp>(target.getLoc(), 3766 target.getSource(), extracted) 3767 .getResult(0); 3768 // Reset the insertion point. 3769 rewriter.setInsertionPoint(target); 3770 rewriter.replaceOpWithNewOp<OpTy>( 3771 target, copied, target.getDest(), target.getMixedOffsets(), 3772 target.getMixedSizes(), target.getMixedStrides()); 3773 3774 results.push_back(copied.getDefiningOp()); 3775 return DiagnosedSilenceableFailure::success(); 3776 } 3777 3778 DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne( 3779 transform::TransformRewriter &rewriter, Operation *targetOp, 3780 transform::ApplyToEachResultList &results, 3781 transform::TransformState &state) { 3782 3783 rewriter.setInsertionPoint(targetOp); 3784 if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp)) 3785 return doit(rewriter, target, results, state); 3786 if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp)) 3787 return doit(rewriter, target, results, state); 3788 3789 DiagnosedSilenceableFailure diag = 3790 emitSilenceableError() 3791 << "only InsertSliceOp and ParallelInsertSliceOp ops are supported"; 3792 diag.attachNote(targetOp->getLoc()) << "target op"; 3793 return diag; 3794 } 3795 3796 //===----------------------------------------------------------------------===// 3797 // MapCopyToThreadsOp 3798 //===----------------------------------------------------------------------===// 3799 3800 DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne( 3801 transform::TransformRewriter &rewriter, Operation *target, 3802 transform::ApplyToEachResultList &results, 3803 transform::TransformState &state) { 3804 // Check if the op is supported. 3805 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) { 3806 DiagnosedSilenceableFailure diag = 3807 emitSilenceableError() 3808 << "only linalg.copy and tensor.pad target ops are supported"; 3809 diag.attachNote(target->getLoc()) << "target op"; 3810 return diag; 3811 } 3812 assert(target->getNumResults() == 1 && "expected single result"); 3813 auto resultShapedType = cast<ShapedType>(target->getResult(0).getType()); 3814 if (!resultShapedType.hasStaticShape()) { 3815 DiagnosedSilenceableFailure diag = 3816 emitSilenceableError() 3817 << "only statically sized ops of rank <= 3 are supported"; 3818 diag.attachNote(target->getLoc()) << "target op"; 3819 return diag; 3820 } 3821 3822 // Conservatively set the minimum viable desired bitwidth alignment. 3823 int64_t desiredBitAlignment = getDesiredBitAlignment(); 3824 int64_t eltBitwidth = 3825 resultShapedType.getElementType().getIntOrFloatBitWidth(); 3826 if (desiredBitAlignment % eltBitwidth != 0) { 3827 desiredBitAlignment = eltBitwidth; 3828 } 3829 3830 gpu::CopyMappingInfo mapping( 3831 /*ctx=*/getContext(), 3832 /*totalNumThreads=*/getTotalNumThreads(), 3833 /*alignment=*/desiredBitAlignment, 3834 /*sizes=*/resultShapedType.getShape(), 3835 /*favorPredication=*/false, 3836 /*elementalBitwidth=*/ 3837 resultShapedType.getElementType().getIntOrFloatBitWidth()); 3838 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) { 3839 DiagnosedSilenceableFailure diag = 3840 emitSilenceableError() 3841 << "too few threads to map copy op to threads on the most minor " 3842 "dimension, given alignment and vector size constraints, try " 3843 "smaller tile size of mapping to more threads"; 3844 diag.attachNote(target->getLoc()) << "target op"; 3845 return diag; 3846 } 3847 3848 // OpBuilder only used to compute attributes. 3849 OpBuilder b(getContext()); 3850 scf::SCFTilingResult tilingResult; 3851 DiagnosedSilenceableFailure diag = tileToForallOpImpl( 3852 /*rewriter=*/rewriter, 3853 /*state=*/state, 3854 /*transformOp=*/*this, 3855 /*target=*/target, 3856 /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b), 3857 /*mixedTileSizes=*/ArrayRef<OpFoldResult>{}, 3858 /*mapping=*/b.getArrayAttr(mapping.threadMapping), 3859 /*tilingResult=*/tilingResult); 3860 if (!diag.succeeded()) 3861 return diag; 3862 3863 results.push_back(tilingResult.loops.front()); 3864 for (auto op : tilingResult.tiledOps) 3865 results.push_back(op); 3866 return DiagnosedSilenceableFailure::success(); 3867 } 3868 3869 //===----------------------------------------------------------------------===// 3870 // WinogradConv2DOp 3871 //===----------------------------------------------------------------------===// 3872 3873 DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne( 3874 transform::TransformRewriter &rewriter, linalg::LinalgOp target, 3875 transform::ApplyToEachResultList &results, 3876 transform::TransformState &state) { 3877 rewriter.setInsertionPoint(target); 3878 FailureOr<Operation *> maybeTransformed = failure(); 3879 bool supported = TypeSwitch<Operation *, bool>(target) 3880 .Case([&](linalg::Conv2DNhwcFhwcOp op) { 3881 maybeTransformed = 3882 winogradConv2D(rewriter, op, getM(), getR()); 3883 return true; 3884 }) 3885 .Default([&](Operation *op) { return false; }); 3886 3887 if (!supported) { 3888 return emitSilenceableError() 3889 << "this operation is not supported to convert to Winograd Conv2D"; 3890 } 3891 3892 if (failed(maybeTransformed)) { 3893 return emitSilenceableError() << "apply Winograd Conv2D failed"; 3894 } 3895 3896 results.push_back(*maybeTransformed); 3897 return DiagnosedSilenceableFailure::success(); 3898 } 3899 3900 DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne( 3901 transform::TransformRewriter &rewriter, Operation *target, 3902 transform::ApplyToEachResultList &results, 3903 transform::TransformState &state) { 3904 rewriter.setInsertionPoint(target); 3905 FailureOr<Operation *> maybeTransformed = failure(); 3906 bool supported = 3907 TypeSwitch<Operation *, bool>(target) 3908 .Case([&](linalg::WinogradFilterTransformOp op) { 3909 maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op); 3910 return true; 3911 }) 3912 .Case([&](linalg::WinogradInputTransformOp op) { 3913 maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op); 3914 return true; 3915 }) 3916 .Case([&](linalg::WinogradOutputTransformOp op) { 3917 maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op); 3918 return true; 3919 }) 3920 .Default([&](Operation *op) { return false; }); 3921 3922 if (!supported) { 3923 DiagnosedSilenceableFailure diag = 3924 emitSilenceableError() 3925 << "this operation is not supported to decompose into other operations"; 3926 diag.attachNote(target->getLoc()) << "target op"; 3927 return diag; 3928 } 3929 3930 if (failed(maybeTransformed)) { 3931 DiagnosedSilenceableFailure diag = 3932 emitSilenceableError() << "decompose Winograd operations failed"; 3933 diag.attachNote(target->getLoc()) << "target op"; 3934 return diag; 3935 } 3936 3937 results.push_back(*maybeTransformed); 3938 return DiagnosedSilenceableFailure::success(); 3939 } 3940 3941 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" 3942 3943 #define GET_OP_CLASSES 3944 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" 3945