1 //===- TestTensorTransforms.cpp - Test Tensor transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Tensor transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/SCF/IR/SCF.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h" 18 #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 19 #include "mlir/Dialect/Transform/IR/TransformOps.h" 20 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 24 using namespace mlir; 25 26 namespace { 27 struct TestTensorTransforms 28 : public PassWrapper<TestTensorTransforms, OperationPass<>> { 29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorTransforms) 30 31 TestTensorTransforms() = default; 32 TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {} 33 34 void getDependentDialects(DialectRegistry ®istry) const override { 35 registry.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect, 36 transform::TransformDialect>(); 37 } 38 39 StringRef getArgument() const final { 40 return "test-tensor-transform-patterns"; 41 } 42 StringRef getDescription() const final { 43 return "Test Tensor transformation patterns by applying them greedily."; 44 } 45 46 void runOnOperation() override; 47 48 Option<bool> testFoldConstantExtractSlice{ 49 *this, "test-fold-constant-extract-slice", 50 llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"), 51 llvm::cl::init(false)}; 52 53 Option<bool> testFoldConsecutiveInsertExtractSlice{ 54 *this, "test-fold-consecutive-insert-extract-slice", 55 llvm::cl::desc( 56 "Test folding consecutive tensor.insert_slice/tensor.extract_slice"), 57 llvm::cl::init(false)}; 58 59 Option<bool> testRewriteExtractSliceWithTiledCollapseShape{ 60 *this, "test-rewrite-extract-slice-from-collapse-shape", 61 llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape " 62 "with loop nest"), 63 llvm::cl::init(false)}; 64 65 Option<bool> testDropRedundantInsertSliceRankExpansion{ 66 *this, "test-drop-redundant-insert-slice-rank-expansion", 67 llvm::cl::desc("Test dropping redundant insert_slice rank expansions"), 68 llvm::cl::init(false)}; 69 70 Option<bool> testReassociativeReshapeFolding{ 71 *this, "test-reassociative-reshape-folding", 72 llvm::cl::desc("Test folding of expand_shape/collapse_shape"), 73 llvm::cl::init(false)}; 74 75 Option<bool> testBubbleUpExpandShapePatterns{ 76 *this, "test-expand-shape-bubbling", 77 llvm::cl::desc("Test folding of expand_shape/collapse_shape"), 78 llvm::cl::init(false)}; 79 80 Option<bool> testFoldIntoPackAndUnpack{ 81 *this, "test-fold-into-pack-and-unpack", 82 llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"), 83 llvm::cl::init(false)}; 84 85 Option<bool> useForeach{ 86 *this, "use-foreach", 87 llvm::cl::desc( 88 "Use the scf.forall operation when generating loop nests for " 89 "the extract_slice of collapse_shape pattern"), 90 llvm::cl::init(false)}; 91 92 Option<bool> testSimplifyPackUnpackPatterns{ 93 *this, "test-simplify-pack-unpack-patterns", 94 llvm::cl::desc("Test patterns to simplify tensor.pack and tensor.unpack"), 95 llvm::cl::init(false)}; 96 97 Option<bool> testTrackingListener{ 98 *this, "test-tracking-listener", 99 llvm::cl::desc("Test tensor TrackingListener for the transform dialect"), 100 llvm::cl::init(false)}; 101 }; 102 } // namespace 103 104 static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) { 105 RewritePatternSet patterns(rootOp->getContext()); 106 tensor::populateReassociativeReshapeFoldingPatterns(patterns); 107 (void)applyPatternsGreedily(rootOp, std::move(patterns)); 108 } 109 110 static void applyBubbleUpExpandShapePatterns(Operation *rootOp) { 111 RewritePatternSet patterns(rootOp->getContext()); 112 tensor::populateBubbleUpExpandShapePatterns(patterns); 113 (void)applyPatternsGreedily(rootOp, std::move(patterns)); 114 } 115 116 static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) { 117 RewritePatternSet patterns(rootOp->getContext()); 118 tensor::populateFoldIntoPackAndUnpackPatterns(patterns); 119 (void)applyPatternsGreedily(rootOp, std::move(patterns)); 120 } 121 122 static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) { 123 RewritePatternSet patterns(rootOp->getContext()); 124 tensor::ControlConstantExtractSliceFusionFn controlFn = 125 [](tensor::ExtractSliceOp op) { 126 if (!op.getSource().hasOneUse()) 127 return false; 128 129 auto resultType = cast<ShapedType>(op.getResult().getType()); 130 constexpr int64_t kConstantFoldingMaxNumElements = 1024; 131 return resultType.getNumElements() <= kConstantFoldingMaxNumElements; 132 }; 133 134 tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn); 135 (void)applyPatternsGreedily(rootOp, std::move(patterns)); 136 } 137 138 static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation *rootOp) { 139 RewritePatternSet patterns(rootOp->getContext()); 140 tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); 141 (void)applyPatternsGreedily(rootOp, std::move(patterns)); 142 } 143 144 static void 145 applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) { 146 RewritePatternSet patterns(rootOp->getContext()); 147 tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns); 148 (void)applyPatternsGreedily(rootOp, std::move(patterns)); 149 } 150 151 static void applySimplifyPackUnpackPatterns(Operation *rootOp) { 152 RewritePatternSet patterns(rootOp->getContext()); 153 tensor::populateSimplifyPackAndUnpackPatterns(patterns); 154 (void)applyPatternsGreedily(rootOp, std::move(patterns)); 155 } 156 157 namespace { 158 /// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`. 159 /// The `tensor.extract_slice` is replaced by a loop or gather operation that 160 /// stitches together the desired tile from slices of the source of the collapse 161 /// shape op. 162 struct RewriteExtractSliceFromCollapseShapeBase 163 : public OpRewritePattern<tensor::ExtractSliceOp> { 164 RewriteExtractSliceFromCollapseShapeBase(MLIRContext *context) 165 : mlir::OpRewritePattern<tensor::ExtractSliceOp>(context) {} 166 167 /// Emit a loop or gather operation that uses `helper` to take each point in 168 /// the parallel iteration space bounds, extract a slice from the source 169 /// tensor and insert it into `dest`. For examples, see below for `scf.for` 170 /// and `scf.foreach`. 171 virtual LogicalResult 172 emitReplacement(tensor::ExtractSliceOp op, Value dest, 173 tensor::ExtractSliceFromCollapseHelper &helper, 174 PatternRewriter &rewriter) const = 0; 175 176 LogicalResult matchAndRewrite(tensor::ExtractSliceOp op, 177 PatternRewriter &rewriter) const override { 178 auto collapseOp = op.getSource().getDefiningOp<tensor::CollapseShapeOp>(); 179 if (!collapseOp) 180 return rewriter.notifyMatchFailure( 181 op, "producer is not a tensor.collapse_shape op"); 182 183 // Try to simplify the collapse shape using a rank-reducing slice, if 184 // possible. 185 FailureOr<Operation *> simplifiedCollapseShapeResult = 186 tensor::simplifyCollapseShapeWithRankReducingExtractSlice(collapseOp, 187 rewriter); 188 if (succeeded(simplifiedCollapseShapeResult)) { 189 auto newCollapseOp = 190 dyn_cast<tensor::CollapseShapeOp>(*simplifiedCollapseShapeResult); 191 // The collapse shape op might have been simplified away, so we can just 192 // return. 193 if (!newCollapseOp) 194 return success(); 195 collapseOp = newCollapseOp; 196 } 197 198 // Materialize the output shape values of the slice operation. 199 ReifiedRankedShapedTypeDims reifiedShapes; 200 if (failed(reifyResultShapes(rewriter, op, reifiedShapes))) 201 return rewriter.notifyMatchFailure(op, "failed to reify result shapes"); 202 203 // Create the destination tensor using the above values. 204 Type elementType = op.getSourceType().getElementType(); 205 SmallVector<OpFoldResult> outputShape = reifiedShapes[0]; 206 Value dest = rewriter.create<tensor::EmptyOp>(op->getLoc(), outputShape, 207 elementType); 208 209 // Calculate the parameters for the tile loop nest. 210 FailureOr<tensor::ExtractSliceFromCollapseHelper> params = 211 tensor::ExtractSliceFromCollapseHelper::create(rewriter, collapseOp, 212 op); 213 if (failed(params)) 214 return rewriter.notifyMatchFailure( 215 op, "could not calculate tiling parameters"); 216 return emitReplacement(op, dest, *params, rewriter); 217 } 218 }; 219 220 struct RewriteExtractSliceFromCollapseShapeUsingScfFor 221 : public RewriteExtractSliceFromCollapseShapeBase { 222 RewriteExtractSliceFromCollapseShapeUsingScfFor(MLIRContext *context) 223 : RewriteExtractSliceFromCollapseShapeBase(context) {} 224 LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest, 225 tensor::ExtractSliceFromCollapseHelper &helper, 226 PatternRewriter &rewriter) const override { 227 Location loc = op.getLoc(); 228 const unsigned numTiledDims = helper.getIterationSpaceSizes().size(); 229 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 230 auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 231 SmallVector<Value> lbs(numTiledDims, zero); 232 SmallVector<Value> steps(numTiledDims, one); 233 234 scf::LoopNest nest = scf::buildLoopNest( 235 rewriter, loc, lbs, helper.getIterationSpaceSizes(), steps, dest, 236 [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs, 237 ValueRange iterArgs) -> scf::ValueVector { 238 auto [tile, insertParams] = 239 helper.emitLoopNestBody(nestedBuilder, loc, outputIvs); 240 241 // Insert the slice into the destination. 242 return {nestedBuilder.create<tensor::InsertSliceOp>( 243 loc, tile, iterArgs[0], insertParams)}; 244 }); 245 rewriter.replaceOp(op, nest.results); 246 247 return success(); 248 } 249 }; 250 251 struct RewriteExtractSliceFromCollapseShapeUsingScfForeach 252 : public RewriteExtractSliceFromCollapseShapeBase { 253 RewriteExtractSliceFromCollapseShapeUsingScfForeach(MLIRContext *context) 254 : RewriteExtractSliceFromCollapseShapeBase(context) {} 255 LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest, 256 tensor::ExtractSliceFromCollapseHelper &helper, 257 PatternRewriter &rewriter) const override { 258 Location loc = op.getLoc(); 259 auto forallOp = rewriter.create<scf::ForallOp>( 260 loc, /*numThreads=*/getAsOpFoldResult(helper.getIterationSpaceSizes()), 261 /*outputs=*/dest, 262 /*mapping=*/std::nullopt, 263 [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) { 264 unsigned numThreadIdRegionArgs = 265 helper.getIterationSpaceSizes().size(); 266 unsigned numOutputRegionArgs = 267 regionArgs.size() - numThreadIdRegionArgs; 268 ValueRange outputIvs = regionArgs.take_front(numThreadIdRegionArgs); 269 ValueRange outputArgs = regionArgs.take_back(numOutputRegionArgs); 270 assert(outputArgs.size() == 1 && 271 "there should only be one output region argument"); 272 auto [tile, insertParams] = 273 helper.emitLoopNestBody(nestedBuilder, loc, outputIvs); 274 // Insert the slice into the destination. 275 auto term = nestedBuilder.create<scf::InParallelOp>(loc); 276 nestedBuilder.setInsertionPointToStart(term.getBody()); 277 nestedBuilder.create<tensor::ParallelInsertSliceOp>( 278 loc, tile, outputArgs[0], insertParams); 279 }); 280 rewriter.replaceOp(op, forallOp->getResult(0)); 281 return success(); 282 } 283 }; 284 } // namespace 285 286 static LogicalResult 287 applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp, 288 bool useForeach) { 289 RewritePatternSet patterns(rootOp->getContext()); 290 if (useForeach) 291 patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfForeach>( 292 rootOp->getContext()); 293 else 294 patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfFor>( 295 rootOp->getContext()); 296 return applyPatternsGreedily(rootOp, std::move(patterns)); 297 } 298 299 namespace { 300 class DummyTrackingListener : public transform::TrackingListener { 301 public: 302 using transform::TrackingListener::TrackingListener; 303 304 // Expose `findReplacementOp` as a public function, so that it can be tested. 305 Operation *getReplacementOp(Operation *op, ValueRange newValues) const { 306 Operation *replacementOp; 307 if (!findReplacementOp(replacementOp, op, newValues).succeeded()) 308 return nullptr; 309 return replacementOp; 310 } 311 }; 312 } // namespace 313 314 static LogicalResult testTrackingListenerReplacements(Operation *rootOp) { 315 // Find replaced op. 316 Operation *replaced = nullptr; 317 WalkResult status = rootOp->walk([&](Operation *op) { 318 if (op->hasAttr("replaced")) { 319 if (replaced) { 320 op->emitError("only one 'replaced' op is allowed per test case"); 321 replaced->emitRemark("other 'replaced' op"); 322 return WalkResult::interrupt(); 323 } 324 replaced = op; 325 } 326 return WalkResult::advance(); 327 }); 328 if (status.wasInterrupted()) 329 return failure(); 330 if (!replaced) { 331 rootOp->emitError("could not find 'replaced' op"); 332 return failure(); 333 } 334 335 // Find replacements. 336 SmallVector<Value> replacements(replaced->getNumResults(), Value()); 337 status = rootOp->walk([&](Operation *op) { 338 for (int64_t i = 0; i < replaced->getNumResults(); ++i) { 339 if (auto attr = op->getAttrOfType<IntegerAttr>("replacement_" + 340 std::to_string(i))) { 341 if (replacements[i]) { 342 op->emitError("only one 'replacement_" + std::to_string(i) + 343 "' is allowed per test case"); 344 replacements[i].getDefiningOp()->emitRemark("other 'replacement_" + 345 std::to_string(i) + "'"); 346 return WalkResult::interrupt(); 347 } 348 replacements[i] = op->getResult(attr.getInt()); 349 } 350 } 351 return WalkResult::advance(); 352 }); 353 if (status.wasInterrupted()) 354 return failure(); 355 356 if (!llvm::all_of(replacements, 357 [](Value v) { return static_cast<bool>(v); })) { 358 replaced->emitError("insufficient replacement values"); 359 return failure(); 360 } 361 362 // Find the replacement op (if any) and emit a remark/error. 363 transform::TransformState transformState = 364 transform::detail::makeTransformStateForTesting(/*region=*/nullptr, 365 /*payloadRoot=*/nullptr); 366 MLIRContext *context = rootOp->getContext(); 367 OpBuilder builder(context); 368 OwningOpRef<transform::NamedSequenceOp> transformOp = 369 builder.create<transform::NamedSequenceOp>( 370 rootOp->getLoc(), 371 /*sym_name=*/"test_sequence", 372 /*function_type=*/ 373 TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})), 374 /*sym_visibility*/ StringAttr::get(context, "public"), 375 /*arg_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()), 376 /*res_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>())); 377 DummyTrackingListener listener(transformState, transformOp.get()); 378 Operation *replacement = listener.getReplacementOp(replaced, replacements); 379 if (!replacement) { 380 replaced->emitError("listener could not find replacement op"); 381 return failure(); 382 } 383 384 replacement->emitRemark("replacement found"); 385 return success(); 386 } 387 388 void TestTensorTransforms::runOnOperation() { 389 Operation *rootOp = getOperation(); 390 if (testSimplifyPackUnpackPatterns) 391 applySimplifyPackUnpackPatterns(rootOp); 392 if (testFoldConstantExtractSlice) 393 applyFoldConstantExtractSlicePatterns(rootOp); 394 if (testFoldConsecutiveInsertExtractSlice) 395 applyFoldConsecutiveInsertExtractSlicePatterns(rootOp); 396 if (testDropRedundantInsertSliceRankExpansion) 397 applyDropRedundantInsertSliceRankExpansionPatterns(rootOp); 398 if (testReassociativeReshapeFolding) 399 applyReassociativeReshapeFoldingPatterns(rootOp); 400 if (testBubbleUpExpandShapePatterns) 401 applyBubbleUpExpandShapePatterns(rootOp); 402 if (testFoldIntoPackAndUnpack) 403 applyFoldIntoPackAndUnpackPatterns(rootOp); 404 if (testRewriteExtractSliceWithTiledCollapseShape) { 405 if (failed( 406 applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach))) 407 return signalPassFailure(); 408 } 409 if (testTrackingListener) 410 if (failed(testTrackingListenerReplacements(rootOp))) 411 return signalPassFailure(); 412 } 413 414 namespace mlir { 415 namespace test { 416 void registerTestTensorTransforms() { 417 PassRegistration<TestTensorTransforms>(); 418 } 419 } // namespace test 420 } // namespace mlir 421