xref: /llvm-project/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
1 //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
10 
11 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Affine/LoopUtils.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Arith/Utils/Utils.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
19 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
20 #include "mlir/Dialect/SCF/Utils/Utils.h"
21 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
22 #include "mlir/Dialect/Transform/IR/TransformOps.h"
23 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
24 #include "mlir/Dialect/Utils/StaticValueUtils.h"
25 #include "mlir/Dialect/Vector/IR/VectorOps.h"
26 #include "mlir/IR/BuiltinAttributes.h"
27 #include "mlir/IR/Dominance.h"
28 #include "mlir/IR/OpDefinition.h"
29 
30 using namespace mlir;
31 using namespace mlir::affine;
32 
33 //===----------------------------------------------------------------------===//
34 // Apply...PatternsOp
35 //===----------------------------------------------------------------------===//
36 
37 void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns(
38     RewritePatternSet &patterns) {
39   scf::populateSCFForLoopCanonicalizationPatterns(patterns);
40 }
41 
42 void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns(
43     TypeConverter &typeConverter, RewritePatternSet &patterns) {
44   scf::populateSCFStructuralTypeConversions(typeConverter, patterns);
45 }
46 
47 void transform::ApplySCFStructuralConversionPatternsOp::
48     populateConversionTargetRules(const TypeConverter &typeConverter,
49                                   ConversionTarget &conversionTarget) {
50   scf::populateSCFStructuralTypeConversionTarget(typeConverter,
51                                                  conversionTarget);
52 }
53 
54 void transform::ApplySCFToControlFlowPatternsOp::populatePatterns(
55     TypeConverter &typeConverter, RewritePatternSet &patterns) {
56   populateSCFToControlFlowConversionPatterns(patterns);
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // ForallToForOp
61 //===----------------------------------------------------------------------===//
62 
63 DiagnosedSilenceableFailure
64 transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
65                                 transform::TransformResults &results,
66                                 transform::TransformState &state) {
67   auto payload = state.getPayloadOps(getTarget());
68   if (!llvm::hasSingleElement(payload))
69     return emitSilenceableError() << "expected a single payload op";
70 
71   auto target = dyn_cast<scf::ForallOp>(*payload.begin());
72   if (!target) {
73     DiagnosedSilenceableFailure diag =
74         emitSilenceableError() << "expected the payload to be scf.forall";
75     diag.attachNote((*payload.begin())->getLoc()) << "payload op";
76     return diag;
77   }
78 
79   if (!target.getOutputs().empty()) {
80     return emitSilenceableError()
81            << "unsupported shared outputs (didn't bufferize?)";
82   }
83 
84   SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
85 
86   if (getNumResults() != lbs.size()) {
87     DiagnosedSilenceableFailure diag =
88         emitSilenceableError()
89         << "op expects as many results (" << getNumResults()
90         << ") as payload has induction variables (" << lbs.size() << ")";
91     diag.attachNote(target.getLoc()) << "payload op";
92     return diag;
93   }
94 
95   SmallVector<Operation *> opResults;
96   if (failed(scf::forallToForLoop(rewriter, target, &opResults))) {
97     DiagnosedSilenceableFailure diag = emitSilenceableError()
98                                        << "failed to convert forall into for";
99     return diag;
100   }
101 
102   for (auto &&[i, res] : llvm::enumerate(opResults)) {
103     results.set(cast<OpResult>(getTransformed()[i]), {res});
104   }
105   return DiagnosedSilenceableFailure::success();
106 }
107 
108 //===----------------------------------------------------------------------===//
109 // ForallToForOp
110 //===----------------------------------------------------------------------===//
111 
112 DiagnosedSilenceableFailure
113 transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
114                                      transform::TransformResults &results,
115                                      transform::TransformState &state) {
116   auto payload = state.getPayloadOps(getTarget());
117   if (!llvm::hasSingleElement(payload))
118     return emitSilenceableError() << "expected a single payload op";
119 
120   auto target = dyn_cast<scf::ForallOp>(*payload.begin());
121   if (!target) {
122     DiagnosedSilenceableFailure diag =
123         emitSilenceableError() << "expected the payload to be scf.forall";
124     diag.attachNote((*payload.begin())->getLoc()) << "payload op";
125     return diag;
126   }
127 
128   if (!target.getOutputs().empty()) {
129     return emitSilenceableError()
130            << "unsupported shared outputs (didn't bufferize?)";
131   }
132 
133   if (getNumResults() != 1) {
134     DiagnosedSilenceableFailure diag = emitSilenceableError()
135                                        << "op expects one result, given "
136                                        << getNumResults();
137     diag.attachNote(target.getLoc()) << "payload op";
138     return diag;
139   }
140 
141   scf::ParallelOp opResult;
142   if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) {
143     DiagnosedSilenceableFailure diag =
144         emitSilenceableError() << "failed to convert forall into parallel";
145     return diag;
146   }
147 
148   results.set(cast<OpResult>(getTransformed()[0]), {opResult});
149   return DiagnosedSilenceableFailure::success();
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // LoopOutlineOp
154 //===----------------------------------------------------------------------===//
155 
156 /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses
157 /// the provided rewriter for all operations to remain compatible with the
158 /// rewriting infra, as opposed to just splicing the op in place.
159 static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
160                                                 Operation *op) {
161   if (op->getNumRegions() != 1)
162     return nullptr;
163   OpBuilder::InsertionGuard g(b);
164   b.setInsertionPoint(op);
165   scf::ExecuteRegionOp executeRegionOp =
166       b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
167   {
168     OpBuilder::InsertionGuard g(b);
169     b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
170     Operation *clonedOp = b.cloneWithoutRegions(*op);
171     Region &clonedRegion = clonedOp->getRegions().front();
172     assert(clonedRegion.empty() && "expected empty region");
173     b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
174                          clonedRegion.end());
175     b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
176   }
177   b.replaceOp(op, executeRegionOp.getResults());
178   return executeRegionOp;
179 }
180 
181 DiagnosedSilenceableFailure
182 transform::LoopOutlineOp::apply(transform::TransformRewriter &rewriter,
183                                 transform::TransformResults &results,
184                                 transform::TransformState &state) {
185   SmallVector<Operation *> functions;
186   SmallVector<Operation *> calls;
187   DenseMap<Operation *, SymbolTable> symbolTables;
188   for (Operation *target : state.getPayloadOps(getTarget())) {
189     Location location = target->getLoc();
190     Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
191     scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
192     if (!exec) {
193       DiagnosedSilenceableFailure diag = emitSilenceableError()
194                                          << "failed to outline";
195       diag.attachNote(target->getLoc()) << "target op";
196       return diag;
197     }
198     func::CallOp call;
199     FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
200         rewriter, location, exec.getRegion(), getFuncName(), &call);
201 
202     if (failed(outlined))
203       return emitDefaultDefiniteFailure(target);
204 
205     if (symbolTableOp) {
206       SymbolTable &symbolTable =
207           symbolTables.try_emplace(symbolTableOp, symbolTableOp)
208               .first->getSecond();
209       symbolTable.insert(*outlined);
210       call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
211     }
212     functions.push_back(*outlined);
213     calls.push_back(call);
214   }
215   results.set(cast<OpResult>(getFunction()), functions);
216   results.set(cast<OpResult>(getCall()), calls);
217   return DiagnosedSilenceableFailure::success();
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // LoopPeelOp
222 //===----------------------------------------------------------------------===//
223 
224 DiagnosedSilenceableFailure
225 transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
226                                   scf::ForOp target,
227                                   transform::ApplyToEachResultList &results,
228                                   transform::TransformState &state) {
229   scf::ForOp result;
230   if (getPeelFront()) {
231     LogicalResult status =
232         scf::peelForLoopFirstIteration(rewriter, target, result);
233     if (failed(status)) {
234       DiagnosedSilenceableFailure diag =
235           emitSilenceableError() << "failed to peel the first iteration";
236       return diag;
237     }
238   } else {
239     LogicalResult status =
240         scf::peelForLoopAndSimplifyBounds(rewriter, target, result);
241     if (failed(status)) {
242       DiagnosedSilenceableFailure diag = emitSilenceableError()
243                                          << "failed to peel the last iteration";
244       return diag;
245     }
246   }
247 
248   results.push_back(target);
249   results.push_back(result);
250 
251   return DiagnosedSilenceableFailure::success();
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // LoopPipelineOp
256 //===----------------------------------------------------------------------===//
257 
258 /// Callback for PipeliningOption. Populates `schedule` with the mapping from an
259 /// operation to its logical time position given the iteration interval and the
260 /// read latency. The latter is only relevant for vector transfers.
261 static void
262 loopScheduling(scf::ForOp forOp,
263                std::vector<std::pair<Operation *, unsigned>> &schedule,
264                unsigned iterationInterval, unsigned readLatency) {
265   auto getLatency = [&](Operation *op) -> unsigned {
266     if (isa<vector::TransferReadOp>(op))
267       return readLatency;
268     return 1;
269   };
270 
271   std::optional<int64_t> ubConstant =
272       getConstantIntValue(forOp.getUpperBound());
273   std::optional<int64_t> lbConstant =
274       getConstantIntValue(forOp.getLowerBound());
275   DenseMap<Operation *, unsigned> opCycles;
276   std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
277   for (Operation &op : forOp.getBody()->getOperations()) {
278     if (isa<scf::YieldOp>(op))
279       continue;
280     unsigned earlyCycle = 0;
281     for (Value operand : op.getOperands()) {
282       Operation *def = operand.getDefiningOp();
283       if (!def)
284         continue;
285       if (ubConstant && lbConstant) {
286         unsigned ubInt = ubConstant.value();
287         unsigned lbInt = lbConstant.value();
288         auto minLatency = std::min(ubInt - lbInt - 1, getLatency(def));
289         earlyCycle = std::max(earlyCycle, opCycles[def] + minLatency);
290       } else {
291         earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
292       }
293     }
294     opCycles[&op] = earlyCycle;
295     wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
296   }
297   for (const auto &it : wrappedSchedule) {
298     for (Operation *op : it.second) {
299       unsigned cycle = opCycles[op];
300       schedule.emplace_back(op, cycle / iterationInterval);
301     }
302   }
303 }
304 
305 DiagnosedSilenceableFailure
306 transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter,
307                                       scf::ForOp target,
308                                       transform::ApplyToEachResultList &results,
309                                       transform::TransformState &state) {
310   scf::PipeliningOption options;
311   options.getScheduleFn =
312       [this](scf::ForOp forOp,
313              std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
314         loopScheduling(forOp, schedule, getIterationInterval(),
315                        getReadLatency());
316       };
317   scf::ForLoopPipeliningPattern pattern(options, target->getContext());
318   rewriter.setInsertionPoint(target);
319   FailureOr<scf::ForOp> patternResult =
320       scf::pipelineForLoop(rewriter, target, options);
321   if (succeeded(patternResult)) {
322     results.push_back(*patternResult);
323     return DiagnosedSilenceableFailure::success();
324   }
325   return emitDefaultSilenceableFailure(target);
326 }
327 
328 //===----------------------------------------------------------------------===//
329 // LoopPromoteIfOneIterationOp
330 //===----------------------------------------------------------------------===//
331 
332 DiagnosedSilenceableFailure transform::LoopPromoteIfOneIterationOp::applyToOne(
333     transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
334     transform::ApplyToEachResultList &results,
335     transform::TransformState &state) {
336   (void)target.promoteIfSingleIteration(rewriter);
337   return DiagnosedSilenceableFailure::success();
338 }
339 
340 void transform::LoopPromoteIfOneIterationOp::getEffects(
341     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
342   consumesHandle(getTargetMutable(), effects);
343   modifiesPayload(effects);
344 }
345 
346 //===----------------------------------------------------------------------===//
347 // LoopUnrollOp
348 //===----------------------------------------------------------------------===//
349 
350 DiagnosedSilenceableFailure
351 transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
352                                     Operation *op,
353                                     transform::ApplyToEachResultList &results,
354                                     transform::TransformState &state) {
355   LogicalResult result(failure());
356   if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
357     result = loopUnrollByFactor(scfFor, getFactor());
358   else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
359     result = loopUnrollByFactor(affineFor, getFactor());
360   else
361     return emitSilenceableError()
362            << "failed to unroll, incorrect type of payload";
363 
364   if (failed(result))
365     return emitSilenceableError() << "failed to unroll";
366 
367   return DiagnosedSilenceableFailure::success();
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // LoopUnrollAndJamOp
372 //===----------------------------------------------------------------------===//
373 
374 DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
375     transform::TransformRewriter &rewriter, Operation *op,
376     transform::ApplyToEachResultList &results,
377     transform::TransformState &state) {
378   LogicalResult result(failure());
379   if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
380     result = loopUnrollJamByFactor(scfFor, getFactor());
381   else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
382     result = loopUnrollJamByFactor(affineFor, getFactor());
383   else
384     return emitSilenceableError()
385            << "failed to unroll and jam, incorrect type of payload";
386 
387   if (failed(result))
388     return emitSilenceableError() << "failed to unroll and jam";
389 
390   return DiagnosedSilenceableFailure::success();
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // LoopCoalesceOp
395 //===----------------------------------------------------------------------===//
396 
397 DiagnosedSilenceableFailure
398 transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter,
399                                       Operation *op,
400                                       transform::ApplyToEachResultList &results,
401                                       transform::TransformState &state) {
402   LogicalResult result(failure());
403   if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
404     result = coalescePerfectlyNestedSCFForLoops(scfForOp);
405   else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
406     result = coalescePerfectlyNestedAffineLoops(affineForOp);
407 
408   results.push_back(op);
409   if (failed(result)) {
410     DiagnosedSilenceableFailure diag = emitSilenceableError()
411                                        << "failed to coalesce";
412     return diag;
413   }
414   return DiagnosedSilenceableFailure::success();
415 }
416 
417 //===----------------------------------------------------------------------===//
418 // TakeAssumedBranchOp
419 //===----------------------------------------------------------------------===//
420 /// Replaces the given op with the contents of the given single-block region,
421 /// using the operands of the block terminator to replace operation results.
422 static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op,
423                                 Region &region) {
424   assert(llvm::hasSingleElement(region) && "expected single-region block");
425   Block *block = &region.front();
426   Operation *terminator = block->getTerminator();
427   ValueRange results = terminator->getOperands();
428   rewriter.inlineBlockBefore(block, op, /*blockArgs=*/{});
429   rewriter.replaceOp(op, results);
430   rewriter.eraseOp(terminator);
431 }
432 
433 DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne(
434     transform::TransformRewriter &rewriter, scf::IfOp ifOp,
435     transform::ApplyToEachResultList &results,
436     transform::TransformState &state) {
437   rewriter.setInsertionPoint(ifOp);
438   Region &region =
439       getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();
440   if (!llvm::hasSingleElement(region)) {
441     return emitDefiniteFailure()
442            << "requires an scf.if op with a single-block "
443            << ((getTakeElseBranch()) ? "`else`" : "`then`") << " region";
444   }
445   replaceOpWithRegion(rewriter, ifOp, region);
446   return DiagnosedSilenceableFailure::success();
447 }
448 
449 void transform::TakeAssumedBranchOp::getEffects(
450     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
451   onlyReadsHandle(getTargetMutable(), effects);
452   modifiesPayload(effects);
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // LoopFuseSiblingOp
457 //===----------------------------------------------------------------------===//
458 
459 /// Check if `target` and `source` are siblings, in the context that `target`
460 /// is being fused into `source`.
461 ///
462 /// This is a simple check that just checks if both operations are in the same
463 /// block and some checks to ensure that the fused IR does not violate
464 /// dominance.
465 static DiagnosedSilenceableFailure isOpSibling(Operation *target,
466                                                Operation *source) {
467   // Check if both operations are same.
468   if (target == source)
469     return emitSilenceableFailure(source)
470            << "target and source need to be different loops";
471 
472   // Check if both operations are in the same block.
473   if (target->getBlock() != source->getBlock())
474     return emitSilenceableFailure(source)
475            << "target and source are not in the same block";
476 
477   // Check if fusion will violate dominance.
478   DominanceInfo domInfo(source);
479   if (target->isBeforeInBlock(source)) {
480     // Since `target` is before `source`, all users of results of `target`
481     // need to be dominated by `source`.
482     for (Operation *user : target->getUsers()) {
483       if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
484         return emitSilenceableFailure(target)
485                << "user of results of target should be properly dominated by "
486                   "source";
487       }
488     }
489   } else {
490     // Since `target` is after `source`, all values used by `target` need
491     // to dominate `source`.
492 
493     // Check if operands of `target` are dominated by `source`.
494     for (Value operand : target->getOperands()) {
495       Operation *operandOp = operand.getDefiningOp();
496       // Operands without defining operations are block arguments. When `target`
497       // and `source` occur in the same block, these operands dominate `source`.
498       if (!operandOp)
499         continue;
500 
501       // Operand's defining operation should properly dominate `source`.
502       if (!domInfo.properlyDominates(operandOp, source,
503                                      /*enclosingOpOk=*/false))
504         return emitSilenceableFailure(target)
505                << "operands of target should be properly dominated by source";
506     }
507 
508     // Check if values used by `target` are dominated by `source`.
509     bool failed = false;
510     OpOperand *failedValue = nullptr;
511     visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
512       Operation *operandOp = operand->get().getDefiningOp();
513       if (operandOp && !domInfo.properlyDominates(operandOp, source,
514                                                   /*enclosingOpOk=*/false)) {
515         // `operand` is not an argument of an enclosing block and the defining
516         // op of `operand` is outside `target` but does not dominate `source`.
517         failed = true;
518         failedValue = operand;
519       }
520     });
521 
522     if (failed)
523       return emitSilenceableFailure(failedValue->getOwner())
524              << "values used inside regions of target should be properly "
525                 "dominated by source";
526   }
527 
528   return DiagnosedSilenceableFailure::success();
529 }
530 
531 /// Check if `target` scf.forall can be fused into `source` scf.forall.
532 ///
533 /// This simply checks if both loops have the same bounds, steps and mapping.
534 /// No attempt is made at checking that the side effects of `target` and
535 /// `source` are independent of each other.
536 static bool isForallWithIdenticalConfiguration(Operation *target,
537                                                Operation *source) {
538   auto targetOp = dyn_cast<scf::ForallOp>(target);
539   auto sourceOp = dyn_cast<scf::ForallOp>(source);
540   if (!targetOp || !sourceOp)
541     return false;
542 
543   return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
544          targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
545          targetOp.getMixedStep() == sourceOp.getMixedStep() &&
546          targetOp.getMapping() == sourceOp.getMapping();
547 }
548 
549 /// Check if `target` scf.for can be fused into `source` scf.for.
550 ///
551 /// This simply checks if both loops have the same bounds and steps. No attempt
552 /// is made at checking that the side effects of `target` and `source` are
553 /// independent of each other.
554 static bool isForWithIdenticalConfiguration(Operation *target,
555                                             Operation *source) {
556   auto targetOp = dyn_cast<scf::ForOp>(target);
557   auto sourceOp = dyn_cast<scf::ForOp>(source);
558   if (!targetOp || !sourceOp)
559     return false;
560 
561   return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
562          targetOp.getUpperBound() == sourceOp.getUpperBound() &&
563          targetOp.getStep() == sourceOp.getStep();
564 }
565 
566 DiagnosedSilenceableFailure
567 transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
568                                     transform::TransformResults &results,
569                                     transform::TransformState &state) {
570   auto targetOps = state.getPayloadOps(getTarget());
571   auto sourceOps = state.getPayloadOps(getSource());
572 
573   if (!llvm::hasSingleElement(targetOps) ||
574       !llvm::hasSingleElement(sourceOps)) {
575     return emitDefiniteFailure()
576            << "requires exactly one target handle (got "
577            << llvm::range_size(targetOps) << ") and exactly one "
578            << "source handle (got " << llvm::range_size(sourceOps) << ")";
579   }
580 
581   Operation *target = *targetOps.begin();
582   Operation *source = *sourceOps.begin();
583 
584   // Check if the target and source are siblings.
585   DiagnosedSilenceableFailure diag = isOpSibling(target, source);
586   if (!diag.succeeded())
587     return diag;
588 
589   Operation *fusedLoop;
590   /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
591   if (isForWithIdenticalConfiguration(target, source)) {
592     fusedLoop = fuseIndependentSiblingForLoops(
593         cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
594   } else if (isForallWithIdenticalConfiguration(target, source)) {
595     fusedLoop = fuseIndependentSiblingForallLoops(
596         cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
597   } else
598     return emitSilenceableFailure(target->getLoc())
599            << "operations cannot be fused";
600 
601   assert(fusedLoop && "failed to fuse operations");
602 
603   results.set(cast<OpResult>(getFusedLoop()), {fusedLoop});
604   return DiagnosedSilenceableFailure::success();
605 }
606 
607 //===----------------------------------------------------------------------===//
608 // Transform op registration
609 //===----------------------------------------------------------------------===//
610 
611 namespace {
612 class SCFTransformDialectExtension
613     : public transform::TransformDialectExtension<
614           SCFTransformDialectExtension> {
615 public:
616   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFTransformDialectExtension)
617 
618   using Base::Base;
619 
620   void init() {
621     declareGeneratedDialect<affine::AffineDialect>();
622     declareGeneratedDialect<func::FuncDialect>();
623 
624     registerTransformOps<
625 #define GET_OP_LIST
626 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
627         >();
628   }
629 };
630 } // namespace
631 
632 #define GET_OP_CLASSES
633 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
634 
635 void mlir::scf::registerTransformDialectExtension(DialectRegistry &registry) {
636   registry.addExtensions<SCFTransformDialectExtension>();
637 }
638