//===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines transform dialect operations used for testing // TilingInterface // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/TilingInterface.h" #define GET_OP_CLASSES #include "TestTilingInterfaceTransformOps.h.inc" using namespace mlir; using namespace mlir::transform; //===----------------------------------------------------------------------===// // TestFuseAndYieldOp //===----------------------------------------------------------------------===// static llvm::SmallDenseSet collectTiledAndFusedOps(Operation *op) { SmallVector worklist; llvm::SmallDenseSet producers; worklist.push_back(op); producers.insert(op); while (!worklist.empty()) { Operation *current = worklist.pop_back_val(); for (OpOperand &operand : current->getOpOperands()) { Operation *producer = operand.get().getDefiningOp(); if (!producer || !isa(producer) || producers.contains(producer)) continue; worklist.push_back(producer); producers.insert(producer); } } return producers; } /// Apply a tile and fuse transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. template static LogicalResult applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, ArrayRef tileSizes, ArrayRef interchange, bool useForall, TransformResults &transformResults) { SmallVector tiledOps; SmallVector> loopOps(numLoops); for (Operation *target : payloadOps) { auto tilingInterfaceOp = dyn_cast(target); if (!tilingInterfaceOp) return transformOp->emitError("only TilingInterface ops are supported"); DominanceInfo dominanceInfo(tilingInterfaceOp); llvm::SmallDenseSet tiledAndFusedOps = collectTiledAndFusedOps(tilingInterfaceOp); llvm::DenseSet yieldReplacementsFor; for (auto op : tiledAndFusedOps) { if (llvm::any_of(op->getUsers(), [&](Operation *user) { return dominanceInfo.properlyDominates(tilingInterfaceOp, user); })) { yieldReplacementsFor.insert(op); } } scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); if (useForall) { tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); } scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.setTilingOptions(tilingOptions); scf::SCFTileAndFuseOptions::ControlFnTy controlFn = [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, bool isDestinationOperand) -> std::optional { Operation *owner = originalProducer.getOwner(); bool yieldProducerReplacement = yieldReplacementsFor.contains(owner); return scf::SCFTileAndFuseOptions::ControlFnResult{ yieldProducerReplacement}; }; tileAndFuseOptions.setFusionControlFn(controlFn); rewriter.setInsertionPoint(target); FailureOr tiledResults = scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, tileAndFuseOptions); if (failed(tiledResults)) return failure(); // Perform the replacement of tiled and fused values. SmallVector opsToReplace{target}; llvm::append_range(opsToReplace, tiledResults->fusedProducers); for (Operation *toReplace : opsToReplace) { for (OpResult res : toReplace->getResults()) if (auto replacement = tiledResults->replacements.lookup(res)) { Operation *replacementOp = replacement.getDefiningOp(); rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) { Operation *user = use.getOwner(); return dominanceInfo.properlyDominates(replacementOp, user) && user->getParentOp() == replacementOp->getParentOp(); }); } if (toReplace->use_empty()) { rewriter.eraseOp(toReplace); } } // Report back the relevant handles to the transform op. tiledOps.push_back(tiledResults->tiledAndFusedOps.front()); assert(tiledResults->loops.size() == numLoops && "Mismatched number of loops, tile and fuse transform should have " "failed"); for (unsigned int i = 0; i < numLoops; ++i) loopOps[i].push_back(tiledResults->loops[i]); } transformResults.set(transformOp->getOpResult(0), tiledOps); for (unsigned int i = 0; i < numLoops; ++i) transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); return success(); } DiagnosedSilenceableFailure transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { SmallVector tileSizes = extractFromIntegerArrayAttr(getTileSizes()); SmallVector tileInterchange = extractFromIntegerArrayAttr(getTileInterchange()); SmallVector tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); LogicalResult result = applyTileAndFuseToAll( rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr, tileInterchange, getUseForall(), transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // TestFuseConsumerOp //===----------------------------------------------------------------------===// /// Apply fusing of consumer transformation to all payload ops and store both /// the original consumer operation as well as the fused consumer operation. template static LogicalResult applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, uint32_t numConsumerToFuse, TransformResults &transformResults) { SmallVector originalConsumerOps; SmallVector fusedConsumerOps; for (Operation *target : payloadOps) { rewriter.setInsertionPoint(target); while (numConsumerToFuse--) { FailureOr fuseConsumerResults = scf::tileAndFuseConsumerOfSlice(rewriter, target); if (failed(fuseConsumerResults)) return failure(); // Report back the relevant handles to the transform op. originalConsumerOps.push_back( fuseConsumerResults->origConsumerOperand->getOwner()); fusedConsumerOps.push_back( fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner()); } } transformResults.set(transformOp->getOpResult(0), originalConsumerOps); transformResults.set(transformOp->getOpResult(1), fusedConsumerOps); return success(); } DiagnosedSilenceableFailure transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { LogicalResult result = applyFuseConsumer( rewriter, getOperation(), state.getPayloadOps(getTarget()), getNumConsumerToFuse(), transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } void transform::TestFuseConsumerOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTargetMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // TestTileUsingForallOp //===----------------------------------------------------------------------===// /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. template static LogicalResult applyTileToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, ArrayRef tileSizes, ArrayRef interchange, std::optional mapping, TransformResults &transformResults) { SmallVector tiledOps; SmallVector loopOps; for (Operation *target : payloadOps) { auto tilingInterfaceOp = dyn_cast(target); if (!tilingInterfaceOp) return transformOp->emitError("only TilingInterface ops are supported"); scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); if (mapping) { tilingOptions.setMapping(mapping.value().getValue()); } tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); rewriter.setInsertionPoint(target); FailureOr tiledResults = scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions); if (failed(tiledResults)) return failure(); // Perform the replacement of tiled and fused values. rewriter.replaceOp(tilingInterfaceOp, tiledResults->mergeResult.replacements); // Report back the relevant handles to the transform op. tiledOps.push_back(tiledResults->tiledOps.front()); for (Operation *loop : tiledResults->loops) loopOps.push_back(loop); } transformResults.set(transformOp->getOpResult(0), tiledOps); for (auto [index, loop] : llvm::enumerate(loopOps)) transformResults.set(transformOp->getOpResult(index + 1), {loop}); return success(); } DiagnosedSilenceableFailure transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { SmallVector tileSizes = extractFromIntegerArrayAttr(getTileSizes()); SmallVector interchange = extractFromIntegerArrayAttr(getInterchange()); SmallVector tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); LogicalResult result = applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizesOfr, interchange, getMapping(), transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } void transform::TestTileUsingForallOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTargetMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // TestFuseUsingForallOp //===----------------------------------------------------------------------===// /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. template static LogicalResult applyTilingToAll( RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, TransformResults &transformResults, function_ref(TilingInterface)> applyFn) { SmallVector tiledLinalgOps; SmallVector> loopOps(1); for (Operation *target : payloadOps) { auto tilingInterfaceOp = dyn_cast(target); if (!tilingInterfaceOp) return transformOp->emitError("only TilingInterface ops are supported"); rewriter.setInsertionPoint(target); FailureOr tiledResults = applyFn(tilingInterfaceOp); if (failed(tiledResults)) return failure(); // Perform the replacement of tiled and fused values. SmallVector opsToReplace{target}; llvm::append_range(opsToReplace, tiledResults->fusedProducers); for (Operation *toReplace : opsToReplace) { for (OpResult res : toReplace->getResults()) if (auto replacement = tiledResults->replacements.lookup(res)) rewriter.replaceAllUsesWith(res, replacement); if (toReplace->use_empty()) rewriter.eraseOp(toReplace); } // Report back the relevant handles to the transform op. tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front()); assert(tiledResults->loops.size() == 1 && cast(tiledResults->loops[0]).getRank() == numLoops && "Mismatched number of loops, tile and fuse transform should have " "failed"); loopOps[0] = {tiledResults->loops[0]}; } transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); if (!loopOps.empty()) transformResults.set(transformOp->getOpResult(1), loopOps[0]); return success(); } DiagnosedSilenceableFailure transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { SmallVector tileSizes = extractFromIntegerArrayAttr(getTileSizes()); SmallVector tileInterchange = extractFromIntegerArrayAttr(getInterchange()); scf::SCFTilingOptions tilingOptions; tilingOptions.interchangeVector = tileInterchange; SmallVector tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions = tilingOptions; LogicalResult result = applyTilingToAll( rewriter, getOperation(), state.getPayloadOps(getRootOp()), tileSizes.size() - llvm::count(tileSizes, 0), transformResults, [&](TilingInterface tilingInterfaceOp) -> FailureOr { return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, tileAndFuseOptions); }); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } void transform::TestFuseUsingForallOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getRootOpMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); modifiesPayload(effects); } #define GET_OP_CLASSES #include "TestTilingInterfaceTransformOps.cpp.inc" namespace { class TestTilingInterfaceDialectExtension : public transform::TransformDialectExtension< TestTilingInterfaceDialectExtension> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestTilingInterfaceDialectExtension) using Base::Base; void init() { declareDependentDialect(); declareDependentDialect(); declareDependentDialect(); declareDependentDialect(); registerTransformOps< #define GET_OP_LIST #include "TestTilingInterfaceTransformOps.cpp.inc" >(); } }; } // namespace namespace test { void registerTestTilingInterfaceTransformDialectExtension( DialectRegistry ®istry) { registry.addExtensions(); } } // namespace test