1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <optional> 10 #include <type_traits> 11 12 #include "mlir/Analysis/SliceAnalysis.h" 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 23 #include "mlir/Dialect/SCF/IR/SCF.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Dialect/Vector/IR/VectorOps.h" 26 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 27 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 28 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 29 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 30 #include "mlir/Pass/Pass.h" 31 #include "mlir/Pass/PassManager.h" 32 #include "mlir/Support/LLVM.h" 33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 34 35 using namespace mlir; 36 using namespace mlir::linalg; 37 using namespace mlir::vector; 38 39 namespace { 40 41 struct TestVectorToVectorLowering 42 : public PassWrapper<TestVectorToVectorLowering, 43 OperationPass<func::FuncOp>> { 44 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering) 45 46 TestVectorToVectorLowering() = default; 47 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) 48 : PassWrapper(pass) {} 49 StringRef getArgument() const final { 50 return "test-vector-to-vector-lowering"; 51 } 52 StringRef getDescription() const final { 53 return "Test lowering patterns between ops in the vector dialect"; 54 } 55 56 void getDependentDialects(DialectRegistry ®istry) const override { 57 registry.insert<affine::AffineDialect>(); 58 registry.insert<vector::VectorDialect>(); 59 } 60 61 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 62 llvm::cl::init(false)}; 63 64 void runOnOperation() override { 65 auto *ctx = &getContext(); 66 RewritePatternSet patterns(ctx); 67 if (unroll) { 68 populateVectorUnrollPatterns( 69 patterns, 70 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 71 filter)); 72 } 73 populateVectorToVectorCanonicalizationPatterns(patterns); 74 populateBubbleVectorBitCastOpPatterns(patterns); 75 populateCastAwayVectorLeadingOneDimPatterns(patterns); 76 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 77 } 78 79 private: 80 // Return the target shape based on op type. 81 static std::optional<SmallVector<int64_t>> getShape(Operation *op) { 82 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op)) 83 return SmallVector<int64_t>(2, 2); 84 if (isa<vector::ContractionOp>(op)) 85 return SmallVector<int64_t>(3, 2); 86 // For transfer ops, just propagate the shape coming from 87 // InsertStridedSlices/ExtractStridedSlices. 88 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 89 VectorType dstVec; 90 for (Operation *users : readOp->getUsers()) { 91 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 92 if (!extract) 93 return std::nullopt; 94 auto vecType = cast<VectorType>(extract.getResult().getType()); 95 if (dstVec && dstVec != vecType) 96 return std::nullopt; 97 dstVec = vecType; 98 } 99 return SmallVector<int64_t>(dstVec.getShape().begin(), 100 dstVec.getShape().end()); 101 } 102 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 103 auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>(); 104 if (!insert) 105 return std::nullopt; 106 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 107 return SmallVector<int64_t>(shape.begin(), shape.end()); 108 } 109 return std::nullopt; 110 } 111 112 static LogicalResult filter(Operation *op) { 113 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp, 114 ContractionOp, TransferReadOp, TransferWriteOp>(op)); 115 } 116 }; 117 118 struct TestVectorContractionPrepareForMMTLowering 119 : public PassWrapper<TestVectorContractionPrepareForMMTLowering, 120 OperationPass<func::FuncOp>> { 121 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 122 TestVectorContractionPrepareForMMTLowering) 123 124 StringRef getArgument() const final { 125 return "test-vector-contraction-prepare-for-mmt-lowering"; 126 } 127 StringRef getDescription() const final { 128 return "Test vector.contraction matmul canonicalization for MMT lowering."; 129 } 130 TestVectorContractionPrepareForMMTLowering() = default; 131 132 void getDependentDialects(DialectRegistry ®istry) const override { 133 registry.insert<affine::AffineDialect, arith::ArithDialect, 134 vector::VectorDialect>(); 135 } 136 137 void runOnOperation() override { 138 MLIRContext *ctx = &getContext(); 139 RewritePatternSet patterns(ctx); 140 vector::populateVectorContractCanonicalizeMatmulToMMT(patterns); 141 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 142 } 143 }; 144 145 struct TestVectorUnrollingPatterns 146 : public PassWrapper<TestVectorUnrollingPatterns, 147 OperationPass<func::FuncOp>> { 148 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns) 149 150 StringRef getArgument() const final { 151 return "test-vector-unrolling-patterns"; 152 } 153 StringRef getDescription() const final { 154 return "Test lowering patterns to unroll contract ops in the vector " 155 "dialect"; 156 } 157 TestVectorUnrollingPatterns() = default; 158 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) 159 : PassWrapper(pass) {} 160 void runOnOperation() override { 161 MLIRContext *ctx = &getContext(); 162 RewritePatternSet patterns(ctx); 163 populateVectorUnrollPatterns( 164 patterns, UnrollVectorOptions() 165 .setNativeShape(ArrayRef<int64_t>{2, 2}) 166 .setFilterConstraint([](Operation *op) { 167 return success(isa<arith::AddFOp, vector::FMAOp, 168 vector::MultiDimReductionOp>(op)); 169 })); 170 populateVectorUnrollPatterns( 171 patterns, UnrollVectorOptions() 172 .setNativeShape(ArrayRef<int64_t>{2}) 173 .setFilterConstraint([](Operation *op) { 174 return success(isa<vector::ReductionOp>(op)); 175 })); 176 populateVectorUnrollPatterns( 177 patterns, UnrollVectorOptions() 178 .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2}) 179 .setFilterConstraint([](Operation *op) { 180 return success(isa<vector::TransposeOp>(op)); 181 })); 182 183 if (unrollBasedOnType) { 184 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 185 [](Operation *op) -> std::optional<SmallVector<int64_t>> { 186 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 187 SmallVector<int64_t> nativeShape(contractOp.getIteratorTypes().size(), 188 4); 189 Type lhsType = contractOp.getLhsType().getElementType(); 190 nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2; 191 return nativeShape; 192 }; 193 194 UnrollVectorOptions opts; 195 opts.setNativeShapeFn(nativeShapeFn) 196 .setFilterConstraint( 197 [](Operation *op) { return success(isa<ContractionOp>(op)); }); 198 199 if (!unrollOrder.empty()) { 200 opts.setUnrollTraversalOrderFn( 201 [this](Operation *op) -> std::optional<SmallVector<int64_t>> { 202 vector::ContractionOp contractOp = 203 cast<vector::ContractionOp>(op); 204 if (contractOp.getIteratorTypes().size() == unrollOrder.size()) 205 return SmallVector<int64_t>(unrollOrder.begin(), 206 unrollOrder.end()); 207 return std::nullopt; 208 }); 209 } 210 populateVectorUnrollPatterns(patterns, opts); 211 } else { 212 auto nativeShapeFn = 213 [](Operation *op) -> std::optional<SmallVector<int64_t>> { 214 auto contractOp = dyn_cast<ContractionOp>(op); 215 if (!contractOp) 216 return std::nullopt; 217 return SmallVector<int64_t>(contractOp.getIteratorTypes().size(), 2); 218 }; 219 populateVectorUnrollPatterns(patterns, 220 UnrollVectorOptions() 221 .setNativeShapeFn(nativeShapeFn) 222 .setFilterConstraint([](Operation *op) { 223 return success(isa<ContractionOp>(op)); 224 })); 225 } 226 populateVectorToVectorCanonicalizationPatterns(patterns); 227 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 228 } 229 230 ListOption<int64_t> unrollOrder{*this, "unroll-order", 231 llvm::cl::desc("set the unroll order")}; 232 233 Option<bool> unrollBasedOnType{ 234 *this, "unroll-based-on-type", 235 llvm::cl::desc("Set the unroll factor based on type of the operation"), 236 llvm::cl::init(false)}; 237 }; 238 239 struct TestVectorTransferUnrollingPatterns 240 : public PassWrapper<TestVectorTransferUnrollingPatterns, 241 OperationPass<func::FuncOp>> { 242 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 243 TestVectorTransferUnrollingPatterns) 244 245 TestVectorTransferUnrollingPatterns() = default; 246 TestVectorTransferUnrollingPatterns( 247 const TestVectorTransferUnrollingPatterns &pass) 248 : PassWrapper(pass) {} 249 250 void getDependentDialects(DialectRegistry ®istry) const override { 251 registry.insert<affine::AffineDialect>(); 252 } 253 StringRef getArgument() const final { 254 return "test-vector-transfer-unrolling-patterns"; 255 } 256 StringRef getDescription() const final { 257 return "Test lowering patterns to unroll transfer ops in the vector " 258 "dialect"; 259 } 260 void runOnOperation() override { 261 MLIRContext *ctx = &getContext(); 262 RewritePatternSet patterns(ctx); 263 UnrollVectorOptions opts; 264 opts.setNativeShape(ArrayRef<int64_t>{2, 2}) 265 .setFilterConstraint([](Operation *op) { 266 return success(isa<vector::TransferReadOp, vector::TransferWriteOp, 267 vector::GatherOp>(op)); 268 }); 269 if (reverseUnrollOrder.getValue()) { 270 opts.setUnrollTraversalOrderFn( 271 [](Operation *op) -> std::optional<SmallVector<int64_t>> { 272 int64_t numLoops = 0; 273 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) 274 numLoops = readOp.getVectorType().getRank(); 275 else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) 276 numLoops = writeOp.getVectorType().getRank(); 277 else if (auto gatherOp = dyn_cast<vector::GatherOp>(op)) 278 numLoops = gatherOp.getVectorType().getRank(); 279 else 280 return std::nullopt; 281 auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops)); 282 return llvm::to_vector(order); 283 }); 284 } 285 populateVectorUnrollPatterns(patterns, opts); 286 populateVectorToVectorCanonicalizationPatterns(patterns); 287 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 288 } 289 290 Option<bool> reverseUnrollOrder{ 291 *this, "reverse-unroll-order", 292 llvm::cl::desc( 293 "reverse the order of unrolling of vector transfer operations"), 294 llvm::cl::init(false)}; 295 }; 296 297 struct TestScalarVectorTransferLoweringPatterns 298 : public PassWrapper<TestScalarVectorTransferLoweringPatterns, 299 OperationPass<func::FuncOp>> { 300 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 301 TestScalarVectorTransferLoweringPatterns) 302 303 TestScalarVectorTransferLoweringPatterns() = default; 304 TestScalarVectorTransferLoweringPatterns( 305 const TestScalarVectorTransferLoweringPatterns &pass) 306 : PassWrapper(pass) {} 307 308 StringRef getArgument() const final { 309 return "test-scalar-vector-transfer-lowering"; 310 } 311 StringRef getDescription() const final { 312 return "Test lowering of scalar vector transfers to memref loads/stores."; 313 } 314 315 void getDependentDialects(DialectRegistry ®istry) const override { 316 registry.insert<affine::AffineDialect, memref::MemRefDialect, 317 tensor::TensorDialect, vector::VectorDialect>(); 318 } 319 320 Option<bool> allowMultipleUses{ 321 *this, "allow-multiple-uses", 322 llvm::cl::desc("Fold transfer operations with multiple uses"), 323 llvm::cl::init(false)}; 324 325 void runOnOperation() override { 326 MLIRContext *ctx = &getContext(); 327 RewritePatternSet patterns(ctx); 328 vector::populateScalarVectorTransferLoweringPatterns( 329 patterns, /*benefit=*/1, allowMultipleUses.getValue()); 330 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 331 } 332 }; 333 334 struct TestVectorTransferOpt 335 : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> { 336 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt) 337 338 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 339 StringRef getDescription() const final { 340 return "Test optimization transformations for transfer ops"; 341 } 342 void runOnOperation() override { 343 IRRewriter rewriter(&getContext()); 344 transferOpflowOpt(rewriter, getOperation()); 345 } 346 }; 347 348 struct TestVectorTransferCollapseInnerMostContiguousDims 349 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 350 OperationPass<func::FuncOp>> { 351 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 352 TestVectorTransferCollapseInnerMostContiguousDims) 353 354 TestVectorTransferCollapseInnerMostContiguousDims() = default; 355 TestVectorTransferCollapseInnerMostContiguousDims( 356 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; 357 358 void getDependentDialects(DialectRegistry ®istry) const override { 359 registry.insert<memref::MemRefDialect, affine::AffineDialect>(); 360 } 361 362 StringRef getArgument() const final { 363 return "test-vector-transfer-collapse-inner-most-dims"; 364 } 365 366 StringRef getDescription() const final { 367 return "Test lowering patterns that reducedes the rank of the vector " 368 "transfer memory and vector operands."; 369 } 370 371 void runOnOperation() override { 372 RewritePatternSet patterns(&getContext()); 373 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 374 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 375 } 376 }; 377 378 struct TestSinkVectorBroadcast 379 : public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> { 380 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast) 381 382 TestSinkVectorBroadcast() = default; 383 TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default; 384 385 void getDependentDialects(DialectRegistry ®istry) const override { 386 registry.insert<memref::MemRefDialect, affine::AffineDialect>(); 387 } 388 389 StringRef getArgument() const final { return "test-sink-vector-broadcast"; } 390 391 StringRef getDescription() const final { 392 return "Test lowering patterns that eliminate redundant brodacast " 393 "operations."; 394 } 395 396 void runOnOperation() override { 397 RewritePatternSet patterns(&getContext()); 398 populateSinkVectorBroadcastPatterns(patterns); 399 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 400 } 401 }; 402 403 struct TestVectorReduceToContractPatternsPatterns 404 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 405 OperationPass<func::FuncOp>> { 406 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 407 TestVectorReduceToContractPatternsPatterns) 408 409 StringRef getArgument() const final { 410 return "test-vector-reduction-to-contract-patterns"; 411 } 412 StringRef getDescription() const final { 413 return "Test patterns to convert multireduce op to contract and combine " 414 "broadcast/transpose to contract"; 415 } 416 void runOnOperation() override { 417 RewritePatternSet patterns(&getContext()); 418 populateVectorReductionToContractPatterns(patterns); 419 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 420 } 421 }; 422 423 struct TestVectorChainedReductionFoldingPatterns 424 : public PassWrapper<TestVectorChainedReductionFoldingPatterns, 425 OperationPass<func::FuncOp>> { 426 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 427 TestVectorChainedReductionFoldingPatterns) 428 429 StringRef getArgument() const final { 430 return "test-vector-chained-reduction-folding-patterns"; 431 } 432 StringRef getDescription() const final { 433 return "Test patterns to fold chained vector reductions"; 434 } 435 void runOnOperation() override { 436 RewritePatternSet patterns(&getContext()); 437 populateChainedVectorReductionFoldingPatterns(patterns); 438 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 439 } 440 }; 441 442 struct TestVectorBreakDownReductionPatterns 443 : public PassWrapper<TestVectorBreakDownReductionPatterns, 444 OperationPass<func::FuncOp>> { 445 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 446 TestVectorBreakDownReductionPatterns) 447 448 StringRef getArgument() const final { 449 return "test-vector-break-down-reduction-patterns"; 450 } 451 StringRef getDescription() const final { 452 return "Test patterns to break down vector reductions into arith " 453 "reductions"; 454 } 455 void runOnOperation() override { 456 RewritePatternSet patterns(&getContext()); 457 populateBreakDownVectorReductionPatterns(patterns, 458 /*maxNumElementsToExtract=*/2); 459 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 460 } 461 }; 462 463 struct TestFlattenVectorTransferPatterns 464 : public PassWrapper<TestFlattenVectorTransferPatterns, 465 OperationPass<func::FuncOp>> { 466 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 467 TestFlattenVectorTransferPatterns) 468 469 StringRef getArgument() const final { 470 return "test-vector-transfer-flatten-patterns"; 471 } 472 StringRef getDescription() const final { 473 return "Test patterns to rewrite contiguous row-major N-dimensional " 474 "vector.transfer_{read,write} ops into 1D transfers"; 475 } 476 void getDependentDialects(DialectRegistry ®istry) const override { 477 registry.insert<memref::MemRefDialect>(); 478 registry.insert<affine::AffineDialect>(); 479 registry.insert<vector::VectorDialect>(); 480 } 481 void runOnOperation() override { 482 RewritePatternSet patterns(&getContext()); 483 populateFlattenVectorTransferPatterns(patterns); 484 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 485 } 486 }; 487 488 struct TestVectorScanLowering 489 : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> { 490 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering) 491 492 StringRef getArgument() const final { return "test-vector-scan-lowering"; } 493 StringRef getDescription() const final { 494 return "Test lowering patterns that lower the scan op in the vector " 495 "dialect"; 496 } 497 void runOnOperation() override { 498 RewritePatternSet patterns(&getContext()); 499 populateVectorScanLoweringPatterns(patterns); 500 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 501 } 502 }; 503 504 /// Allocate shared memory for a single warp to test lowering of 505 /// WarpExecuteOnLane0Op. 506 static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, 507 WarpExecuteOnLane0Op warpOp, 508 Type type) { 509 static constexpr int64_t kSharedMemorySpace = 3; 510 // Compute type of shared memory buffer. 511 MemRefType memrefType; 512 if (auto vectorType = dyn_cast<VectorType>(type)) { 513 memrefType = 514 MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, 515 kSharedMemorySpace); 516 } else { 517 memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace); 518 } 519 520 // Get symbol table holding all shared memory globals. 521 ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>(); 522 SymbolTable symbolTable(moduleOp); 523 524 // Create a pretty name. 525 SmallString<64> buf; 526 llvm::raw_svector_ostream os(buf); 527 interleave(memrefType.getShape(), os, "x"); 528 os << "x" << memrefType.getElementType(); 529 std::string symbolName = (Twine("__shared_") + os.str()).str(); 530 531 auto ip = builder.saveInsertionPoint(); 532 builder.setInsertionPoint(moduleOp); 533 auto global = builder.create<memref::GlobalOp>( 534 loc, 535 /*sym_name=*/symbolName, 536 /*sym_visibility=*/builder.getStringAttr("private"), 537 /*type=*/memrefType, 538 /*initial_value=*/Attribute(), 539 /*constant=*/false, 540 /*alignment=*/IntegerAttr()); 541 symbolTable.insert(global); 542 // The symbol table inserts at the end of the module, but globals are a bit 543 // nicer if they are at the beginning. 544 global->moveBefore(&moduleOp.front()); 545 546 builder.restoreInsertionPoint(ip); 547 return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName); 548 } 549 550 static Value warpReduction(Location loc, OpBuilder &builder, Value input, 551 CombiningKind kind, uint32_t size) { 552 // First reduce on a single thread to get per lane reduction value. 553 Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input); 554 // Parallel reduction using butterfly shuffles. 555 for (uint64_t i = 1; i < size; i <<= 1) { 556 Value shuffled = builder 557 .create<gpu::ShuffleOp>(loc, laneVal, i, 558 /*width=*/size, 559 /*mode=*/gpu::ShuffleMode::XOR) 560 .getShuffleResult(); 561 laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); 562 } 563 return laneVal; 564 } 565 566 struct TestVectorDistribution 567 : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> { 568 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution) 569 570 void getDependentDialects(DialectRegistry ®istry) const override { 571 registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect, 572 affine::AffineDialect>(); 573 } 574 575 StringRef getArgument() const final { return "test-vector-warp-distribute"; } 576 StringRef getDescription() const final { 577 return "Test vector warp distribute transformation and lowering patterns"; 578 } 579 TestVectorDistribution() = default; 580 TestVectorDistribution(const TestVectorDistribution &pass) 581 : PassWrapper(pass) {} 582 583 Option<bool> warpOpToSCF{ 584 *this, "rewrite-warp-ops-to-scf-if", 585 llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"), 586 llvm::cl::init(false)}; 587 588 Option<bool> distributeTransferWriteOps{ 589 *this, "distribute-transfer-write", 590 llvm::cl::desc("Test distribution of transfer write"), 591 llvm::cl::init(false)}; 592 593 Option<unsigned> maxTransferWriteElements{ 594 *this, "max-transfer-write-elements", 595 llvm::cl::desc("Maximum number of transfer write elements to distribute"), 596 llvm::cl::init(1)}; 597 598 Option<bool> hoistUniform{*this, "hoist-uniform", 599 llvm::cl::desc("Test hoist uniform"), 600 llvm::cl::init(false)}; 601 602 Option<bool> propagateDistribution{ 603 *this, "propagate-distribution", 604 llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)}; 605 606 void runOnOperation() override { 607 RewritePatternSet patterns(&getContext()); 608 609 getOperation().walk([&](Operation *op) { 610 if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) { 611 if (hoistUniform) { 612 moveScalarUniformCode(warpOp); 613 } 614 WalkResult::interrupt(); 615 } 616 }); 617 MLIRContext *ctx = &getContext(); 618 auto distributionFn = [](Value val) { 619 // Create a map (d0, d1) -> (d1) to distribute along the inner 620 // dimension. Once we support n-d distribution we can add more 621 // complex cases. 622 VectorType vecType = dyn_cast<VectorType>(val.getType()); 623 int64_t vecRank = vecType ? vecType.getRank() : 0; 624 OpBuilder builder(val.getContext()); 625 if (vecRank == 0) 626 return AffineMap::get(val.getContext()); 627 return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1)); 628 }; 629 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, 630 Value srcIdx, int64_t warpSz) { 631 assert((val.getType().isF32() || val.getType().isInteger(32)) && 632 "unsupported shuffle type"); 633 Type i32Type = builder.getIntegerType(32); 634 Value srcIdxI32 = 635 builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx); 636 Value warpSzI32 = builder.create<arith::ConstantOp>( 637 loc, builder.getIntegerAttr(i32Type, warpSz)); 638 Value result = builder 639 .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32, 640 gpu::ShuffleMode::IDX) 641 .getResult(0); 642 return result; 643 }; 644 if (distributeTransferWriteOps && propagateDistribution) { 645 RewritePatternSet patterns(ctx); 646 vector::populatePropagateWarpVectorDistributionPatterns( 647 patterns, distributionFn, shuffleFn, /*benefit=*/1, 648 /*readBenefit=*/0); 649 vector::populateDistributeReduction(patterns, warpReduction, 1); 650 populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2); 651 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 652 } else if (distributeTransferWriteOps) { 653 RewritePatternSet patterns(ctx); 654 populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 655 maxTransferWriteElements); 656 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 657 } else if (propagateDistribution) { 658 RewritePatternSet patterns(ctx); 659 vector::populatePropagateWarpVectorDistributionPatterns( 660 patterns, distributionFn, shuffleFn); 661 vector::populateDistributeReduction(patterns, warpReduction); 662 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 663 } 664 WarpExecuteOnLane0LoweringOptions options; 665 options.warpAllocationFn = allocateGlobalSharedMemory; 666 options.warpSyncronizationFn = [](Location loc, OpBuilder &builder, 667 WarpExecuteOnLane0Op warpOp) { 668 builder.create<gpu::BarrierOp>(loc); 669 }; 670 // Test on one pattern in isolation. 671 if (warpOpToSCF) { 672 populateWarpExecuteOnLane0OpToScfForPattern(patterns, options); 673 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 674 return; 675 } 676 } 677 }; 678 679 struct TestVectorExtractStridedSliceLowering 680 : public PassWrapper<TestVectorExtractStridedSliceLowering, 681 OperationPass<func::FuncOp>> { 682 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 683 TestVectorExtractStridedSliceLowering) 684 685 StringRef getArgument() const final { 686 return "test-vector-extract-strided-slice-lowering"; 687 } 688 StringRef getDescription() const final { 689 return "Test lowering patterns that converts vector.extract_strided_slice " 690 "into a chain of vector.extract and vector.insert ops"; 691 } 692 void runOnOperation() override { 693 RewritePatternSet patterns(&getContext()); 694 populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns); 695 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 696 } 697 }; 698 699 struct TestVectorBreakDownBitCast 700 : public PassWrapper<TestVectorBreakDownBitCast, 701 OperationPass<func::FuncOp>> { 702 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBreakDownBitCast) 703 704 StringRef getArgument() const final { 705 return "test-vector-break-down-bitcast"; 706 } 707 StringRef getDescription() const final { 708 return "Test pattern that breaks down vector.bitcast ops "; 709 } 710 void runOnOperation() override { 711 RewritePatternSet patterns(&getContext()); 712 populateBreakDownVectorBitCastOpPatterns(patterns, [](BitCastOp op) { 713 return op.getSourceVectorType().getShape().back() > 4; 714 }); 715 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 716 } 717 }; 718 719 struct TestCreateVectorBroadcast 720 : public PassWrapper<TestCreateVectorBroadcast, 721 OperationPass<func::FuncOp>> { 722 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast) 723 724 StringRef getArgument() const final { return "test-create-vector-broadcast"; } 725 StringRef getDescription() const final { 726 return "Test optimization transformations for transfer ops"; 727 } 728 void getDependentDialects(DialectRegistry ®istry) const override { 729 registry.insert<vector::VectorDialect>(); 730 } 731 732 void runOnOperation() override { 733 getOperation()->walk([](Operation *op) { 734 if (op->getName().getStringRef() != "test_create_broadcast") 735 return; 736 auto targetShape = 737 cast<VectorType>(op->getResult(0).getType()).getShape(); 738 auto arrayAttr = 739 cast<DenseI64ArrayAttr>(op->getDiscardableAttr("broadcast_dims")) 740 .asArrayRef(); 741 llvm::SetVector<int64_t> broadcastedDims; 742 broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end()); 743 OpBuilder b(op); 744 Value bcast = vector::BroadcastOp::createOrFoldBroadcastOp( 745 b, op->getOperand(0), targetShape, broadcastedDims); 746 op->getResult(0).replaceAllUsesWith(bcast); 747 op->erase(); 748 }); 749 } 750 }; 751 752 struct TestVectorGatherLowering 753 : public PassWrapper<TestVectorGatherLowering, 754 OperationPass<func::FuncOp>> { 755 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherLowering) 756 757 StringRef getArgument() const final { return "test-vector-gather-lowering"; } 758 StringRef getDescription() const final { 759 return "Test patterns that lower the gather op in the vector conditional " 760 "loads"; 761 } 762 void getDependentDialects(DialectRegistry ®istry) const override { 763 registry.insert<arith::ArithDialect, func::FuncDialect, 764 memref::MemRefDialect, scf::SCFDialect, 765 tensor::TensorDialect, vector::VectorDialect>(); 766 } 767 768 void runOnOperation() override { 769 RewritePatternSet patterns(&getContext()); 770 populateVectorGatherLoweringPatterns(patterns); 771 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 772 } 773 }; 774 775 struct TestFoldArithExtensionIntoVectorContractPatterns 776 : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns, 777 OperationPass<func::FuncOp>> { 778 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 779 TestFoldArithExtensionIntoVectorContractPatterns) 780 781 StringRef getArgument() const final { 782 return "test-fold-arith-extf-into-vector-contract-patterns"; 783 } 784 StringRef getDescription() const final { 785 return "Test patterns that fold arithmetic extension ops into vector " 786 "contract ops"; 787 } 788 789 void getDependentDialects(DialectRegistry ®istry) const override { 790 registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect, 791 memref::MemRefDialect, scf::SCFDialect, 792 tensor::TensorDialect, vector::VectorDialect>(); 793 } 794 795 void runOnOperation() override { 796 RewritePatternSet patterns(&getContext()); 797 populateFoldArithExtensionPatterns(patterns); 798 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 799 } 800 }; 801 802 struct TestVectorEmulateMaskedLoadStore final 803 : public PassWrapper<TestVectorEmulateMaskedLoadStore, 804 OperationPass<func::FuncOp>> { 805 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore) 806 807 StringRef getArgument() const override { 808 return "test-vector-emulate-masked-load-store"; 809 } 810 StringRef getDescription() const override { 811 return "Test patterns that emulate the maskedload/maskedstore op by " 812 " memref.load/store and scf.if"; 813 } 814 void getDependentDialects(DialectRegistry ®istry) const override { 815 registry 816 .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect, 817 scf::SCFDialect, vector::VectorDialect>(); 818 } 819 820 void runOnOperation() override { 821 RewritePatternSet patterns(&getContext()); 822 populateVectorMaskedLoadStoreEmulationPatterns(patterns); 823 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 824 } 825 }; 826 827 struct TestVectorLinearize final 828 : public PassWrapper<TestVectorLinearize, OperationPass<>> { 829 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) 830 831 StringRef getArgument() const override { return "test-vector-linearize"; } 832 StringRef getDescription() const override { 833 return "Linearizes ND vectors for N >= 2 into 1D vectors"; 834 } 835 void getDependentDialects(DialectRegistry ®istry) const override { 836 registry.insert<vector::VectorDialect>(); 837 } 838 839 void runOnOperation() override { 840 auto *context = &getContext(); 841 842 TypeConverter typeConverter; 843 RewritePatternSet patterns(context); 844 ConversionTarget target(*context); 845 846 vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter, 847 patterns, target); 848 if (failed(applyPartialConversion(getOperation(), target, 849 std::move(patterns)))) 850 return signalPassFailure(); 851 } 852 }; 853 } // namespace 854 855 namespace mlir { 856 namespace test { 857 void registerTestVectorLowerings() { 858 PassRegistration<TestVectorToVectorLowering>(); 859 860 PassRegistration<TestVectorContractionPrepareForMMTLowering>(); 861 862 PassRegistration<TestVectorUnrollingPatterns>(); 863 864 PassRegistration<TestVectorTransferUnrollingPatterns>(); 865 866 PassRegistration<TestScalarVectorTransferLoweringPatterns>(); 867 868 PassRegistration<TestVectorTransferOpt>(); 869 870 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 871 872 PassRegistration<TestSinkVectorBroadcast>(); 873 874 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 875 876 PassRegistration<TestVectorChainedReductionFoldingPatterns>(); 877 878 PassRegistration<TestVectorBreakDownReductionPatterns>(); 879 880 PassRegistration<TestFlattenVectorTransferPatterns>(); 881 882 PassRegistration<TestVectorScanLowering>(); 883 884 PassRegistration<TestVectorDistribution>(); 885 886 PassRegistration<TestVectorExtractStridedSliceLowering>(); 887 888 PassRegistration<TestVectorBreakDownBitCast>(); 889 890 PassRegistration<TestCreateVectorBroadcast>(); 891 892 PassRegistration<TestVectorGatherLowering>(); 893 894 PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>(); 895 896 PassRegistration<TestVectorEmulateMaskedLoadStore>(); 897 898 PassRegistration<TestVectorLinearize>(); 899 } 900 } // namespace test 901 } // namespace mlir 902