1aa2a96a2SMaheshRavishankar //===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===// 2aa2a96a2SMaheshRavishankar // 3aa2a96a2SMaheshRavishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4aa2a96a2SMaheshRavishankar // See https://llvm.org/LICENSE.txt for license information. 5aa2a96a2SMaheshRavishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6aa2a96a2SMaheshRavishankar // 7aa2a96a2SMaheshRavishankar //===----------------------------------------------------------------------===// 8aa2a96a2SMaheshRavishankar // 9aa2a96a2SMaheshRavishankar // This file defines transform dialect operations used for testing 10aa2a96a2SMaheshRavishankar // TilingInterface 11aa2a96a2SMaheshRavishankar // 12aa2a96a2SMaheshRavishankar //===----------------------------------------------------------------------===// 13aa2a96a2SMaheshRavishankar 14aa2a96a2SMaheshRavishankar #include "mlir/Dialect/Affine/IR/AffineOps.h" 15aa2a96a2SMaheshRavishankar #include "mlir/Dialect/Index/IR/IndexDialect.h" 16aa2a96a2SMaheshRavishankar #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 17aa2a96a2SMaheshRavishankar #include "mlir/Dialect/Transform/IR/TransformAttrs.h" 18aa2a96a2SMaheshRavishankar #include "mlir/Dialect/Transform/IR/TransformDialect.h" 195a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 20aa2a96a2SMaheshRavishankar #include "mlir/Dialect/Utils/StaticValueUtils.h" 21aa2a96a2SMaheshRavishankar #include "mlir/IR/Dominance.h" 22aa2a96a2SMaheshRavishankar #include "mlir/IR/OpImplementation.h" 23aa2a96a2SMaheshRavishankar #include "mlir/Interfaces/TilingInterface.h" 24aa2a96a2SMaheshRavishankar 25aa2a96a2SMaheshRavishankar #define GET_OP_CLASSES 26aa2a96a2SMaheshRavishankar #include "TestTilingInterfaceTransformOps.h.inc" 27aa2a96a2SMaheshRavishankar 28aa2a96a2SMaheshRavishankar using namespace mlir; 29aa2a96a2SMaheshRavishankar using namespace mlir::transform; 30aa2a96a2SMaheshRavishankar 31aa2a96a2SMaheshRavishankar //===----------------------------------------------------------------------===// 32aa2a96a2SMaheshRavishankar // TestFuseAndYieldOp 33aa2a96a2SMaheshRavishankar //===----------------------------------------------------------------------===// 34aa2a96a2SMaheshRavishankar 35aa2a96a2SMaheshRavishankar static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) { 36aa2a96a2SMaheshRavishankar SmallVector<Operation *> worklist; 37aa2a96a2SMaheshRavishankar llvm::SmallDenseSet<Operation *> producers; 38aa2a96a2SMaheshRavishankar worklist.push_back(op); 39aa2a96a2SMaheshRavishankar producers.insert(op); 40aa2a96a2SMaheshRavishankar while (!worklist.empty()) { 41aa2a96a2SMaheshRavishankar Operation *current = worklist.pop_back_val(); 42aa2a96a2SMaheshRavishankar for (OpOperand &operand : current->getOpOperands()) { 43aa2a96a2SMaheshRavishankar Operation *producer = operand.get().getDefiningOp(); 44aa2a96a2SMaheshRavishankar if (!producer || !isa<TilingInterface>(producer) || 45aa2a96a2SMaheshRavishankar producers.contains(producer)) 46aa2a96a2SMaheshRavishankar continue; 47aa2a96a2SMaheshRavishankar worklist.push_back(producer); 48aa2a96a2SMaheshRavishankar producers.insert(producer); 49aa2a96a2SMaheshRavishankar } 50aa2a96a2SMaheshRavishankar } 51aa2a96a2SMaheshRavishankar return producers; 52aa2a96a2SMaheshRavishankar } 53aa2a96a2SMaheshRavishankar 54aa2a96a2SMaheshRavishankar /// Apply a tile and fuse transformation to all payload ops and store both the 55aa2a96a2SMaheshRavishankar /// tiled operation as well as the created tile loops. 56aa2a96a2SMaheshRavishankar template <typename Range> 57aa2a96a2SMaheshRavishankar static LogicalResult 58aa2a96a2SMaheshRavishankar applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp, 59aa2a96a2SMaheshRavishankar Range &&payloadOps, unsigned numLoops, 60aa2a96a2SMaheshRavishankar ArrayRef<OpFoldResult> tileSizes, 6176ead96cSMaheshRavishankar ArrayRef<int64_t> interchange, bool useForall, 6276ead96cSMaheshRavishankar TransformResults &transformResults) { 63aa2a96a2SMaheshRavishankar SmallVector<Operation *> tiledOps; 64aa2a96a2SMaheshRavishankar SmallVector<SmallVector<Operation *>> loopOps(numLoops); 65aa2a96a2SMaheshRavishankar 66aa2a96a2SMaheshRavishankar for (Operation *target : payloadOps) { 67aa2a96a2SMaheshRavishankar auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); 68aa2a96a2SMaheshRavishankar if (!tilingInterfaceOp) 69aa2a96a2SMaheshRavishankar return transformOp->emitError("only TilingInterface ops are supported"); 70aa2a96a2SMaheshRavishankar DominanceInfo dominanceInfo(tilingInterfaceOp); 71aa2a96a2SMaheshRavishankar 72aa2a96a2SMaheshRavishankar llvm::SmallDenseSet<Operation *> tiledAndFusedOps = 73aa2a96a2SMaheshRavishankar collectTiledAndFusedOps(tilingInterfaceOp); 74aa2a96a2SMaheshRavishankar llvm::DenseSet<Operation *> yieldReplacementsFor; 75aa2a96a2SMaheshRavishankar for (auto op : tiledAndFusedOps) { 76aa2a96a2SMaheshRavishankar if (llvm::any_of(op->getUsers(), [&](Operation *user) { 77aa2a96a2SMaheshRavishankar return dominanceInfo.properlyDominates(tilingInterfaceOp, user); 78aa2a96a2SMaheshRavishankar })) { 79aa2a96a2SMaheshRavishankar yieldReplacementsFor.insert(op); 80aa2a96a2SMaheshRavishankar } 81aa2a96a2SMaheshRavishankar } 82aa2a96a2SMaheshRavishankar 83aa2a96a2SMaheshRavishankar scf::SCFTilingOptions tilingOptions; 84aa2a96a2SMaheshRavishankar tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); 8576ead96cSMaheshRavishankar if (useForall) { 8676ead96cSMaheshRavishankar tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); 8776ead96cSMaheshRavishankar } 88aa2a96a2SMaheshRavishankar 89aa2a96a2SMaheshRavishankar scf::SCFTileAndFuseOptions tileAndFuseOptions; 90aa2a96a2SMaheshRavishankar tileAndFuseOptions.setTilingOptions(tilingOptions); 91aa2a96a2SMaheshRavishankar 92aa2a96a2SMaheshRavishankar scf::SCFTileAndFuseOptions::ControlFnTy controlFn = 93aa2a96a2SMaheshRavishankar [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, 94d5f0969cSMaheshRavishankar bool isDestinationOperand) 95d5f0969cSMaheshRavishankar -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> { 96aa2a96a2SMaheshRavishankar Operation *owner = originalProducer.getOwner(); 97aa2a96a2SMaheshRavishankar bool yieldProducerReplacement = yieldReplacementsFor.contains(owner); 98d5f0969cSMaheshRavishankar return scf::SCFTileAndFuseOptions::ControlFnResult{ 99d5f0969cSMaheshRavishankar yieldProducerReplacement}; 100aa2a96a2SMaheshRavishankar }; 101aa2a96a2SMaheshRavishankar tileAndFuseOptions.setFusionControlFn(controlFn); 102aa2a96a2SMaheshRavishankar 103aa2a96a2SMaheshRavishankar rewriter.setInsertionPoint(target); 104aa2a96a2SMaheshRavishankar FailureOr<scf::SCFTileAndFuseResult> tiledResults = 10576ead96cSMaheshRavishankar scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, 10676ead96cSMaheshRavishankar tileAndFuseOptions); 107aa2a96a2SMaheshRavishankar if (failed(tiledResults)) 108aa2a96a2SMaheshRavishankar return failure(); 109aa2a96a2SMaheshRavishankar 110aa2a96a2SMaheshRavishankar // Perform the replacement of tiled and fused values. 111aa2a96a2SMaheshRavishankar SmallVector<Operation *> opsToReplace{target}; 112aa2a96a2SMaheshRavishankar llvm::append_range(opsToReplace, tiledResults->fusedProducers); 113aa2a96a2SMaheshRavishankar for (Operation *toReplace : opsToReplace) { 114aa2a96a2SMaheshRavishankar for (OpResult res : toReplace->getResults()) 115aa2a96a2SMaheshRavishankar if (auto replacement = tiledResults->replacements.lookup(res)) { 116aa2a96a2SMaheshRavishankar Operation *replacementOp = replacement.getDefiningOp(); 11776ead96cSMaheshRavishankar rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) { 118aa2a96a2SMaheshRavishankar Operation *user = use.getOwner(); 119aa2a96a2SMaheshRavishankar return dominanceInfo.properlyDominates(replacementOp, user) && 120aa2a96a2SMaheshRavishankar user->getParentOp() == replacementOp->getParentOp(); 121aa2a96a2SMaheshRavishankar }); 122aa2a96a2SMaheshRavishankar } 123aa2a96a2SMaheshRavishankar 124aa2a96a2SMaheshRavishankar if (toReplace->use_empty()) { 125aa2a96a2SMaheshRavishankar rewriter.eraseOp(toReplace); 126aa2a96a2SMaheshRavishankar } 127aa2a96a2SMaheshRavishankar } 128aa2a96a2SMaheshRavishankar 129aa2a96a2SMaheshRavishankar // Report back the relevant handles to the transform op. 130aa2a96a2SMaheshRavishankar tiledOps.push_back(tiledResults->tiledAndFusedOps.front()); 131aa2a96a2SMaheshRavishankar assert(tiledResults->loops.size() == numLoops && 132aa2a96a2SMaheshRavishankar "Mismatched number of loops, tile and fuse transform should have " 133aa2a96a2SMaheshRavishankar "failed"); 134aa2a96a2SMaheshRavishankar for (unsigned int i = 0; i < numLoops; ++i) 135aa2a96a2SMaheshRavishankar loopOps[i].push_back(tiledResults->loops[i]); 136aa2a96a2SMaheshRavishankar } 137aa2a96a2SMaheshRavishankar 138aa2a96a2SMaheshRavishankar transformResults.set(transformOp->getOpResult(0), tiledOps); 139aa2a96a2SMaheshRavishankar for (unsigned int i = 0; i < numLoops; ++i) 140aa2a96a2SMaheshRavishankar transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); 141aa2a96a2SMaheshRavishankar 142aa2a96a2SMaheshRavishankar return success(); 143aa2a96a2SMaheshRavishankar } 144aa2a96a2SMaheshRavishankar 14576ead96cSMaheshRavishankar DiagnosedSilenceableFailure 14676ead96cSMaheshRavishankar transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, 14776ead96cSMaheshRavishankar TransformResults &transformResults, 14876ead96cSMaheshRavishankar TransformState &state) { 149aa2a96a2SMaheshRavishankar SmallVector<int64_t> tileSizes = 150aa2a96a2SMaheshRavishankar extractFromIntegerArrayAttr<int64_t>(getTileSizes()); 151aa2a96a2SMaheshRavishankar SmallVector<int64_t> tileInterchange = 152aa2a96a2SMaheshRavishankar extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); 153aa2a96a2SMaheshRavishankar 154aa2a96a2SMaheshRavishankar SmallVector<OpFoldResult> tileSizesOfr = 155aa2a96a2SMaheshRavishankar getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); 156aa2a96a2SMaheshRavishankar 157aa2a96a2SMaheshRavishankar LogicalResult result = applyTileAndFuseToAll( 158aa2a96a2SMaheshRavishankar rewriter, getOperation(), state.getPayloadOps(getTarget()), 159aa2a96a2SMaheshRavishankar tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr, 16076ead96cSMaheshRavishankar tileInterchange, getUseForall(), transformResults); 161aa2a96a2SMaheshRavishankar return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() 162aa2a96a2SMaheshRavishankar : DiagnosedSilenceableFailure::success(); 163aa2a96a2SMaheshRavishankar } 164aa2a96a2SMaheshRavishankar 165aa2a96a2SMaheshRavishankar //===----------------------------------------------------------------------===// 1662b2ce50fSAbhishek Varma // TestFuseConsumerOp 1672b2ce50fSAbhishek Varma //===----------------------------------------------------------------------===// 1682b2ce50fSAbhishek Varma 1692b2ce50fSAbhishek Varma /// Apply fusing of consumer transformation to all payload ops and store both 1702b2ce50fSAbhishek Varma /// the original consumer operation as well as the fused consumer operation. 1712b2ce50fSAbhishek Varma template <typename Range> 1722b2ce50fSAbhishek Varma static LogicalResult 1732b2ce50fSAbhishek Varma applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, 1749bc3102bSYun-Fly Range &&payloadOps, uint32_t numConsumerToFuse, 1759bc3102bSYun-Fly TransformResults &transformResults) { 1762b2ce50fSAbhishek Varma SmallVector<Operation *> originalConsumerOps; 1772b2ce50fSAbhishek Varma SmallVector<Operation *> fusedConsumerOps; 1782b2ce50fSAbhishek Varma 1792b2ce50fSAbhishek Varma for (Operation *target : payloadOps) { 1802b2ce50fSAbhishek Varma rewriter.setInsertionPoint(target); 1812b2ce50fSAbhishek Varma 1829bc3102bSYun-Fly while (numConsumerToFuse--) { 1832b2ce50fSAbhishek Varma FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults = 1842b2ce50fSAbhishek Varma scf::tileAndFuseConsumerOfSlice(rewriter, target); 1852b2ce50fSAbhishek Varma 1862b2ce50fSAbhishek Varma if (failed(fuseConsumerResults)) 1872b2ce50fSAbhishek Varma return failure(); 1882b2ce50fSAbhishek Varma 1892b2ce50fSAbhishek Varma // Report back the relevant handles to the transform op. 1902b2ce50fSAbhishek Varma originalConsumerOps.push_back( 1912b2ce50fSAbhishek Varma fuseConsumerResults->origConsumerOperand->getOwner()); 1922b2ce50fSAbhishek Varma fusedConsumerOps.push_back( 1932b2ce50fSAbhishek Varma fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner()); 1942b2ce50fSAbhishek Varma } 1959bc3102bSYun-Fly } 1962b2ce50fSAbhishek Varma 1972b2ce50fSAbhishek Varma transformResults.set(transformOp->getOpResult(0), originalConsumerOps); 1982b2ce50fSAbhishek Varma transformResults.set(transformOp->getOpResult(1), fusedConsumerOps); 1992b2ce50fSAbhishek Varma return success(); 2002b2ce50fSAbhishek Varma } 2012b2ce50fSAbhishek Varma 2022b2ce50fSAbhishek Varma DiagnosedSilenceableFailure 2032b2ce50fSAbhishek Varma transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, 2042b2ce50fSAbhishek Varma TransformResults &transformResults, 2052b2ce50fSAbhishek Varma TransformState &state) { 2069bc3102bSYun-Fly LogicalResult result = applyFuseConsumer( 2079bc3102bSYun-Fly rewriter, getOperation(), state.getPayloadOps(getTarget()), 2089bc3102bSYun-Fly getNumConsumerToFuse(), transformResults); 2092b2ce50fSAbhishek Varma return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() 2102b2ce50fSAbhishek Varma : DiagnosedSilenceableFailure::success(); 2112b2ce50fSAbhishek Varma } 2122b2ce50fSAbhishek Varma 2132b2ce50fSAbhishek Varma void transform::TestFuseConsumerOp::getEffects( 2142b2ce50fSAbhishek Varma SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2152c1ae801Sdonald chen consumesHandle(getTargetMutable(), effects); 2162c1ae801Sdonald chen producesHandle(getOperation()->getOpResults(), effects); 2172b2ce50fSAbhishek Varma modifiesPayload(effects); 2182b2ce50fSAbhishek Varma } 2192b2ce50fSAbhishek Varma 2202b2ce50fSAbhishek Varma //===----------------------------------------------------------------------===// 221aa2a96a2SMaheshRavishankar // TestTileUsingForallOp 222aa2a96a2SMaheshRavishankar //===----------------------------------------------------------------------===// 223aa2a96a2SMaheshRavishankar 224aa2a96a2SMaheshRavishankar /// Apply a tiling transformation to all payload ops and store both the 225aa2a96a2SMaheshRavishankar /// tiled operation as well as the created tile loops. 226aa2a96a2SMaheshRavishankar template <typename Range> 227aa2a96a2SMaheshRavishankar static LogicalResult 228aa2a96a2SMaheshRavishankar applyTileToAll(RewriterBase &rewriter, Operation *transformOp, 229aa2a96a2SMaheshRavishankar Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes, 230aa2a96a2SMaheshRavishankar ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping, 23176ead96cSMaheshRavishankar TransformResults &transformResults) { 232aa2a96a2SMaheshRavishankar SmallVector<Operation *> tiledOps; 233aa2a96a2SMaheshRavishankar SmallVector<Operation *> loopOps; 234aa2a96a2SMaheshRavishankar 235aa2a96a2SMaheshRavishankar for (Operation *target : payloadOps) { 236aa2a96a2SMaheshRavishankar auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); 237aa2a96a2SMaheshRavishankar if (!tilingInterfaceOp) 238aa2a96a2SMaheshRavishankar return transformOp->emitError("only TilingInterface ops are supported"); 239aa2a96a2SMaheshRavishankar scf::SCFTilingOptions tilingOptions; 240aa2a96a2SMaheshRavishankar tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); 241aa2a96a2SMaheshRavishankar if (mapping) { 2426740d701SMaheshRavishankar tilingOptions.setMapping(mapping.value().getValue()); 243aa2a96a2SMaheshRavishankar } 24476ead96cSMaheshRavishankar tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); 245aa2a96a2SMaheshRavishankar 246aa2a96a2SMaheshRavishankar rewriter.setInsertionPoint(target); 247aa2a96a2SMaheshRavishankar FailureOr<scf::SCFTilingResult> tiledResults = 24876ead96cSMaheshRavishankar scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions); 249aa2a96a2SMaheshRavishankar if (failed(tiledResults)) 250aa2a96a2SMaheshRavishankar return failure(); 251aa2a96a2SMaheshRavishankar 252aa2a96a2SMaheshRavishankar // Perform the replacement of tiled and fused values. 253*4b563458SKunwar Grover rewriter.replaceOp(tilingInterfaceOp, 254*4b563458SKunwar Grover tiledResults->mergeResult.replacements); 255aa2a96a2SMaheshRavishankar 256aa2a96a2SMaheshRavishankar // Report back the relevant handles to the transform op. 257aa2a96a2SMaheshRavishankar tiledOps.push_back(tiledResults->tiledOps.front()); 258aa2a96a2SMaheshRavishankar for (Operation *loop : tiledResults->loops) 259aa2a96a2SMaheshRavishankar loopOps.push_back(loop); 260aa2a96a2SMaheshRavishankar } 261aa2a96a2SMaheshRavishankar 262aa2a96a2SMaheshRavishankar transformResults.set(transformOp->getOpResult(0), tiledOps); 263aa2a96a2SMaheshRavishankar for (auto [index, loop] : llvm::enumerate(loopOps)) 264aa2a96a2SMaheshRavishankar transformResults.set(transformOp->getOpResult(index + 1), {loop}); 265aa2a96a2SMaheshRavishankar 266aa2a96a2SMaheshRavishankar return success(); 267aa2a96a2SMaheshRavishankar } 268aa2a96a2SMaheshRavishankar 26976ead96cSMaheshRavishankar DiagnosedSilenceableFailure 27076ead96cSMaheshRavishankar transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter, 27176ead96cSMaheshRavishankar TransformResults &transformResults, 27276ead96cSMaheshRavishankar TransformState &state) { 273aa2a96a2SMaheshRavishankar SmallVector<int64_t> tileSizes = 274aa2a96a2SMaheshRavishankar extractFromIntegerArrayAttr<int64_t>(getTileSizes()); 275aa2a96a2SMaheshRavishankar SmallVector<int64_t> interchange = 276aa2a96a2SMaheshRavishankar extractFromIntegerArrayAttr<int64_t>(getInterchange()); 277aa2a96a2SMaheshRavishankar SmallVector<OpFoldResult> tileSizesOfr = 278aa2a96a2SMaheshRavishankar getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); 279aa2a96a2SMaheshRavishankar 280aa2a96a2SMaheshRavishankar LogicalResult result = 281aa2a96a2SMaheshRavishankar applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()), 282aa2a96a2SMaheshRavishankar tileSizesOfr, interchange, getMapping(), transformResults); 283aa2a96a2SMaheshRavishankar return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() 284aa2a96a2SMaheshRavishankar : DiagnosedSilenceableFailure::success(); 285aa2a96a2SMaheshRavishankar } 286aa2a96a2SMaheshRavishankar 287aa2a96a2SMaheshRavishankar void transform::TestTileUsingForallOp::getEffects( 288aa2a96a2SMaheshRavishankar SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2892c1ae801Sdonald chen consumesHandle(getTargetMutable(), effects); 2902c1ae801Sdonald chen producesHandle(getOperation()->getOpResults(), effects); 291aa2a96a2SMaheshRavishankar modifiesPayload(effects); 292aa2a96a2SMaheshRavishankar } 293aa2a96a2SMaheshRavishankar 29476ead96cSMaheshRavishankar //===----------------------------------------------------------------------===// 29576ead96cSMaheshRavishankar // TestFuseUsingForallOp 29676ead96cSMaheshRavishankar //===----------------------------------------------------------------------===// 29776ead96cSMaheshRavishankar 29876ead96cSMaheshRavishankar /// Apply a tiling transformation to all payload ops and store both the 29976ead96cSMaheshRavishankar /// tiled operation as well as the created tile loops. 30076ead96cSMaheshRavishankar template <typename Range> 30176ead96cSMaheshRavishankar static LogicalResult applyTilingToAll( 30276ead96cSMaheshRavishankar RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, 30376ead96cSMaheshRavishankar unsigned numLoops, TransformResults &transformResults, 30476ead96cSMaheshRavishankar function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)> 30576ead96cSMaheshRavishankar applyFn) { 30676ead96cSMaheshRavishankar SmallVector<Operation *> tiledLinalgOps; 30776ead96cSMaheshRavishankar SmallVector<SmallVector<Operation *>> loopOps(1); 30876ead96cSMaheshRavishankar 30976ead96cSMaheshRavishankar for (Operation *target : payloadOps) { 31076ead96cSMaheshRavishankar auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); 31176ead96cSMaheshRavishankar if (!tilingInterfaceOp) 31276ead96cSMaheshRavishankar return transformOp->emitError("only TilingInterface ops are supported"); 31376ead96cSMaheshRavishankar 31476ead96cSMaheshRavishankar rewriter.setInsertionPoint(target); 31576ead96cSMaheshRavishankar FailureOr<scf::SCFTileAndFuseResult> tiledResults = 31676ead96cSMaheshRavishankar applyFn(tilingInterfaceOp); 31776ead96cSMaheshRavishankar if (failed(tiledResults)) 31876ead96cSMaheshRavishankar return failure(); 31976ead96cSMaheshRavishankar 32076ead96cSMaheshRavishankar // Perform the replacement of tiled and fused values. 32176ead96cSMaheshRavishankar SmallVector<Operation *> opsToReplace{target}; 32276ead96cSMaheshRavishankar llvm::append_range(opsToReplace, tiledResults->fusedProducers); 32376ead96cSMaheshRavishankar for (Operation *toReplace : opsToReplace) { 32476ead96cSMaheshRavishankar for (OpResult res : toReplace->getResults()) 32576ead96cSMaheshRavishankar if (auto replacement = tiledResults->replacements.lookup(res)) 32676ead96cSMaheshRavishankar rewriter.replaceAllUsesWith(res, replacement); 32776ead96cSMaheshRavishankar if (toReplace->use_empty()) 32876ead96cSMaheshRavishankar rewriter.eraseOp(toReplace); 32976ead96cSMaheshRavishankar } 33076ead96cSMaheshRavishankar 33176ead96cSMaheshRavishankar // Report back the relevant handles to the transform op. 33276ead96cSMaheshRavishankar tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front()); 33376ead96cSMaheshRavishankar assert(tiledResults->loops.size() == 1 && 33476ead96cSMaheshRavishankar cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops && 33576ead96cSMaheshRavishankar "Mismatched number of loops, tile and fuse transform should have " 33676ead96cSMaheshRavishankar "failed"); 33776ead96cSMaheshRavishankar loopOps[0] = {tiledResults->loops[0]}; 33876ead96cSMaheshRavishankar } 33976ead96cSMaheshRavishankar 34076ead96cSMaheshRavishankar transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); 34176ead96cSMaheshRavishankar if (!loopOps.empty()) 34276ead96cSMaheshRavishankar transformResults.set(transformOp->getOpResult(1), loopOps[0]); 34376ead96cSMaheshRavishankar 34476ead96cSMaheshRavishankar return success(); 34576ead96cSMaheshRavishankar } 34676ead96cSMaheshRavishankar 34776ead96cSMaheshRavishankar DiagnosedSilenceableFailure 34876ead96cSMaheshRavishankar transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter, 34976ead96cSMaheshRavishankar TransformResults &transformResults, 35076ead96cSMaheshRavishankar TransformState &state) { 35176ead96cSMaheshRavishankar SmallVector<int64_t> tileSizes = 35276ead96cSMaheshRavishankar extractFromIntegerArrayAttr<int64_t>(getTileSizes()); 35376ead96cSMaheshRavishankar SmallVector<int64_t> tileInterchange = 35476ead96cSMaheshRavishankar extractFromIntegerArrayAttr<int64_t>(getInterchange()); 35576ead96cSMaheshRavishankar 35676ead96cSMaheshRavishankar scf::SCFTilingOptions tilingOptions; 35776ead96cSMaheshRavishankar tilingOptions.interchangeVector = tileInterchange; 35876ead96cSMaheshRavishankar SmallVector<OpFoldResult> tileSizesOfr = 35976ead96cSMaheshRavishankar getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); 36076ead96cSMaheshRavishankar tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); 36176ead96cSMaheshRavishankar tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); 36276ead96cSMaheshRavishankar scf::SCFTileAndFuseOptions tileAndFuseOptions; 36376ead96cSMaheshRavishankar tileAndFuseOptions.tilingOptions = tilingOptions; 36476ead96cSMaheshRavishankar LogicalResult result = applyTilingToAll( 36576ead96cSMaheshRavishankar rewriter, getOperation(), state.getPayloadOps(getRootOp()), 36676ead96cSMaheshRavishankar tileSizes.size() - llvm::count(tileSizes, 0), transformResults, 36776ead96cSMaheshRavishankar [&](TilingInterface tilingInterfaceOp) 36876ead96cSMaheshRavishankar -> FailureOr<scf::SCFTileAndFuseResult> { 36976ead96cSMaheshRavishankar return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, 37076ead96cSMaheshRavishankar tileAndFuseOptions); 37176ead96cSMaheshRavishankar }); 37276ead96cSMaheshRavishankar return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() 37376ead96cSMaheshRavishankar : DiagnosedSilenceableFailure::success(); 37476ead96cSMaheshRavishankar } 37576ead96cSMaheshRavishankar 37676ead96cSMaheshRavishankar void transform::TestFuseUsingForallOp::getEffects( 37776ead96cSMaheshRavishankar SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 3782c1ae801Sdonald chen consumesHandle(getRootOpMutable(), effects); 3792c1ae801Sdonald chen producesHandle(getOperation()->getOpResults(), effects); 38076ead96cSMaheshRavishankar modifiesPayload(effects); 38176ead96cSMaheshRavishankar } 38276ead96cSMaheshRavishankar 383aa2a96a2SMaheshRavishankar #define GET_OP_CLASSES 384aa2a96a2SMaheshRavishankar #include "TestTilingInterfaceTransformOps.cpp.inc" 385aa2a96a2SMaheshRavishankar 386aa2a96a2SMaheshRavishankar namespace { 387aa2a96a2SMaheshRavishankar class TestTilingInterfaceDialectExtension 388aa2a96a2SMaheshRavishankar : public transform::TransformDialectExtension< 389aa2a96a2SMaheshRavishankar TestTilingInterfaceDialectExtension> { 390aa2a96a2SMaheshRavishankar public: 39184cc1865SNikhil Kalra MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 39284cc1865SNikhil Kalra TestTilingInterfaceDialectExtension) 39384cc1865SNikhil Kalra 394aa2a96a2SMaheshRavishankar using Base::Base; 395aa2a96a2SMaheshRavishankar 396aa2a96a2SMaheshRavishankar void init() { 397aa2a96a2SMaheshRavishankar declareDependentDialect<affine::AffineDialect>(); 398aa2a96a2SMaheshRavishankar declareDependentDialect<index::IndexDialect>(); 399aa2a96a2SMaheshRavishankar declareDependentDialect<scf::SCFDialect>(); 400aa2a96a2SMaheshRavishankar declareDependentDialect<tensor::TensorDialect>(); 401aa2a96a2SMaheshRavishankar 402aa2a96a2SMaheshRavishankar registerTransformOps< 403aa2a96a2SMaheshRavishankar #define GET_OP_LIST 404aa2a96a2SMaheshRavishankar #include "TestTilingInterfaceTransformOps.cpp.inc" 405aa2a96a2SMaheshRavishankar >(); 406aa2a96a2SMaheshRavishankar } 407aa2a96a2SMaheshRavishankar }; 408aa2a96a2SMaheshRavishankar } // namespace 409aa2a96a2SMaheshRavishankar 410aa2a96a2SMaheshRavishankar namespace test { 411aa2a96a2SMaheshRavishankar void registerTestTilingInterfaceTransformDialectExtension( 412aa2a96a2SMaheshRavishankar DialectRegistry ®istry) { 413aa2a96a2SMaheshRavishankar registry.addExtensions<TestTilingInterfaceDialectExtension>(); 414aa2a96a2SMaheshRavishankar } 415aa2a96a2SMaheshRavishankar } // namespace test 416