xref: /llvm-project/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp (revision 884221eddb9d395830704fac79fd04008e02e368)
1 //===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h"
10 #include "mlir/Analysis/SliceAnalysis.h"
11 #include "mlir/Dialect/Linalg/IR/Linalg.h"
12 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
13 #include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
16 #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/Interfaces/FunctionImplementation.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 
22 using namespace mlir;
23 
24 #define DEBUG_TYPE "linalg-transforms"
25 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
26 
27 //===----------------------------------------------------------------------===//
28 // StructuredMatchOp
29 //===----------------------------------------------------------------------===//
30 
31 DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
32     Operation *current, transform::TransformResults &results,
33     transform::TransformState &state) {
34   // First, check if the payload operation is a structured Linalg operation.
35   if (!isa<linalg::LinalgOp>(current)) {
36     if (getFailurePropagationMode().value_or(
37             FailurePropagationMode::Propagate) ==
38         FailurePropagationMode::Propagate) {
39       return emitSilenceableError() << "expected a Linalg op";
40     }
41     // If errors are suppressed, succeed and set all results to empty lists.
42     LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
43     results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
44     return DiagnosedSilenceableFailure::success();
45   }
46 
47   // Bind `current` to the block argument.
48   auto scope = state.make_region_scope(getBodyRegion());
49   if (failed(state.mapBlockArgument(getBody()->getArgument(0),
50                                     MappedValue(current)))) {
51     return DiagnosedSilenceableFailure::definiteFailure();
52   }
53 
54   for (Operation &nested : getBody()->without_terminator()) {
55     DiagnosedSilenceableFailure diag =
56         state.applyTransform(cast<TransformOpInterface>(nested));
57     if (diag.isDefiniteFailure())
58       return diag;
59     if (diag.succeeded())
60       continue;
61 
62     // If propagating errors, do this immediately.
63     assert(diag.isSilenceableFailure());
64     if (getFailurePropagationMode().value_or(
65             FailurePropagationMode::Propagate) ==
66         FailurePropagationMode::Propagate) {
67       return diag;
68     }
69 
70     // If suppressing errors, print the message into the debug stream before
71     // silencing it. Then set all results value that are already known.
72     // Results come from the terminator operands, which may be defined in the
73     // (single) block of this operation or above it. When they are defined
74     // above, they are known to be mapped at this point per SSA dominance.
75     // When they are defined in this block, we additionally check if we have
76     // already applied the operation that defines them. If not, the
77     // corresponding results will be set to empty lists.
78     LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
79                       << "\n");
80     (void)diag.silence();
81     SmallVector<OpOperand *> undefinedOperands;
82     for (OpOperand &terminatorOperand :
83          getBody()->getTerminator()->getOpOperands()) {
84       Operation *definingOp = terminatorOperand.get().getDefiningOp();
85       if (!definingOp)
86         continue;
87       if (definingOp->getBlock() != getBody())
88         continue;
89       if (definingOp->isBeforeInBlock(&nested))
90         continue;
91 
92       undefinedOperands.push_back(&terminatorOperand);
93     }
94 
95     SmallVector<SmallVector<transform::MappedValue>> mappings;
96     auto filtered = llvm::make_filter_range(
97         getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) {
98           return !llvm::is_contained(undefinedOperands, &opOperand);
99         });
100     SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range(
101         filtered, [](OpOperand &opOperand) { return opOperand.get(); }));
102     detail::prepareValueMappings(mappings, definedOperands, state);
103     for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
104       results.setMappedValues(getResults()[operand.getOperandNumber()],
105                               mapping);
106     }
107     results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
108     return DiagnosedSilenceableFailure::success();
109   }
110 
111   // Set the results.
112   detail::forwardTerminatorOperands(getBody(), state, results);
113   return DiagnosedSilenceableFailure::success();
114 }
115 
116 void transform::MatchStructuredOp::getEffects(
117     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
118   onlyReadsHandle(getCurrentMutable(), effects);
119   onlyReadsPayload(effects);
120   producesHandle(getOperation()->getOpResults(), effects);
121 }
122 
123 LogicalResult transform::MatchStructuredOp::verify() {
124   if (getBody()->getNumArguments() != 1)
125     return emitOpError() << "expected one body argument";
126   if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) {
127     return emitOpError() << "expected body argument to implement "
128                             "TransformHandleTypeInterface";
129   }
130   for (Operation &nested : getBody()->without_terminator()) {
131     if (isa<MatchOpInterface>(nested))
132       continue;
133     InFlightDiagnostic diag =
134         emitOpError()
135         << "expects nested operations to implement MatchOpInterface";
136     diag.attachNote(nested.getLoc()) << "offending operation";
137     return diag;
138   }
139   return success();
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // StructuredOpPredicateOpTrait
144 //===----------------------------------------------------------------------===//
145 
146 LogicalResult transform::detail::verifyStructuredOpPredicateOpTrait(
147     Operation *op, Value structuredOpHandle) {
148   if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) {
149     return op->emitOpError() << "expects parent op to be '"
150                              << MatchStructuredOp::getOperationName() << "'";
151   }
152 
153   // Bail out here, let the verifier of the parent complain.
154   Operation *parent = op->getParentOp();
155   if (parent->getNumRegions() < 1 || parent->getRegion(0).empty() ||
156       parent->getRegion(0).front().getNumArguments() < 1)
157     return success();
158 
159   if (structuredOpHandle != parent->getRegion(0).front().getArgument(0)) {
160     return op->emitOpError()
161            << "expected predicate to apply to the surrounding structured op";
162   }
163   return success();
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // MatchStructuredBodyOp
168 //===----------------------------------------------------------------------===//
169 
170 DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
171     Operation *current, transform::TransformResults &results,
172     transform::TransformState &state) {
173   auto linalgOp = cast<linalg::LinalgOp>(current);
174   if (std::optional<uint64_t> position = getReductionPosition()) {
175     SmallVector<Operation *> combinerOps;
176     if (!matchReduction(linalgOp.getRegionOutputArgs(), *position,
177                         combinerOps)) {
178       return emitSilenceableError() << "could not match reduction";
179     }
180     if (combinerOps.size() != 1) {
181       return emitSilenceableError() << "reduction combiner is not a single op";
182     }
183     return DiagnosedSilenceableFailure::success();
184   }
185   if (getPassthrough()) {
186     Block &body = linalgOp->getRegion(0).front();
187     if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) {
188       return emitSilenceableError() << "not a passthrough";
189     }
190     return DiagnosedSilenceableFailure::success();
191   }
192   if (getElementwise()) {
193     if (!isElementwise(linalgOp))
194       return emitSilenceableError() << "not elementwise";
195     return DiagnosedSilenceableFailure::success();
196   }
197   if (std::optional<ArrayAttr> contractionOps = getContraction()) {
198     Block &body = linalgOp->getRegion(0).front();
199     std::string message;
200     llvm::raw_string_ostream os(message);
201     bool result = linalg::detail::isContractionBody(
202         body,
203         [&](Operation *elem, Operation *red) {
204           return elem->getName().getStringRef() ==
205                      cast<StringAttr>((*contractionOps)[0]).getValue() &&
206                  red->getName().getStringRef() ==
207                      cast<StringAttr>((*contractionOps)[1]).getValue();
208         },
209         os);
210     if (result)
211       return DiagnosedSilenceableFailure::success();
212     return emitSilenceableError() << "contraction: " << message;
213   }
214   return emitDefiniteFailure() << "unknown body condition";
215 }
216 
217 LogicalResult transform::MatchStructuredBodyOp::verify() {
218   int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
219                        getElementwise() + getContraction().has_value();
220 
221   if (numOptions > 1) {
222     std::string attributeNames;
223     llvm::raw_string_ostream os(attributeNames);
224     llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(),
225                                                getPassthroughAttrName(),
226                                                getElementwiseAttrName(),
227                                                getContractionAttrName()},
228                           os);
229     return emitOpError() << "only one of {" << attributeNames << "} is allowed";
230   }
231 
232   if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
233     if (contractionAttr->size() != 2) {
234       return emitOpError() << "expects " << getContractionAttrName()
235                            << " to contain two elements";
236     }
237   }
238   return success();
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // MatchStructuredClassifyContractionDimsOp
243 //===----------------------------------------------------------------------===//
244 
245 DiagnosedSilenceableFailure
246 transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
247     Operation *current, transform::TransformResults &results,
248     transform::TransformState &state) {
249   FailureOr<linalg::ContractionDimensions> contractionDims =
250       linalg::inferContractionDims(cast<linalg::LinalgOp>(current));
251   if (failed(contractionDims))
252     return emitSilenceableError() << "could not infer contraction dimensions";
253 
254   MLIRContext *context = current->getContext();
255   Builder builder(context);
256   auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
257     return llvm::to_vector(
258         llvm::map_range(values, [&](unsigned value) -> Attribute {
259           return builder.getI64IntegerAttr(value);
260         }));
261   };
262   results.setParams(cast<OpResult>(getBatch()),
263                     makeI64Attrs(contractionDims->batch));
264   results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
265   results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
266   results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
267   return DiagnosedSilenceableFailure::success();
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // MatchStructuredClassifyConvolutionDimsOp
272 //===----------------------------------------------------------------------===//
273 
274 DiagnosedSilenceableFailure
275 transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
276     Operation *current, transform::TransformResults &results,
277     transform::TransformState &state) {
278   FailureOr<linalg::ConvolutionDimensions> convolutionDims =
279       linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current));
280   if (failed(convolutionDims))
281     return emitSilenceableError() << "could not infer convolution dimensions";
282 
283   MLIRContext *context = current->getContext();
284   Builder builder(context);
285   auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
286     return llvm::to_vector(
287         llvm::map_range(values, [&](unsigned value) -> Attribute {
288           return builder.getI64IntegerAttr(value);
289         }));
290   };
291   results.setParams(cast<OpResult>(getBatch()),
292                     makeI64Attrs(convolutionDims->batch));
293   results.setParams(cast<OpResult>(getOutputImage()),
294                     makeI64Attrs(convolutionDims->outputImage));
295   results.setParams(cast<OpResult>(getOutputChannel()),
296                     makeI64Attrs(convolutionDims->outputChannel));
297   results.setParams(cast<OpResult>(getFilterLoop()),
298                     makeI64Attrs(convolutionDims->filterLoop));
299   results.setParams(cast<OpResult>(getInputChannel()),
300                     makeI64Attrs(convolutionDims->inputChannel));
301   results.setParams(cast<OpResult>(getDepth()),
302                     makeI64Attrs(convolutionDims->depth));
303 
304   auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
305     return llvm::to_vector(
306         llvm::map_range(values, [&](int64_t value) -> Attribute {
307           return builder.getI64IntegerAttr(value);
308         }));
309   };
310   results.setParams(cast<OpResult>(getStrides()),
311                     makeI64AttrsFromI64(convolutionDims->strides));
312   results.setParams(cast<OpResult>(getDilations()),
313                     makeI64AttrsFromI64(convolutionDims->dilations));
314   return DiagnosedSilenceableFailure::success();
315 }
316 
317 //===----------------------------------------------------------------------===//
318 // Utilities for structured match predicates.
319 //===----------------------------------------------------------------------===//
320 
321 /// Checks if all values from `list` are also contained in `reference`. Returns
322 /// a silenceable error with the given message at the given location when it is
323 /// not the case. The error message must contain the "{0}" placeholder that
324 /// will be substituted with the value from `list` that is not contained in
325 /// `reference`.
326 static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference,
327                                                ArrayRef<int64_t> list,
328                                                Location loc,
329                                                const char *message) {
330   for (int64_t value : list) {
331     if (llvm::any_of(reference, [&](unsigned ref) {
332           return static_cast<int64_t>(ref) == value;
333         })) {
334       continue;
335     }
336     return emitSilenceableFailure(loc) << llvm::formatv(message, value);
337   }
338   return DiagnosedSilenceableFailure::success();
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // MatchStructuredDimOp
343 //===----------------------------------------------------------------------===//
344 
345 DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation(
346     Operation *current, transform::TransformResults &results,
347     transform::TransformState &state) {
348   auto linalgOp = cast<linalg::LinalgOp>(current);
349   SmallVector<int64_t> dimensions;
350   DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions);
351   if (!diag.succeeded())
352     return diag;
353 
354   // If asked to check for the kind of dimension, perform the check.
355   if (getParallel() || getReduction()) {
356     SmallVector<unsigned> reference;
357     if (getParallel())
358       linalgOp.getParallelDims(reference);
359     else if (getReduction())
360       linalgOp.getReductionDims(reference);
361 
362     DiagnosedSilenceableFailure diag =
363         containsAll(reference, dimensions, getLoc(),
364                     getParallel() ? "expects dimension #{0} to be parallel"
365                                   : "expects dimension #{0} to be reduction");
366     if (!diag.succeeded())
367       return diag;
368   }
369 
370   // If not capturing, we are done here.
371   if (!getResult())
372     return diag;
373 
374   SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges();
375   Builder builder(current);
376   SmallVector<Attribute> captured = llvm::to_vector(
377       llvm::map_range(dimensions, [&](int64_t dim) -> Attribute {
378         return builder.getI64IntegerAttr(ranges[dim]);
379       }));
380   results.setParams(cast<OpResult>(getResult()), captured);
381   return DiagnosedSilenceableFailure::success();
382 }
383 
384 DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor(
385     linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) {
386   DiagnosedSilenceableFailure diag =
387       expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(),
388                                 getRawDimList(), op.getNumLoops(), dims);
389   if (diag.isSilenceableFailure()) {
390     diag.attachNote(op->getLoc())
391         << "while considering dimensions of this payload operation";
392   }
393   return diag;
394 }
395 
396 LogicalResult transform::MatchStructuredDimOp::verify() {
397   if (getParallel() && getReduction()) {
398     return emitOpError() << "cannot request the same dimension to be both "
399                             "parallel and reduction";
400   }
401   return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
402                                     getIsInverted(), getIsAll());
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // MatchStructuredElementalBitwidthOp
407 //===----------------------------------------------------------------------===//
408 
409 DiagnosedSilenceableFailure
410 transform::MatchStructuredElementalBitwidthOp::matchValue(
411     Value current, transform::TransformResults &results,
412     transform::TransformState &state) {
413   auto setupResult = [&](int64_t bitwidth) {
414     Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth);
415     results.setParams(cast<OpResult>(getResult()), {attr});
416     return DiagnosedSilenceableFailure::success();
417   };
418 
419   Type type = current.getType();
420   if (type.isIntOrFloat())
421     return setupResult(type.getIntOrFloatBitWidth());
422 
423   if (auto shapedType = dyn_cast<ShapedType>(type)) {
424     if (shapedType.getElementType().isIntOrFloat())
425       return setupResult(shapedType.getElementTypeBitWidth());
426   }
427   return emitSilenceableError()
428          << "unsupported type for bitwidth extraction: " << type;
429 }
430 
431 //===----------------------------------------------------------------------===//
432 // MatchStructuredInputOp
433 //===----------------------------------------------------------------------===//
434 
435 DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
436     Operation *current, transform::TransformResults &results,
437     transform::TransformState &state) {
438   auto linalgOp = cast<linalg::LinalgOp>(current);
439   SmallVector<int64_t> positions;
440   DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
441   if (!diag.succeeded())
442     return diag;
443 
444   SmallVector<MappedValue> operandMapping;
445   operandMapping.reserve(positions.size());
446   for (int64_t position : positions) {
447     AffineMap indexingMap =
448         linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
449     if (getPermutation() && !indexingMap.isPermutation()) {
450       return emitSilenceableError() << "the indexing map for input #"
451                                     << position << " is not a permutation";
452     }
453     if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
454       return emitSilenceableError()
455              << "the indexing map for input #" << position
456              << " is not a projected permutation";
457     }
458 
459     // If capture not requested, skip it.
460     if (!getResult())
461       continue;
462 
463     if (isa<AffineMapParamType>(getResult().getType())) {
464       operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
465       continue;
466     }
467 
468     Value operand = linalgOp.getDpsInputOperand(position)->get();
469     if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
470       operandMapping.emplace_back(operand);
471       continue;
472     }
473 
474     Operation *operandProducer = operand.getDefiningOp();
475     if (!operandProducer) {
476       return emitSilenceableError()
477              << "input #" << position << " is not produced by an operation";
478     }
479     operandMapping.emplace_back(operandProducer);
480   }
481   if (getResult())
482     results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
483   return DiagnosedSilenceableFailure::success();
484 }
485 
486 DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor(
487     linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
488   DiagnosedSilenceableFailure diag = expandTargetSpecification(
489       getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
490       op.getNumDpsInputs(), positions);
491   if (diag.isSilenceableFailure()) {
492     diag.attachNote(op->getLoc())
493         << "while considering DPS inputs of this payload operation";
494   }
495   return diag;
496 }
497 
498 /// Verifies a matcher op for structured input or output, specifically the
499 /// attributes specifying the operand positions.
500 template <typename OpTy>
501 LogicalResult verifyStructuredOperandOp(OpTy op) {
502   if (op.getPermutation() && op.getProjectedPermutation()) {
503     return op.emitOpError()
504            << op.getPermutationAttrName() << " and "
505            << op.getProjectedPermutationAttrName() << " are mutually exclusive";
506   }
507   if (op.getRawPositionList().size() > 1 && op.getResult()) {
508     return op.emitOpError()
509            << "cannot bind multiple inputs/inits to the same value";
510   }
511 
512   return success();
513 }
514 
515 LogicalResult transform::MatchStructuredInputOp::verify() {
516   if (failed(verifyStructuredOperandOp(*this)))
517     return failure();
518   return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
519                                     getIsInverted(), getIsAll());
520 }
521 
522 //===----------------------------------------------------------------------===//
523 // MatchStructuredInitOp
524 //===----------------------------------------------------------------------===//
525 
526 DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
527     Operation *current, transform::TransformResults &results,
528     transform::TransformState &state) {
529   auto linalgOp = cast<linalg::LinalgOp>(current);
530   SmallVector<int64_t> positions;
531   DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
532   if (!diag.succeeded())
533     return diag;
534 
535   SmallVector<MappedValue> operandMapping;
536   operandMapping.reserve(positions.size());
537   for (int64_t position : positions) {
538     AffineMap indexingMap =
539         linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
540     if (getPermutation() && !indexingMap.isPermutation()) {
541       return emitSilenceableError() << "the indexing map for output(init) #"
542                                     << position << " is not a permutation";
543     }
544     if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
545       return emitSilenceableError() << "the indexing map for output(init) #"
546                                     << position << " is not a permutation";
547     }
548 
549     // If capture not requested, skip it.
550     if (!getResult())
551       continue;
552 
553     if (isa<AffineMapParamType>(getResult().getType())) {
554       operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
555       continue;
556     }
557 
558     Value operand = linalgOp.getDpsInitOperand(position)->get();
559     if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
560       operandMapping.emplace_back(operand);
561       continue;
562     }
563 
564     Operation *operandProducer = operand.getDefiningOp();
565     if (!operandProducer) {
566       return emitSilenceableError() << "output(init) #" << position
567                                     << " is not produced by an operation";
568     }
569     operandMapping.emplace_back(operandProducer);
570   }
571   if (getResult())
572     results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
573   return DiagnosedSilenceableFailure::success();
574 }
575 
576 DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
577     linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
578   DiagnosedSilenceableFailure diag = expandTargetSpecification(
579       getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
580       op.getNumDpsInits(), positions);
581   if (diag.isSilenceableFailure()) {
582     diag.attachNote(op->getLoc())
583         << "while considering DPS inits (outputs) of this payload operation";
584   }
585   return diag;
586 }
587 
588 LogicalResult transform::MatchStructuredInitOp::verify() {
589   if (failed(verifyStructuredOperandOp(*this)))
590     return failure();
591   return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
592                                     getIsInverted(), getIsAll());
593 }
594 
595 //===----------------------------------------------------------------------===//
596 // MatchStructuredNumInputsOp
597 //===----------------------------------------------------------------------===//
598 
599 DiagnosedSilenceableFailure
600 transform::MatchStructuredNumInputsOp::matchOperation(
601     Operation *current, transform::TransformResults &results,
602     transform::TransformState &state) {
603   auto linalgOp = cast<linalg::LinalgOp>(current);
604   Attribute attr =
605       Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs());
606   results.setParams(cast<OpResult>(getResult()), {attr});
607   return DiagnosedSilenceableFailure::success();
608 }
609 
610 //===----------------------------------------------------------------------===//
611 // MatchStructuredNumInitsOp
612 //===----------------------------------------------------------------------===//
613 
614 DiagnosedSilenceableFailure
615 transform::MatchStructuredNumInitsOp::matchOperation(
616     Operation *current, transform::TransformResults &results,
617     transform::TransformState &state) {
618   auto linalgOp = cast<linalg::LinalgOp>(current);
619   Attribute attr =
620       Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits());
621   results.setParams(cast<OpResult>(getResult()), {attr});
622   return DiagnosedSilenceableFailure::success();
623 }
624 
625 //===----------------------------------------------------------------------===//
626 // MatchStructuredRankOp
627 //===----------------------------------------------------------------------===//
628 
629 DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation(
630     Operation *current, transform::TransformResults &results,
631     transform::TransformState &state) {
632   auto linalgOp = cast<linalg::LinalgOp>(current);
633   int64_t numLoops = linalgOp.getNumLoops();
634   Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops);
635   results.setParams(cast<OpResult>(getRank()), {attr});
636   return DiagnosedSilenceableFailure::success();
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // MatchStructuredResultOp
641 //===----------------------------------------------------------------------===//
642 
643 DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
644     Operation *op, transform::TransformResults &results,
645     transform::TransformState &state) {
646   auto linalgOp = cast<linalg::LinalgOp>(op);
647   int64_t position;
648   DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position);
649   if (!diag.succeeded())
650     return diag;
651 
652   Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
653   if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
654     results.setValues(cast<OpResult>(getResult()), {result});
655     return DiagnosedSilenceableFailure::success();
656   }
657 
658   if (result.getUsers().empty()) {
659     return emitSilenceableError()
660            << "no users of the result #" << getPosition();
661   }
662   Operation *firstUser = *result.getUsers().begin();
663   if (getAny()) {
664     results.set(cast<OpResult>(getResult()), {firstUser});
665     return DiagnosedSilenceableFailure::success();
666   }
667   if (getSingle()) {
668     if (!llvm::hasSingleElement(result.getUsers())) {
669       return emitSilenceableError()
670              << "more than one result user with single user requested";
671     }
672     results.set(cast<OpResult>(getResult()), {firstUser});
673     return DiagnosedSilenceableFailure::success();
674   }
675 
676   return emitDefiniteFailure() << "unknown sub-predicate";
677 }
678 
679 DiagnosedSilenceableFailure
680 transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
681                                                    int64_t &position) {
682   auto rawPosition = static_cast<int64_t>(getPosition());
683   position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
684   if (position >= op.getNumDpsInits() || position < 0) {
685     return emitSilenceableError()
686            << "position " << rawPosition
687            << " overflows the number of results(ints) of the payload operation";
688   }
689   return DiagnosedSilenceableFailure::success();
690 }
691 
692 LogicalResult transform::MatchStructuredResultOp::verify() {
693   if ((getAny() || getSingle()) ^
694       isa<TransformHandleTypeInterface>(getResult().getType())) {
695     return emitOpError() << "expects either the any/single keyword or the type "
696                             "value handle result type";
697   }
698   if (getAny() && getSingle()) {
699     return emitOpError() << "'any' and 'single' are mutually exclusive";
700   }
701   return success();
702 }
703 
704 //===----------------------------------------------------------------------===//
705 // MatchStructuredYieldOp
706 //===----------------------------------------------------------------------===//
707 
708 void transform::MatchStructuredYieldOp::getEffects(
709     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
710   onlyReadsHandle(getHandlesMutable(), effects);
711   onlyReadsPayload(effects);
712 }
713 
714 void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
715                                               OperationState &state) {
716   build(builder, state, ValueRange());
717 }
718 
719 #define GET_OP_CLASSES
720 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
721