1 //===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h" 10 #include "mlir/Analysis/SliceAnalysis.h" 11 #include "mlir/Dialect/Linalg/IR/Linalg.h" 12 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" 13 #include "mlir/Dialect/Linalg/TransformOps/Syntax.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Transform/IR/TransformTypes.h" 16 #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" 17 #include "mlir/IR/BuiltinAttributes.h" 18 #include "mlir/Interfaces/FunctionImplementation.h" 19 #include "llvm/Support/Debug.h" 20 #include "llvm/Support/FormatVariadic.h" 21 22 using namespace mlir; 23 24 #define DEBUG_TYPE "linalg-transforms" 25 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 26 27 //===----------------------------------------------------------------------===// 28 // StructuredMatchOp 29 //===----------------------------------------------------------------------===// 30 31 DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation( 32 Operation *current, transform::TransformResults &results, 33 transform::TransformState &state) { 34 // First, check if the payload operation is a structured Linalg operation. 35 if (!isa<linalg::LinalgOp>(current)) { 36 if (getFailurePropagationMode().value_or( 37 FailurePropagationMode::Propagate) == 38 FailurePropagationMode::Propagate) { 39 return emitSilenceableError() << "expected a Linalg op"; 40 } 41 // If errors are suppressed, succeed and set all results to empty lists. 42 LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op"); 43 results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation())); 44 return DiagnosedSilenceableFailure::success(); 45 } 46 47 // Bind `current` to the block argument. 48 auto scope = state.make_region_scope(getBodyRegion()); 49 if (failed(state.mapBlockArgument(getBody()->getArgument(0), 50 MappedValue(current)))) { 51 return DiagnosedSilenceableFailure::definiteFailure(); 52 } 53 54 for (Operation &nested : getBody()->without_terminator()) { 55 DiagnosedSilenceableFailure diag = 56 state.applyTransform(cast<TransformOpInterface>(nested)); 57 if (diag.isDefiniteFailure()) 58 return diag; 59 if (diag.succeeded()) 60 continue; 61 62 // If propagating errors, do this immediately. 63 assert(diag.isSilenceableFailure()); 64 if (getFailurePropagationMode().value_or( 65 FailurePropagationMode::Propagate) == 66 FailurePropagationMode::Propagate) { 67 return diag; 68 } 69 70 // If suppressing errors, print the message into the debug stream before 71 // silencing it. Then set all results value that are already known. 72 // Results come from the terminator operands, which may be defined in the 73 // (single) block of this operation or above it. When they are defined 74 // above, they are known to be mapped at this point per SSA dominance. 75 // When they are defined in this block, we additionally check if we have 76 // already applied the operation that defines them. If not, the 77 // corresponding results will be set to empty lists. 78 LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage() 79 << "\n"); 80 (void)diag.silence(); 81 SmallVector<OpOperand *> undefinedOperands; 82 for (OpOperand &terminatorOperand : 83 getBody()->getTerminator()->getOpOperands()) { 84 Operation *definingOp = terminatorOperand.get().getDefiningOp(); 85 if (!definingOp) 86 continue; 87 if (definingOp->getBlock() != getBody()) 88 continue; 89 if (definingOp->isBeforeInBlock(&nested)) 90 continue; 91 92 undefinedOperands.push_back(&terminatorOperand); 93 } 94 95 SmallVector<SmallVector<transform::MappedValue>> mappings; 96 auto filtered = llvm::make_filter_range( 97 getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) { 98 return !llvm::is_contained(undefinedOperands, &opOperand); 99 }); 100 SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range( 101 filtered, [](OpOperand &opOperand) { return opOperand.get(); })); 102 detail::prepareValueMappings(mappings, definedOperands, state); 103 for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) { 104 results.setMappedValues(getResults()[operand.getOperandNumber()], 105 mapping); 106 } 107 results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation())); 108 return DiagnosedSilenceableFailure::success(); 109 } 110 111 // Set the results. 112 detail::forwardTerminatorOperands(getBody(), state, results); 113 return DiagnosedSilenceableFailure::success(); 114 } 115 116 void transform::MatchStructuredOp::getEffects( 117 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 118 onlyReadsHandle(getCurrentMutable(), effects); 119 onlyReadsPayload(effects); 120 producesHandle(getOperation()->getOpResults(), effects); 121 } 122 123 LogicalResult transform::MatchStructuredOp::verify() { 124 if (getBody()->getNumArguments() != 1) 125 return emitOpError() << "expected one body argument"; 126 if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) { 127 return emitOpError() << "expected body argument to implement " 128 "TransformHandleTypeInterface"; 129 } 130 for (Operation &nested : getBody()->without_terminator()) { 131 if (isa<MatchOpInterface>(nested)) 132 continue; 133 InFlightDiagnostic diag = 134 emitOpError() 135 << "expects nested operations to implement MatchOpInterface"; 136 diag.attachNote(nested.getLoc()) << "offending operation"; 137 return diag; 138 } 139 return success(); 140 } 141 142 //===----------------------------------------------------------------------===// 143 // StructuredOpPredicateOpTrait 144 //===----------------------------------------------------------------------===// 145 146 LogicalResult transform::detail::verifyStructuredOpPredicateOpTrait( 147 Operation *op, Value structuredOpHandle) { 148 if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) { 149 return op->emitOpError() << "expects parent op to be '" 150 << MatchStructuredOp::getOperationName() << "'"; 151 } 152 153 // Bail out here, let the verifier of the parent complain. 154 Operation *parent = op->getParentOp(); 155 if (parent->getNumRegions() < 1 || parent->getRegion(0).empty() || 156 parent->getRegion(0).front().getNumArguments() < 1) 157 return success(); 158 159 if (structuredOpHandle != parent->getRegion(0).front().getArgument(0)) { 160 return op->emitOpError() 161 << "expected predicate to apply to the surrounding structured op"; 162 } 163 return success(); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // MatchStructuredBodyOp 168 //===----------------------------------------------------------------------===// 169 170 DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation( 171 Operation *current, transform::TransformResults &results, 172 transform::TransformState &state) { 173 auto linalgOp = cast<linalg::LinalgOp>(current); 174 if (std::optional<uint64_t> position = getReductionPosition()) { 175 SmallVector<Operation *> combinerOps; 176 if (!matchReduction(linalgOp.getRegionOutputArgs(), *position, 177 combinerOps)) { 178 return emitSilenceableError() << "could not match reduction"; 179 } 180 if (combinerOps.size() != 1) { 181 return emitSilenceableError() << "reduction combiner is not a single op"; 182 } 183 return DiagnosedSilenceableFailure::success(); 184 } 185 if (getPassthrough()) { 186 Block &body = linalgOp->getRegion(0).front(); 187 if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) { 188 return emitSilenceableError() << "not a passthrough"; 189 } 190 return DiagnosedSilenceableFailure::success(); 191 } 192 if (getElementwise()) { 193 if (!isElementwise(linalgOp)) 194 return emitSilenceableError() << "not elementwise"; 195 return DiagnosedSilenceableFailure::success(); 196 } 197 if (std::optional<ArrayAttr> contractionOps = getContraction()) { 198 Block &body = linalgOp->getRegion(0).front(); 199 std::string message; 200 llvm::raw_string_ostream os(message); 201 bool result = linalg::detail::isContractionBody( 202 body, 203 [&](Operation *elem, Operation *red) { 204 return elem->getName().getStringRef() == 205 cast<StringAttr>((*contractionOps)[0]).getValue() && 206 red->getName().getStringRef() == 207 cast<StringAttr>((*contractionOps)[1]).getValue(); 208 }, 209 os); 210 if (result) 211 return DiagnosedSilenceableFailure::success(); 212 return emitSilenceableError() << "contraction: " << message; 213 } 214 return emitDefiniteFailure() << "unknown body condition"; 215 } 216 217 LogicalResult transform::MatchStructuredBodyOp::verify() { 218 int64_t numOptions = getReductionPosition().has_value() + getPassthrough() + 219 getElementwise() + getContraction().has_value(); 220 221 if (numOptions > 1) { 222 std::string attributeNames; 223 llvm::raw_string_ostream os(attributeNames); 224 llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(), 225 getPassthroughAttrName(), 226 getElementwiseAttrName(), 227 getContractionAttrName()}, 228 os); 229 return emitOpError() << "only one of {" << attributeNames << "} is allowed"; 230 } 231 232 if (std::optional<ArrayAttr> contractionAttr = getContraction()) { 233 if (contractionAttr->size() != 2) { 234 return emitOpError() << "expects " << getContractionAttrName() 235 << " to contain two elements"; 236 } 237 } 238 return success(); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // MatchStructuredClassifyContractionDimsOp 243 //===----------------------------------------------------------------------===// 244 245 DiagnosedSilenceableFailure 246 transform::MatchStructuredClassifyContractionDimsOp::matchOperation( 247 Operation *current, transform::TransformResults &results, 248 transform::TransformState &state) { 249 FailureOr<linalg::ContractionDimensions> contractionDims = 250 linalg::inferContractionDims(cast<linalg::LinalgOp>(current)); 251 if (failed(contractionDims)) 252 return emitSilenceableError() << "could not infer contraction dimensions"; 253 254 MLIRContext *context = current->getContext(); 255 Builder builder(context); 256 auto makeI64Attrs = [&](ArrayRef<unsigned> values) { 257 return llvm::to_vector( 258 llvm::map_range(values, [&](unsigned value) -> Attribute { 259 return builder.getI64IntegerAttr(value); 260 })); 261 }; 262 results.setParams(cast<OpResult>(getBatch()), 263 makeI64Attrs(contractionDims->batch)); 264 results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m)); 265 results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n)); 266 results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k)); 267 return DiagnosedSilenceableFailure::success(); 268 } 269 270 //===----------------------------------------------------------------------===// 271 // MatchStructuredClassifyConvolutionDimsOp 272 //===----------------------------------------------------------------------===// 273 274 DiagnosedSilenceableFailure 275 transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation( 276 Operation *current, transform::TransformResults &results, 277 transform::TransformState &state) { 278 FailureOr<linalg::ConvolutionDimensions> convolutionDims = 279 linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current)); 280 if (failed(convolutionDims)) 281 return emitSilenceableError() << "could not infer convolution dimensions"; 282 283 MLIRContext *context = current->getContext(); 284 Builder builder(context); 285 auto makeI64Attrs = [&](ArrayRef<unsigned> values) { 286 return llvm::to_vector( 287 llvm::map_range(values, [&](unsigned value) -> Attribute { 288 return builder.getI64IntegerAttr(value); 289 })); 290 }; 291 results.setParams(cast<OpResult>(getBatch()), 292 makeI64Attrs(convolutionDims->batch)); 293 results.setParams(cast<OpResult>(getOutputImage()), 294 makeI64Attrs(convolutionDims->outputImage)); 295 results.setParams(cast<OpResult>(getOutputChannel()), 296 makeI64Attrs(convolutionDims->outputChannel)); 297 results.setParams(cast<OpResult>(getFilterLoop()), 298 makeI64Attrs(convolutionDims->filterLoop)); 299 results.setParams(cast<OpResult>(getInputChannel()), 300 makeI64Attrs(convolutionDims->inputChannel)); 301 results.setParams(cast<OpResult>(getDepth()), 302 makeI64Attrs(convolutionDims->depth)); 303 304 auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) { 305 return llvm::to_vector( 306 llvm::map_range(values, [&](int64_t value) -> Attribute { 307 return builder.getI64IntegerAttr(value); 308 })); 309 }; 310 results.setParams(cast<OpResult>(getStrides()), 311 makeI64AttrsFromI64(convolutionDims->strides)); 312 results.setParams(cast<OpResult>(getDilations()), 313 makeI64AttrsFromI64(convolutionDims->dilations)); 314 return DiagnosedSilenceableFailure::success(); 315 } 316 317 //===----------------------------------------------------------------------===// 318 // Utilities for structured match predicates. 319 //===----------------------------------------------------------------------===// 320 321 /// Checks if all values from `list` are also contained in `reference`. Returns 322 /// a silenceable error with the given message at the given location when it is 323 /// not the case. The error message must contain the "{0}" placeholder that 324 /// will be substituted with the value from `list` that is not contained in 325 /// `reference`. 326 static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference, 327 ArrayRef<int64_t> list, 328 Location loc, 329 const char *message) { 330 for (int64_t value : list) { 331 if (llvm::any_of(reference, [&](unsigned ref) { 332 return static_cast<int64_t>(ref) == value; 333 })) { 334 continue; 335 } 336 return emitSilenceableFailure(loc) << llvm::formatv(message, value); 337 } 338 return DiagnosedSilenceableFailure::success(); 339 } 340 341 //===----------------------------------------------------------------------===// 342 // MatchStructuredDimOp 343 //===----------------------------------------------------------------------===// 344 345 DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation( 346 Operation *current, transform::TransformResults &results, 347 transform::TransformState &state) { 348 auto linalgOp = cast<linalg::LinalgOp>(current); 349 SmallVector<int64_t> dimensions; 350 DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions); 351 if (!diag.succeeded()) 352 return diag; 353 354 // If asked to check for the kind of dimension, perform the check. 355 if (getParallel() || getReduction()) { 356 SmallVector<unsigned> reference; 357 if (getParallel()) 358 linalgOp.getParallelDims(reference); 359 else if (getReduction()) 360 linalgOp.getReductionDims(reference); 361 362 DiagnosedSilenceableFailure diag = 363 containsAll(reference, dimensions, getLoc(), 364 getParallel() ? "expects dimension #{0} to be parallel" 365 : "expects dimension #{0} to be reduction"); 366 if (!diag.succeeded()) 367 return diag; 368 } 369 370 // If not capturing, we are done here. 371 if (!getResult()) 372 return diag; 373 374 SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges(); 375 Builder builder(current); 376 SmallVector<Attribute> captured = llvm::to_vector( 377 llvm::map_range(dimensions, [&](int64_t dim) -> Attribute { 378 return builder.getI64IntegerAttr(ranges[dim]); 379 })); 380 results.setParams(cast<OpResult>(getResult()), captured); 381 return DiagnosedSilenceableFailure::success(); 382 } 383 384 DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor( 385 linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) { 386 DiagnosedSilenceableFailure diag = 387 expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(), 388 getRawDimList(), op.getNumLoops(), dims); 389 if (diag.isSilenceableFailure()) { 390 diag.attachNote(op->getLoc()) 391 << "while considering dimensions of this payload operation"; 392 } 393 return diag; 394 } 395 396 LogicalResult transform::MatchStructuredDimOp::verify() { 397 if (getParallel() && getReduction()) { 398 return emitOpError() << "cannot request the same dimension to be both " 399 "parallel and reduction"; 400 } 401 return verifyTransformMatchDimsOp(getOperation(), getRawDimList(), 402 getIsInverted(), getIsAll()); 403 } 404 405 //===----------------------------------------------------------------------===// 406 // MatchStructuredElementalBitwidthOp 407 //===----------------------------------------------------------------------===// 408 409 DiagnosedSilenceableFailure 410 transform::MatchStructuredElementalBitwidthOp::matchValue( 411 Value current, transform::TransformResults &results, 412 transform::TransformState &state) { 413 auto setupResult = [&](int64_t bitwidth) { 414 Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth); 415 results.setParams(cast<OpResult>(getResult()), {attr}); 416 return DiagnosedSilenceableFailure::success(); 417 }; 418 419 Type type = current.getType(); 420 if (type.isIntOrFloat()) 421 return setupResult(type.getIntOrFloatBitWidth()); 422 423 if (auto shapedType = dyn_cast<ShapedType>(type)) { 424 if (shapedType.getElementType().isIntOrFloat()) 425 return setupResult(shapedType.getElementTypeBitWidth()); 426 } 427 return emitSilenceableError() 428 << "unsupported type for bitwidth extraction: " << type; 429 } 430 431 //===----------------------------------------------------------------------===// 432 // MatchStructuredInputOp 433 //===----------------------------------------------------------------------===// 434 435 DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation( 436 Operation *current, transform::TransformResults &results, 437 transform::TransformState &state) { 438 auto linalgOp = cast<linalg::LinalgOp>(current); 439 SmallVector<int64_t> positions; 440 DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions); 441 if (!diag.succeeded()) 442 return diag; 443 444 SmallVector<MappedValue> operandMapping; 445 operandMapping.reserve(positions.size()); 446 for (int64_t position : positions) { 447 AffineMap indexingMap = 448 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position)); 449 if (getPermutation() && !indexingMap.isPermutation()) { 450 return emitSilenceableError() << "the indexing map for input #" 451 << position << " is not a permutation"; 452 } 453 if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) { 454 return emitSilenceableError() 455 << "the indexing map for input #" << position 456 << " is not a projected permutation"; 457 } 458 459 // If capture not requested, skip it. 460 if (!getResult()) 461 continue; 462 463 if (isa<AffineMapParamType>(getResult().getType())) { 464 operandMapping.emplace_back(AffineMapAttr::get(indexingMap)); 465 continue; 466 } 467 468 Value operand = linalgOp.getDpsInputOperand(position)->get(); 469 if (isa<TransformValueHandleTypeInterface>(getResult().getType())) { 470 operandMapping.emplace_back(operand); 471 continue; 472 } 473 474 Operation *operandProducer = operand.getDefiningOp(); 475 if (!operandProducer) { 476 return emitSilenceableError() 477 << "input #" << position << " is not produced by an operation"; 478 } 479 operandMapping.emplace_back(operandProducer); 480 } 481 if (getResult()) 482 results.setMappedValues(cast<OpResult>(getResult()), operandMapping); 483 return DiagnosedSilenceableFailure::success(); 484 } 485 486 DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor( 487 linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) { 488 DiagnosedSilenceableFailure diag = expandTargetSpecification( 489 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), 490 op.getNumDpsInputs(), positions); 491 if (diag.isSilenceableFailure()) { 492 diag.attachNote(op->getLoc()) 493 << "while considering DPS inputs of this payload operation"; 494 } 495 return diag; 496 } 497 498 /// Verifies a matcher op for structured input or output, specifically the 499 /// attributes specifying the operand positions. 500 template <typename OpTy> 501 LogicalResult verifyStructuredOperandOp(OpTy op) { 502 if (op.getPermutation() && op.getProjectedPermutation()) { 503 return op.emitOpError() 504 << op.getPermutationAttrName() << " and " 505 << op.getProjectedPermutationAttrName() << " are mutually exclusive"; 506 } 507 if (op.getRawPositionList().size() > 1 && op.getResult()) { 508 return op.emitOpError() 509 << "cannot bind multiple inputs/inits to the same value"; 510 } 511 512 return success(); 513 } 514 515 LogicalResult transform::MatchStructuredInputOp::verify() { 516 if (failed(verifyStructuredOperandOp(*this))) 517 return failure(); 518 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), 519 getIsInverted(), getIsAll()); 520 } 521 522 //===----------------------------------------------------------------------===// 523 // MatchStructuredInitOp 524 //===----------------------------------------------------------------------===// 525 526 DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation( 527 Operation *current, transform::TransformResults &results, 528 transform::TransformState &state) { 529 auto linalgOp = cast<linalg::LinalgOp>(current); 530 SmallVector<int64_t> positions; 531 DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions); 532 if (!diag.succeeded()) 533 return diag; 534 535 SmallVector<MappedValue> operandMapping; 536 operandMapping.reserve(positions.size()); 537 for (int64_t position : positions) { 538 AffineMap indexingMap = 539 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position)); 540 if (getPermutation() && !indexingMap.isPermutation()) { 541 return emitSilenceableError() << "the indexing map for output(init) #" 542 << position << " is not a permutation"; 543 } 544 if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) { 545 return emitSilenceableError() << "the indexing map for output(init) #" 546 << position << " is not a permutation"; 547 } 548 549 // If capture not requested, skip it. 550 if (!getResult()) 551 continue; 552 553 if (isa<AffineMapParamType>(getResult().getType())) { 554 operandMapping.emplace_back(AffineMapAttr::get(indexingMap)); 555 continue; 556 } 557 558 Value operand = linalgOp.getDpsInitOperand(position)->get(); 559 if (isa<TransformValueHandleTypeInterface>(getResult().getType())) { 560 operandMapping.emplace_back(operand); 561 continue; 562 } 563 564 Operation *operandProducer = operand.getDefiningOp(); 565 if (!operandProducer) { 566 return emitSilenceableError() << "output(init) #" << position 567 << " is not produced by an operation"; 568 } 569 operandMapping.emplace_back(operandProducer); 570 } 571 if (getResult()) 572 results.setMappedValues(cast<OpResult>(getResult()), operandMapping); 573 return DiagnosedSilenceableFailure::success(); 574 } 575 576 DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor( 577 linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) { 578 DiagnosedSilenceableFailure diag = expandTargetSpecification( 579 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), 580 op.getNumDpsInits(), positions); 581 if (diag.isSilenceableFailure()) { 582 diag.attachNote(op->getLoc()) 583 << "while considering DPS inits (outputs) of this payload operation"; 584 } 585 return diag; 586 } 587 588 LogicalResult transform::MatchStructuredInitOp::verify() { 589 if (failed(verifyStructuredOperandOp(*this))) 590 return failure(); 591 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), 592 getIsInverted(), getIsAll()); 593 } 594 595 //===----------------------------------------------------------------------===// 596 // MatchStructuredNumInputsOp 597 //===----------------------------------------------------------------------===// 598 599 DiagnosedSilenceableFailure 600 transform::MatchStructuredNumInputsOp::matchOperation( 601 Operation *current, transform::TransformResults &results, 602 transform::TransformState &state) { 603 auto linalgOp = cast<linalg::LinalgOp>(current); 604 Attribute attr = 605 Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs()); 606 results.setParams(cast<OpResult>(getResult()), {attr}); 607 return DiagnosedSilenceableFailure::success(); 608 } 609 610 //===----------------------------------------------------------------------===// 611 // MatchStructuredNumInitsOp 612 //===----------------------------------------------------------------------===// 613 614 DiagnosedSilenceableFailure 615 transform::MatchStructuredNumInitsOp::matchOperation( 616 Operation *current, transform::TransformResults &results, 617 transform::TransformState &state) { 618 auto linalgOp = cast<linalg::LinalgOp>(current); 619 Attribute attr = 620 Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits()); 621 results.setParams(cast<OpResult>(getResult()), {attr}); 622 return DiagnosedSilenceableFailure::success(); 623 } 624 625 //===----------------------------------------------------------------------===// 626 // MatchStructuredRankOp 627 //===----------------------------------------------------------------------===// 628 629 DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation( 630 Operation *current, transform::TransformResults &results, 631 transform::TransformState &state) { 632 auto linalgOp = cast<linalg::LinalgOp>(current); 633 int64_t numLoops = linalgOp.getNumLoops(); 634 Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops); 635 results.setParams(cast<OpResult>(getRank()), {attr}); 636 return DiagnosedSilenceableFailure::success(); 637 } 638 639 //===----------------------------------------------------------------------===// 640 // MatchStructuredResultOp 641 //===----------------------------------------------------------------------===// 642 643 DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation( 644 Operation *op, transform::TransformResults &results, 645 transform::TransformState &state) { 646 auto linalgOp = cast<linalg::LinalgOp>(op); 647 int64_t position; 648 DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position); 649 if (!diag.succeeded()) 650 return diag; 651 652 Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position)); 653 if (isa<TransformValueHandleTypeInterface>(getResult().getType())) { 654 results.setValues(cast<OpResult>(getResult()), {result}); 655 return DiagnosedSilenceableFailure::success(); 656 } 657 658 if (result.getUsers().empty()) { 659 return emitSilenceableError() 660 << "no users of the result #" << getPosition(); 661 } 662 Operation *firstUser = *result.getUsers().begin(); 663 if (getAny()) { 664 results.set(cast<OpResult>(getResult()), {firstUser}); 665 return DiagnosedSilenceableFailure::success(); 666 } 667 if (getSingle()) { 668 if (!llvm::hasSingleElement(result.getUsers())) { 669 return emitSilenceableError() 670 << "more than one result user with single user requested"; 671 } 672 results.set(cast<OpResult>(getResult()), {firstUser}); 673 return DiagnosedSilenceableFailure::success(); 674 } 675 676 return emitDefiniteFailure() << "unknown sub-predicate"; 677 } 678 679 DiagnosedSilenceableFailure 680 transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op, 681 int64_t &position) { 682 auto rawPosition = static_cast<int64_t>(getPosition()); 683 position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition; 684 if (position >= op.getNumDpsInits() || position < 0) { 685 return emitSilenceableError() 686 << "position " << rawPosition 687 << " overflows the number of results(ints) of the payload operation"; 688 } 689 return DiagnosedSilenceableFailure::success(); 690 } 691 692 LogicalResult transform::MatchStructuredResultOp::verify() { 693 if ((getAny() || getSingle()) ^ 694 isa<TransformHandleTypeInterface>(getResult().getType())) { 695 return emitOpError() << "expects either the any/single keyword or the type " 696 "value handle result type"; 697 } 698 if (getAny() && getSingle()) { 699 return emitOpError() << "'any' and 'single' are mutually exclusive"; 700 } 701 return success(); 702 } 703 704 //===----------------------------------------------------------------------===// 705 // MatchStructuredYieldOp 706 //===----------------------------------------------------------------------===// 707 708 void transform::MatchStructuredYieldOp::getEffects( 709 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 710 onlyReadsHandle(getHandlesMutable(), effects); 711 onlyReadsPayload(effects); 712 } 713 714 void transform::MatchStructuredYieldOp::build(OpBuilder &builder, 715 OperationState &state) { 716 build(builder, state, ValueRange()); 717 } 718 719 #define GET_OP_CLASSES 720 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc" 721