1 //===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines transform dialect operations used for testing 10 // TilingInterface 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Index/IR/IndexDialect.h" 16 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 17 #include "mlir/Dialect/Transform/IR/TransformAttrs.h" 18 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 19 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 20 #include "mlir/Dialect/Utils/StaticValueUtils.h" 21 #include "mlir/IR/Dominance.h" 22 #include "mlir/IR/OpImplementation.h" 23 #include "mlir/Interfaces/TilingInterface.h" 24 25 #define GET_OP_CLASSES 26 #include "TestTilingInterfaceTransformOps.h.inc" 27 28 using namespace mlir; 29 using namespace mlir::transform; 30 31 //===----------------------------------------------------------------------===// 32 // TestFuseAndYieldOp 33 //===----------------------------------------------------------------------===// 34 35 static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) { 36 SmallVector<Operation *> worklist; 37 llvm::SmallDenseSet<Operation *> producers; 38 worklist.push_back(op); 39 producers.insert(op); 40 while (!worklist.empty()) { 41 Operation *current = worklist.pop_back_val(); 42 for (OpOperand &operand : current->getOpOperands()) { 43 Operation *producer = operand.get().getDefiningOp(); 44 if (!producer || !isa<TilingInterface>(producer) || 45 producers.contains(producer)) 46 continue; 47 worklist.push_back(producer); 48 producers.insert(producer); 49 } 50 } 51 return producers; 52 } 53 54 /// Apply a tile and fuse transformation to all payload ops and store both the 55 /// tiled operation as well as the created tile loops. 56 template <typename Range> 57 static LogicalResult 58 applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp, 59 Range &&payloadOps, unsigned numLoops, 60 ArrayRef<OpFoldResult> tileSizes, 61 ArrayRef<int64_t> interchange, bool useForall, 62 TransformResults &transformResults) { 63 SmallVector<Operation *> tiledOps; 64 SmallVector<SmallVector<Operation *>> loopOps(numLoops); 65 66 for (Operation *target : payloadOps) { 67 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); 68 if (!tilingInterfaceOp) 69 return transformOp->emitError("only TilingInterface ops are supported"); 70 DominanceInfo dominanceInfo(tilingInterfaceOp); 71 72 llvm::SmallDenseSet<Operation *> tiledAndFusedOps = 73 collectTiledAndFusedOps(tilingInterfaceOp); 74 llvm::DenseSet<Operation *> yieldReplacementsFor; 75 for (auto op : tiledAndFusedOps) { 76 if (llvm::any_of(op->getUsers(), [&](Operation *user) { 77 return dominanceInfo.properlyDominates(tilingInterfaceOp, user); 78 })) { 79 yieldReplacementsFor.insert(op); 80 } 81 } 82 83 scf::SCFTilingOptions tilingOptions; 84 tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); 85 if (useForall) { 86 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); 87 } 88 89 scf::SCFTileAndFuseOptions tileAndFuseOptions; 90 tileAndFuseOptions.setTilingOptions(tilingOptions); 91 92 scf::SCFTileAndFuseOptions::ControlFnTy controlFn = 93 [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, 94 bool isDestinationOperand) 95 -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> { 96 Operation *owner = originalProducer.getOwner(); 97 bool yieldProducerReplacement = yieldReplacementsFor.contains(owner); 98 return scf::SCFTileAndFuseOptions::ControlFnResult{ 99 yieldProducerReplacement}; 100 }; 101 tileAndFuseOptions.setFusionControlFn(controlFn); 102 103 rewriter.setInsertionPoint(target); 104 FailureOr<scf::SCFTileAndFuseResult> tiledResults = 105 scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, 106 tileAndFuseOptions); 107 if (failed(tiledResults)) 108 return failure(); 109 110 // Perform the replacement of tiled and fused values. 111 SmallVector<Operation *> opsToReplace{target}; 112 llvm::append_range(opsToReplace, tiledResults->fusedProducers); 113 for (Operation *toReplace : opsToReplace) { 114 for (OpResult res : toReplace->getResults()) 115 if (auto replacement = tiledResults->replacements.lookup(res)) { 116 Operation *replacementOp = replacement.getDefiningOp(); 117 rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) { 118 Operation *user = use.getOwner(); 119 return dominanceInfo.properlyDominates(replacementOp, user) && 120 user->getParentOp() == replacementOp->getParentOp(); 121 }); 122 } 123 124 if (toReplace->use_empty()) { 125 rewriter.eraseOp(toReplace); 126 } 127 } 128 129 // Report back the relevant handles to the transform op. 130 tiledOps.push_back(tiledResults->tiledAndFusedOps.front()); 131 assert(tiledResults->loops.size() == numLoops && 132 "Mismatched number of loops, tile and fuse transform should have " 133 "failed"); 134 for (unsigned int i = 0; i < numLoops; ++i) 135 loopOps[i].push_back(tiledResults->loops[i]); 136 } 137 138 transformResults.set(transformOp->getOpResult(0), tiledOps); 139 for (unsigned int i = 0; i < numLoops; ++i) 140 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); 141 142 return success(); 143 } 144 145 DiagnosedSilenceableFailure 146 transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, 147 TransformResults &transformResults, 148 TransformState &state) { 149 SmallVector<int64_t> tileSizes = 150 extractFromIntegerArrayAttr<int64_t>(getTileSizes()); 151 SmallVector<int64_t> tileInterchange = 152 extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); 153 154 SmallVector<OpFoldResult> tileSizesOfr = 155 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); 156 157 LogicalResult result = applyTileAndFuseToAll( 158 rewriter, getOperation(), state.getPayloadOps(getTarget()), 159 tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr, 160 tileInterchange, getUseForall(), transformResults); 161 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() 162 : DiagnosedSilenceableFailure::success(); 163 } 164 165 //===----------------------------------------------------------------------===// 166 // TestFuseConsumerOp 167 //===----------------------------------------------------------------------===// 168 169 /// Apply fusing of consumer transformation to all payload ops and store both 170 /// the original consumer operation as well as the fused consumer operation. 171 template <typename Range> 172 static LogicalResult 173 applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, 174 Range &&payloadOps, uint32_t numConsumerToFuse, 175 TransformResults &transformResults) { 176 SmallVector<Operation *> originalConsumerOps; 177 SmallVector<Operation *> fusedConsumerOps; 178 179 for (Operation *target : payloadOps) { 180 rewriter.setInsertionPoint(target); 181 182 while (numConsumerToFuse--) { 183 FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults = 184 scf::tileAndFuseConsumerOfSlice(rewriter, target); 185 186 if (failed(fuseConsumerResults)) 187 return failure(); 188 189 // Report back the relevant handles to the transform op. 190 originalConsumerOps.push_back( 191 fuseConsumerResults->origConsumerOperand->getOwner()); 192 fusedConsumerOps.push_back( 193 fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner()); 194 } 195 } 196 197 transformResults.set(transformOp->getOpResult(0), originalConsumerOps); 198 transformResults.set(transformOp->getOpResult(1), fusedConsumerOps); 199 return success(); 200 } 201 202 DiagnosedSilenceableFailure 203 transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, 204 TransformResults &transformResults, 205 TransformState &state) { 206 LogicalResult result = applyFuseConsumer( 207 rewriter, getOperation(), state.getPayloadOps(getTarget()), 208 getNumConsumerToFuse(), transformResults); 209 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() 210 : DiagnosedSilenceableFailure::success(); 211 } 212 213 void transform::TestFuseConsumerOp::getEffects( 214 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 215 consumesHandle(getTargetMutable(), effects); 216 producesHandle(getOperation()->getOpResults(), effects); 217 modifiesPayload(effects); 218 } 219 220 //===----------------------------------------------------------------------===// 221 // TestTileUsingForallOp 222 //===----------------------------------------------------------------------===// 223 224 /// Apply a tiling transformation to all payload ops and store both the 225 /// tiled operation as well as the created tile loops. 226 template <typename Range> 227 static LogicalResult 228 applyTileToAll(RewriterBase &rewriter, Operation *transformOp, 229 Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes, 230 ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping, 231 TransformResults &transformResults) { 232 SmallVector<Operation *> tiledOps; 233 SmallVector<Operation *> loopOps; 234 235 for (Operation *target : payloadOps) { 236 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); 237 if (!tilingInterfaceOp) 238 return transformOp->emitError("only TilingInterface ops are supported"); 239 scf::SCFTilingOptions tilingOptions; 240 tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); 241 if (mapping) { 242 tilingOptions.setMapping(mapping.value().getValue()); 243 } 244 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); 245 246 rewriter.setInsertionPoint(target); 247 FailureOr<scf::SCFTilingResult> tiledResults = 248 scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions); 249 if (failed(tiledResults)) 250 return failure(); 251 252 // Perform the replacement of tiled and fused values. 253 rewriter.replaceOp(tilingInterfaceOp, 254 tiledResults->mergeResult.replacements); 255 256 // Report back the relevant handles to the transform op. 257 tiledOps.push_back(tiledResults->tiledOps.front()); 258 for (Operation *loop : tiledResults->loops) 259 loopOps.push_back(loop); 260 } 261 262 transformResults.set(transformOp->getOpResult(0), tiledOps); 263 for (auto [index, loop] : llvm::enumerate(loopOps)) 264 transformResults.set(transformOp->getOpResult(index + 1), {loop}); 265 266 return success(); 267 } 268 269 DiagnosedSilenceableFailure 270 transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter, 271 TransformResults &transformResults, 272 TransformState &state) { 273 SmallVector<int64_t> tileSizes = 274 extractFromIntegerArrayAttr<int64_t>(getTileSizes()); 275 SmallVector<int64_t> interchange = 276 extractFromIntegerArrayAttr<int64_t>(getInterchange()); 277 SmallVector<OpFoldResult> tileSizesOfr = 278 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); 279 280 LogicalResult result = 281 applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()), 282 tileSizesOfr, interchange, getMapping(), transformResults); 283 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() 284 : DiagnosedSilenceableFailure::success(); 285 } 286 287 void transform::TestTileUsingForallOp::getEffects( 288 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 289 consumesHandle(getTargetMutable(), effects); 290 producesHandle(getOperation()->getOpResults(), effects); 291 modifiesPayload(effects); 292 } 293 294 //===----------------------------------------------------------------------===// 295 // TestFuseUsingForallOp 296 //===----------------------------------------------------------------------===// 297 298 /// Apply a tiling transformation to all payload ops and store both the 299 /// tiled operation as well as the created tile loops. 300 template <typename Range> 301 static LogicalResult applyTilingToAll( 302 RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, 303 unsigned numLoops, TransformResults &transformResults, 304 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)> 305 applyFn) { 306 SmallVector<Operation *> tiledLinalgOps; 307 SmallVector<SmallVector<Operation *>> loopOps(1); 308 309 for (Operation *target : payloadOps) { 310 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); 311 if (!tilingInterfaceOp) 312 return transformOp->emitError("only TilingInterface ops are supported"); 313 314 rewriter.setInsertionPoint(target); 315 FailureOr<scf::SCFTileAndFuseResult> tiledResults = 316 applyFn(tilingInterfaceOp); 317 if (failed(tiledResults)) 318 return failure(); 319 320 // Perform the replacement of tiled and fused values. 321 SmallVector<Operation *> opsToReplace{target}; 322 llvm::append_range(opsToReplace, tiledResults->fusedProducers); 323 for (Operation *toReplace : opsToReplace) { 324 for (OpResult res : toReplace->getResults()) 325 if (auto replacement = tiledResults->replacements.lookup(res)) 326 rewriter.replaceAllUsesWith(res, replacement); 327 if (toReplace->use_empty()) 328 rewriter.eraseOp(toReplace); 329 } 330 331 // Report back the relevant handles to the transform op. 332 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front()); 333 assert(tiledResults->loops.size() == 1 && 334 cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops && 335 "Mismatched number of loops, tile and fuse transform should have " 336 "failed"); 337 loopOps[0] = {tiledResults->loops[0]}; 338 } 339 340 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); 341 if (!loopOps.empty()) 342 transformResults.set(transformOp->getOpResult(1), loopOps[0]); 343 344 return success(); 345 } 346 347 DiagnosedSilenceableFailure 348 transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter, 349 TransformResults &transformResults, 350 TransformState &state) { 351 SmallVector<int64_t> tileSizes = 352 extractFromIntegerArrayAttr<int64_t>(getTileSizes()); 353 SmallVector<int64_t> tileInterchange = 354 extractFromIntegerArrayAttr<int64_t>(getInterchange()); 355 356 scf::SCFTilingOptions tilingOptions; 357 tilingOptions.interchangeVector = tileInterchange; 358 SmallVector<OpFoldResult> tileSizesOfr = 359 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); 360 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); 361 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); 362 scf::SCFTileAndFuseOptions tileAndFuseOptions; 363 tileAndFuseOptions.tilingOptions = tilingOptions; 364 LogicalResult result = applyTilingToAll( 365 rewriter, getOperation(), state.getPayloadOps(getRootOp()), 366 tileSizes.size() - llvm::count(tileSizes, 0), transformResults, 367 [&](TilingInterface tilingInterfaceOp) 368 -> FailureOr<scf::SCFTileAndFuseResult> { 369 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, 370 tileAndFuseOptions); 371 }); 372 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() 373 : DiagnosedSilenceableFailure::success(); 374 } 375 376 void transform::TestFuseUsingForallOp::getEffects( 377 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 378 consumesHandle(getRootOpMutable(), effects); 379 producesHandle(getOperation()->getOpResults(), effects); 380 modifiesPayload(effects); 381 } 382 383 #define GET_OP_CLASSES 384 #include "TestTilingInterfaceTransformOps.cpp.inc" 385 386 namespace { 387 class TestTilingInterfaceDialectExtension 388 : public transform::TransformDialectExtension< 389 TestTilingInterfaceDialectExtension> { 390 public: 391 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 392 TestTilingInterfaceDialectExtension) 393 394 using Base::Base; 395 396 void init() { 397 declareDependentDialect<affine::AffineDialect>(); 398 declareDependentDialect<index::IndexDialect>(); 399 declareDependentDialect<scf::SCFDialect>(); 400 declareDependentDialect<tensor::TensorDialect>(); 401 402 registerTransformOps< 403 #define GET_OP_LIST 404 #include "TestTilingInterfaceTransformOps.cpp.inc" 405 >(); 406 } 407 }; 408 } // namespace 409 410 namespace test { 411 void registerTestTilingInterfaceTransformDialectExtension( 412 DialectRegistry ®istry) { 413 registry.addExtensions<TestTilingInterfaceDialectExtension>(); 414 } 415 } // namespace test 416