1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <optional> 10 #include <type_traits> 11 12 #include "mlir/Analysis/SliceAnalysis.h" 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 23 #include "mlir/Dialect/SCF/IR/SCF.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Dialect/Vector/IR/VectorOps.h" 26 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 27 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 28 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 29 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 30 #include "mlir/Pass/Pass.h" 31 #include "mlir/Pass/PassManager.h" 32 #include "mlir/Support/LLVM.h" 33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 34 35 using namespace mlir; 36 using namespace mlir::linalg; 37 using namespace mlir::vector; 38 39 namespace { 40 41 struct TestVectorToVectorLowering 42 : public PassWrapper<TestVectorToVectorLowering, 43 OperationPass<func::FuncOp>> { 44 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering) 45 46 TestVectorToVectorLowering() = default; 47 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) 48 : PassWrapper(pass) {} 49 StringRef getArgument() const final { 50 return "test-vector-to-vector-lowering"; 51 } 52 StringRef getDescription() const final { 53 return "Test lowering patterns between ops in the vector dialect"; 54 } 55 56 void getDependentDialects(DialectRegistry ®istry) const override { 57 registry.insert<affine::AffineDialect>(); 58 registry.insert<vector::VectorDialect>(); 59 } 60 61 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 62 llvm::cl::init(false)}; 63 64 void runOnOperation() override { 65 auto *ctx = &getContext(); 66 RewritePatternSet patterns(ctx); 67 if (unroll) { 68 populateVectorUnrollPatterns( 69 patterns, 70 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 71 filter)); 72 } 73 populateVectorToVectorCanonicalizationPatterns(patterns); 74 populateBubbleVectorBitCastOpPatterns(patterns); 75 populateCastAwayVectorLeadingOneDimPatterns(patterns); 76 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 77 } 78 79 private: 80 // Return the target shape based on op type. 81 static std::optional<SmallVector<int64_t>> getShape(Operation *op) { 82 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op)) 83 return SmallVector<int64_t>(2, 2); 84 if (isa<vector::ContractionOp>(op)) 85 return SmallVector<int64_t>(3, 2); 86 // For transfer ops, just propagate the shape coming from 87 // InsertStridedSlices/ExtractStridedSlices. 88 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 89 VectorType dstVec; 90 for (Operation *users : readOp->getUsers()) { 91 auto extract = dyn_cast<ExtractStridedSliceOp>(users); 92 if (!extract) 93 return std::nullopt; 94 auto vecType = cast<VectorType>(extract.getResult().getType()); 95 if (dstVec && dstVec != vecType) 96 return std::nullopt; 97 dstVec = vecType; 98 } 99 return SmallVector<int64_t>(dstVec.getShape()); 100 } 101 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 102 auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>(); 103 if (!insert) 104 return std::nullopt; 105 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 106 return SmallVector<int64_t>(shape); 107 } 108 return std::nullopt; 109 } 110 111 static LogicalResult filter(Operation *op) { 112 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp, 113 ContractionOp, TransferReadOp, TransferWriteOp>(op)); 114 } 115 }; 116 117 struct TestVectorContractionPrepareForMMTLowering 118 : public PassWrapper<TestVectorContractionPrepareForMMTLowering, 119 OperationPass<func::FuncOp>> { 120 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 121 TestVectorContractionPrepareForMMTLowering) 122 123 StringRef getArgument() const final { 124 return "test-vector-contraction-prepare-for-mmt-lowering"; 125 } 126 StringRef getDescription() const final { 127 return "Test vector.contraction matmul canonicalization for MMT lowering."; 128 } 129 TestVectorContractionPrepareForMMTLowering() = default; 130 131 void getDependentDialects(DialectRegistry ®istry) const override { 132 registry.insert<affine::AffineDialect, arith::ArithDialect, 133 vector::VectorDialect>(); 134 } 135 136 void runOnOperation() override { 137 MLIRContext *ctx = &getContext(); 138 RewritePatternSet patterns(ctx); 139 vector::populateVectorContractCanonicalizeMatmulToMMT(patterns); 140 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 141 } 142 }; 143 144 struct TestVectorUnrollingPatterns 145 : public PassWrapper<TestVectorUnrollingPatterns, 146 OperationPass<func::FuncOp>> { 147 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns) 148 149 StringRef getArgument() const final { 150 return "test-vector-unrolling-patterns"; 151 } 152 StringRef getDescription() const final { 153 return "Test lowering patterns to unroll contract ops in the vector " 154 "dialect"; 155 } 156 TestVectorUnrollingPatterns() = default; 157 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) 158 : PassWrapper(pass) {} 159 void runOnOperation() override { 160 MLIRContext *ctx = &getContext(); 161 RewritePatternSet patterns(ctx); 162 populateVectorUnrollPatterns( 163 patterns, UnrollVectorOptions() 164 .setNativeShape(ArrayRef<int64_t>{2, 2}) 165 .setFilterConstraint([](Operation *op) { 166 return success(isa<arith::AddFOp, vector::FMAOp, 167 vector::MultiDimReductionOp>(op)); 168 })); 169 populateVectorUnrollPatterns( 170 patterns, UnrollVectorOptions() 171 .setNativeShape(ArrayRef<int64_t>{2}) 172 .setFilterConstraint([](Operation *op) { 173 return success(isa<vector::ReductionOp>(op)); 174 })); 175 populateVectorUnrollPatterns( 176 patterns, UnrollVectorOptions() 177 .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2}) 178 .setFilterConstraint([](Operation *op) { 179 return success(isa<vector::TransposeOp>(op)); 180 })); 181 182 if (unrollBasedOnType) { 183 UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 184 [](Operation *op) -> std::optional<SmallVector<int64_t>> { 185 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 186 SmallVector<int64_t> nativeShape(contractOp.getIteratorTypes().size(), 187 4); 188 Type lhsType = contractOp.getLhsType().getElementType(); 189 nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2; 190 return nativeShape; 191 }; 192 193 UnrollVectorOptions opts; 194 opts.setNativeShapeFn(nativeShapeFn) 195 .setFilterConstraint( 196 [](Operation *op) { return success(isa<ContractionOp>(op)); }); 197 198 if (!unrollOrder.empty()) { 199 opts.setUnrollTraversalOrderFn( 200 [this](Operation *op) -> std::optional<SmallVector<int64_t>> { 201 vector::ContractionOp contractOp = 202 cast<vector::ContractionOp>(op); 203 if (contractOp.getIteratorTypes().size() == unrollOrder.size()) 204 return SmallVector<int64_t>(unrollOrder.begin(), 205 unrollOrder.end()); 206 return std::nullopt; 207 }); 208 } 209 populateVectorUnrollPatterns(patterns, opts); 210 } else { 211 auto nativeShapeFn = 212 [](Operation *op) -> std::optional<SmallVector<int64_t>> { 213 auto contractOp = dyn_cast<ContractionOp>(op); 214 if (!contractOp) 215 return std::nullopt; 216 return SmallVector<int64_t>(contractOp.getIteratorTypes().size(), 2); 217 }; 218 populateVectorUnrollPatterns(patterns, 219 UnrollVectorOptions() 220 .setNativeShapeFn(nativeShapeFn) 221 .setFilterConstraint([](Operation *op) { 222 return success(isa<ContractionOp>(op)); 223 })); 224 } 225 populateVectorToVectorCanonicalizationPatterns(patterns); 226 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 227 } 228 229 ListOption<int64_t> unrollOrder{*this, "unroll-order", 230 llvm::cl::desc("set the unroll order")}; 231 232 Option<bool> unrollBasedOnType{ 233 *this, "unroll-based-on-type", 234 llvm::cl::desc("Set the unroll factor based on type of the operation"), 235 llvm::cl::init(false)}; 236 }; 237 238 struct TestVectorTransferUnrollingPatterns 239 : public PassWrapper<TestVectorTransferUnrollingPatterns, 240 OperationPass<func::FuncOp>> { 241 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 242 TestVectorTransferUnrollingPatterns) 243 244 TestVectorTransferUnrollingPatterns() = default; 245 TestVectorTransferUnrollingPatterns( 246 const TestVectorTransferUnrollingPatterns &pass) 247 : PassWrapper(pass) {} 248 249 void getDependentDialects(DialectRegistry ®istry) const override { 250 registry.insert<affine::AffineDialect>(); 251 } 252 StringRef getArgument() const final { 253 return "test-vector-transfer-unrolling-patterns"; 254 } 255 StringRef getDescription() const final { 256 return "Test lowering patterns to unroll transfer ops in the vector " 257 "dialect"; 258 } 259 void runOnOperation() override { 260 MLIRContext *ctx = &getContext(); 261 RewritePatternSet patterns(ctx); 262 UnrollVectorOptions opts; 263 opts.setNativeShape(ArrayRef<int64_t>{2, 2}) 264 .setFilterConstraint([](Operation *op) { 265 return success(isa<vector::TransferReadOp, vector::TransferWriteOp, 266 vector::GatherOp>(op)); 267 }); 268 if (reverseUnrollOrder.getValue()) { 269 opts.setUnrollTraversalOrderFn( 270 [](Operation *op) -> std::optional<SmallVector<int64_t>> { 271 int64_t numLoops = 0; 272 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) 273 numLoops = readOp.getVectorType().getRank(); 274 else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) 275 numLoops = writeOp.getVectorType().getRank(); 276 else if (auto gatherOp = dyn_cast<vector::GatherOp>(op)) 277 numLoops = gatherOp.getVectorType().getRank(); 278 else 279 return std::nullopt; 280 auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops)); 281 return llvm::to_vector(order); 282 }); 283 } 284 populateVectorUnrollPatterns(patterns, opts); 285 populateVectorToVectorCanonicalizationPatterns(patterns); 286 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 287 } 288 289 Option<bool> reverseUnrollOrder{ 290 *this, "reverse-unroll-order", 291 llvm::cl::desc( 292 "reverse the order of unrolling of vector transfer operations"), 293 llvm::cl::init(false)}; 294 }; 295 296 struct TestScalarVectorTransferLoweringPatterns 297 : public PassWrapper<TestScalarVectorTransferLoweringPatterns, 298 OperationPass<func::FuncOp>> { 299 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 300 TestScalarVectorTransferLoweringPatterns) 301 302 TestScalarVectorTransferLoweringPatterns() = default; 303 TestScalarVectorTransferLoweringPatterns( 304 const TestScalarVectorTransferLoweringPatterns &pass) 305 : PassWrapper(pass) {} 306 307 StringRef getArgument() const final { 308 return "test-scalar-vector-transfer-lowering"; 309 } 310 StringRef getDescription() const final { 311 return "Test lowering of scalar vector transfers to memref loads/stores."; 312 } 313 314 void getDependentDialects(DialectRegistry ®istry) const override { 315 registry.insert<affine::AffineDialect, memref::MemRefDialect, 316 tensor::TensorDialect, vector::VectorDialect>(); 317 } 318 319 Option<bool> allowMultipleUses{ 320 *this, "allow-multiple-uses", 321 llvm::cl::desc("Fold transfer operations with multiple uses"), 322 llvm::cl::init(false)}; 323 324 void runOnOperation() override { 325 MLIRContext *ctx = &getContext(); 326 RewritePatternSet patterns(ctx); 327 vector::populateScalarVectorTransferLoweringPatterns( 328 patterns, /*benefit=*/1, allowMultipleUses.getValue()); 329 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 330 } 331 }; 332 333 struct TestVectorTransferOpt 334 : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> { 335 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt) 336 337 StringRef getArgument() const final { return "test-vector-transferop-opt"; } 338 StringRef getDescription() const final { 339 return "Test optimization transformations for transfer ops"; 340 } 341 void runOnOperation() override { 342 IRRewriter rewriter(&getContext()); 343 transferOpflowOpt(rewriter, getOperation()); 344 } 345 }; 346 347 struct TestVectorTransferCollapseInnerMostContiguousDims 348 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 349 OperationPass<func::FuncOp>> { 350 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 351 TestVectorTransferCollapseInnerMostContiguousDims) 352 353 TestVectorTransferCollapseInnerMostContiguousDims() = default; 354 TestVectorTransferCollapseInnerMostContiguousDims( 355 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; 356 357 void getDependentDialects(DialectRegistry ®istry) const override { 358 registry.insert<memref::MemRefDialect, affine::AffineDialect>(); 359 } 360 361 StringRef getArgument() const final { 362 return "test-vector-transfer-collapse-inner-most-dims"; 363 } 364 365 StringRef getDescription() const final { 366 return "Test lowering patterns that reducedes the rank of the vector " 367 "transfer memory and vector operands."; 368 } 369 370 void runOnOperation() override { 371 RewritePatternSet patterns(&getContext()); 372 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 373 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 374 } 375 }; 376 377 struct TestVectorSinkPatterns 378 : public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> { 379 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns) 380 381 TestVectorSinkPatterns() = default; 382 TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default; 383 384 void getDependentDialects(DialectRegistry ®istry) const override { 385 registry.insert<memref::MemRefDialect, affine::AffineDialect>(); 386 } 387 388 StringRef getArgument() const final { return "test-vector-sink-patterns"; } 389 390 StringRef getDescription() const final { 391 return "Test lowering patterns that eliminate redundant brodacast " 392 "and transpose operations."; 393 } 394 395 void runOnOperation() override { 396 RewritePatternSet patterns(&getContext()); 397 populateSinkVectorOpsPatterns(patterns); 398 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 399 } 400 }; 401 402 struct TestVectorReduceToContractPatternsPatterns 403 : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 404 OperationPass<func::FuncOp>> { 405 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 406 TestVectorReduceToContractPatternsPatterns) 407 408 StringRef getArgument() const final { 409 return "test-vector-reduction-to-contract-patterns"; 410 } 411 StringRef getDescription() const final { 412 return "Test patterns to convert multireduce op to contract and combine " 413 "broadcast/transpose to contract"; 414 } 415 void runOnOperation() override { 416 RewritePatternSet patterns(&getContext()); 417 populateVectorReductionToContractPatterns(patterns); 418 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 419 } 420 }; 421 422 struct TestVectorChainedReductionFoldingPatterns 423 : public PassWrapper<TestVectorChainedReductionFoldingPatterns, 424 OperationPass<func::FuncOp>> { 425 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 426 TestVectorChainedReductionFoldingPatterns) 427 428 StringRef getArgument() const final { 429 return "test-vector-chained-reduction-folding-patterns"; 430 } 431 StringRef getDescription() const final { 432 return "Test patterns to fold chained vector reductions"; 433 } 434 void runOnOperation() override { 435 RewritePatternSet patterns(&getContext()); 436 populateChainedVectorReductionFoldingPatterns(patterns); 437 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 438 } 439 }; 440 441 struct TestVectorBreakDownReductionPatterns 442 : public PassWrapper<TestVectorBreakDownReductionPatterns, 443 OperationPass<func::FuncOp>> { 444 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 445 TestVectorBreakDownReductionPatterns) 446 447 StringRef getArgument() const final { 448 return "test-vector-break-down-reduction-patterns"; 449 } 450 StringRef getDescription() const final { 451 return "Test patterns to break down vector reductions into arith " 452 "reductions"; 453 } 454 void runOnOperation() override { 455 RewritePatternSet patterns(&getContext()); 456 populateBreakDownVectorReductionPatterns(patterns, 457 /*maxNumElementsToExtract=*/2); 458 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 459 } 460 }; 461 462 struct TestFlattenVectorTransferPatterns 463 : public PassWrapper<TestFlattenVectorTransferPatterns, 464 OperationPass<func::FuncOp>> { 465 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 466 TestFlattenVectorTransferPatterns) 467 468 TestFlattenVectorTransferPatterns() = default; 469 TestFlattenVectorTransferPatterns( 470 const TestFlattenVectorTransferPatterns &pass) 471 : PassWrapper(pass) {} 472 473 StringRef getArgument() const final { 474 return "test-vector-transfer-flatten-patterns"; 475 } 476 477 StringRef getDescription() const final { 478 return "Test patterns to rewrite contiguous row-major N-dimensional " 479 "vector.transfer_{read,write} ops into 1D transfers"; 480 } 481 482 void getDependentDialects(DialectRegistry ®istry) const override { 483 registry.insert<memref::MemRefDialect>(); 484 registry.insert<affine::AffineDialect>(); 485 registry.insert<vector::VectorDialect>(); 486 } 487 488 Option<unsigned> targetVectorBitwidth{ 489 *this, "target-vector-bitwidth", 490 llvm::cl::desc( 491 "Minimum vector bitwidth to enable the flattening transformation. " 492 "For scalable vectors this is the base size, i.e. the size " 493 "corresponding to vscale=1."), 494 llvm::cl::init(std::numeric_limits<unsigned>::max())}; 495 496 void runOnOperation() override { 497 RewritePatternSet patterns(&getContext()); 498 populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth); 499 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 500 } 501 }; 502 503 struct TestVectorScanLowering 504 : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> { 505 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering) 506 507 StringRef getArgument() const final { return "test-vector-scan-lowering"; } 508 StringRef getDescription() const final { 509 return "Test lowering patterns that lower the scan op in the vector " 510 "dialect"; 511 } 512 void runOnOperation() override { 513 RewritePatternSet patterns(&getContext()); 514 populateVectorScanLoweringPatterns(patterns); 515 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 516 } 517 }; 518 519 /// Allocate shared memory for a single warp to test lowering of 520 /// WarpExecuteOnLane0Op. 521 static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, 522 WarpExecuteOnLane0Op warpOp, 523 Type type) { 524 static constexpr int64_t kSharedMemorySpace = 3; 525 // Compute type of shared memory buffer. 526 MemRefType memrefType; 527 if (auto vectorType = dyn_cast<VectorType>(type)) { 528 memrefType = 529 MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, 530 kSharedMemorySpace); 531 } else { 532 memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace); 533 } 534 535 // Get symbol table holding all shared memory globals. 536 ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>(); 537 SymbolTable symbolTable(moduleOp); 538 539 // Create a pretty name. 540 SmallString<64> buf; 541 llvm::raw_svector_ostream os(buf); 542 interleave(memrefType.getShape(), os, "x"); 543 os << "x" << memrefType.getElementType(); 544 std::string symbolName = (Twine("__shared_") + os.str()).str(); 545 546 auto ip = builder.saveInsertionPoint(); 547 builder.setInsertionPoint(moduleOp); 548 auto global = builder.create<memref::GlobalOp>( 549 loc, 550 /*sym_name=*/symbolName, 551 /*sym_visibility=*/builder.getStringAttr("private"), 552 /*type=*/memrefType, 553 /*initial_value=*/Attribute(), 554 /*constant=*/false, 555 /*alignment=*/IntegerAttr()); 556 symbolTable.insert(global); 557 // The symbol table inserts at the end of the module, but globals are a bit 558 // nicer if they are at the beginning. 559 global->moveBefore(&moduleOp.front()); 560 561 builder.restoreInsertionPoint(ip); 562 return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName); 563 } 564 565 static Value warpReduction(Location loc, OpBuilder &builder, Value input, 566 CombiningKind kind, uint32_t size) { 567 // First reduce on a single thread to get per lane reduction value. 568 Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input); 569 // Parallel reduction using butterfly shuffles. 570 for (uint64_t i = 1; i < size; i <<= 1) { 571 Value shuffled = builder 572 .create<gpu::ShuffleOp>(loc, laneVal, i, 573 /*width=*/size, 574 /*mode=*/gpu::ShuffleMode::XOR) 575 .getShuffleResult(); 576 laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); 577 } 578 return laneVal; 579 } 580 581 struct TestVectorDistribution 582 : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> { 583 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution) 584 585 void getDependentDialects(DialectRegistry ®istry) const override { 586 registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect, 587 affine::AffineDialect>(); 588 } 589 590 StringRef getArgument() const final { return "test-vector-warp-distribute"; } 591 StringRef getDescription() const final { 592 return "Test vector warp distribute transformation and lowering patterns"; 593 } 594 TestVectorDistribution() = default; 595 TestVectorDistribution(const TestVectorDistribution &pass) 596 : PassWrapper(pass) {} 597 598 Option<bool> warpOpToSCF{ 599 *this, "rewrite-warp-ops-to-scf-if", 600 llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"), 601 llvm::cl::init(false)}; 602 603 Option<bool> distributeTransferWriteOps{ 604 *this, "distribute-transfer-write", 605 llvm::cl::desc("Test distribution of transfer write"), 606 llvm::cl::init(false)}; 607 608 Option<unsigned> maxTransferWriteElements{ 609 *this, "max-transfer-write-elements", 610 llvm::cl::desc("Maximum number of transfer write elements to distribute"), 611 llvm::cl::init(1)}; 612 613 Option<bool> hoistUniform{*this, "hoist-uniform", 614 llvm::cl::desc("Test hoist uniform"), 615 llvm::cl::init(false)}; 616 617 Option<bool> propagateDistribution{ 618 *this, "propagate-distribution", 619 llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)}; 620 621 void runOnOperation() override { 622 RewritePatternSet patterns(&getContext()); 623 624 getOperation().walk([&](Operation *op) { 625 if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) { 626 if (hoistUniform) { 627 moveScalarUniformCode(warpOp); 628 } 629 WalkResult::interrupt(); 630 } 631 }); 632 MLIRContext *ctx = &getContext(); 633 auto distributionFn = [](Value val) { 634 // Create an identity dim map of the same rank as the vector. 635 VectorType vecType = dyn_cast<VectorType>(val.getType()); 636 int64_t vecRank = vecType ? vecType.getRank() : 0; 637 OpBuilder builder(val.getContext()); 638 if (vecRank == 0) 639 return AffineMap::get(val.getContext()); 640 return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext()); 641 }; 642 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, 643 Value srcIdx, int64_t warpSz) { 644 assert((val.getType().isF32() || val.getType().isInteger(32)) && 645 "unsupported shuffle type"); 646 Type i32Type = builder.getIntegerType(32); 647 Value srcIdxI32 = 648 builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx); 649 Value warpSzI32 = builder.create<arith::ConstantOp>( 650 loc, builder.getIntegerAttr(i32Type, warpSz)); 651 Value result = builder 652 .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32, 653 gpu::ShuffleMode::IDX) 654 .getResult(0); 655 return result; 656 }; 657 if (distributeTransferWriteOps && propagateDistribution) { 658 RewritePatternSet patterns(ctx); 659 vector::populatePropagateWarpVectorDistributionPatterns( 660 patterns, distributionFn, shuffleFn, /*benefit=*/1, 661 /*readBenefit=*/0); 662 vector::populateDistributeReduction(patterns, warpReduction, 1); 663 populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2); 664 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 665 } else if (distributeTransferWriteOps) { 666 RewritePatternSet patterns(ctx); 667 populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 668 maxTransferWriteElements); 669 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 670 } else if (propagateDistribution) { 671 RewritePatternSet patterns(ctx); 672 vector::populatePropagateWarpVectorDistributionPatterns( 673 patterns, distributionFn, shuffleFn); 674 vector::populateDistributeReduction(patterns, warpReduction); 675 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 676 } 677 WarpExecuteOnLane0LoweringOptions options; 678 options.warpAllocationFn = allocateGlobalSharedMemory; 679 options.warpSyncronizationFn = [](Location loc, OpBuilder &builder, 680 WarpExecuteOnLane0Op warpOp) { 681 builder.create<gpu::BarrierOp>(loc); 682 }; 683 // Test on one pattern in isolation. 684 if (warpOpToSCF) { 685 populateWarpExecuteOnLane0OpToScfForPattern(patterns, options); 686 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 687 return; 688 } 689 } 690 }; 691 692 struct TestVectorExtractStridedSliceLowering 693 : public PassWrapper<TestVectorExtractStridedSliceLowering, 694 OperationPass<func::FuncOp>> { 695 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 696 TestVectorExtractStridedSliceLowering) 697 698 StringRef getArgument() const final { 699 return "test-vector-extract-strided-slice-lowering"; 700 } 701 StringRef getDescription() const final { 702 return "Test lowering patterns that converts vector.extract_strided_slice " 703 "into a chain of vector.extract and vector.insert ops"; 704 } 705 void runOnOperation() override { 706 RewritePatternSet patterns(&getContext()); 707 populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns); 708 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 709 } 710 }; 711 712 struct TestVectorContiguousExtractStridedSliceToExtract 713 : public PassWrapper<TestVectorContiguousExtractStridedSliceToExtract, 714 OperationPass<func::FuncOp>> { 715 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 716 TestVectorExtractStridedSliceLowering) 717 718 StringRef getArgument() const final { 719 return "test-vector-contiguous-extract-strided-slice-to-extract"; 720 } 721 StringRef getDescription() const final { 722 return "Test lowering patterns that rewrite simple cases of N-D " 723 "extract_strided_slice, where the slice is contiguous, into extract " 724 "and shape_cast"; 725 } 726 void runOnOperation() override { 727 RewritePatternSet patterns(&getContext()); 728 populateVectorContiguousExtractStridedSliceToExtractPatterns(patterns); 729 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 730 } 731 }; 732 733 struct TestVectorBreakDownBitCast 734 : public PassWrapper<TestVectorBreakDownBitCast, 735 OperationPass<func::FuncOp>> { 736 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBreakDownBitCast) 737 738 StringRef getArgument() const final { 739 return "test-vector-break-down-bitcast"; 740 } 741 StringRef getDescription() const final { 742 return "Test pattern that breaks down vector.bitcast ops "; 743 } 744 void runOnOperation() override { 745 RewritePatternSet patterns(&getContext()); 746 populateBreakDownVectorBitCastOpPatterns(patterns, [](BitCastOp op) { 747 return op.getSourceVectorType().getShape().back() > 4; 748 }); 749 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 750 } 751 }; 752 753 struct TestCreateVectorBroadcast 754 : public PassWrapper<TestCreateVectorBroadcast, 755 OperationPass<func::FuncOp>> { 756 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast) 757 758 StringRef getArgument() const final { return "test-create-vector-broadcast"; } 759 StringRef getDescription() const final { 760 return "Test optimization transformations for transfer ops"; 761 } 762 void getDependentDialects(DialectRegistry ®istry) const override { 763 registry.insert<vector::VectorDialect>(); 764 } 765 766 void runOnOperation() override { 767 getOperation()->walk([](Operation *op) { 768 if (op->getName().getStringRef() != "test_create_broadcast") 769 return; 770 auto targetShape = 771 cast<VectorType>(op->getResult(0).getType()).getShape(); 772 auto arrayAttr = 773 cast<DenseI64ArrayAttr>(op->getDiscardableAttr("broadcast_dims")) 774 .asArrayRef(); 775 llvm::SetVector<int64_t> broadcastedDims; 776 broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end()); 777 OpBuilder b(op); 778 Value bcast = vector::BroadcastOp::createOrFoldBroadcastOp( 779 b, op->getOperand(0), targetShape, broadcastedDims); 780 op->getResult(0).replaceAllUsesWith(bcast); 781 op->erase(); 782 }); 783 } 784 }; 785 786 struct TestVectorGatherLowering 787 : public PassWrapper<TestVectorGatherLowering, 788 OperationPass<func::FuncOp>> { 789 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherLowering) 790 791 StringRef getArgument() const final { return "test-vector-gather-lowering"; } 792 StringRef getDescription() const final { 793 return "Test patterns that lower the gather op in the vector conditional " 794 "loads"; 795 } 796 void getDependentDialects(DialectRegistry ®istry) const override { 797 registry.insert<arith::ArithDialect, func::FuncDialect, 798 memref::MemRefDialect, scf::SCFDialect, 799 tensor::TensorDialect, vector::VectorDialect>(); 800 } 801 802 void runOnOperation() override { 803 RewritePatternSet patterns(&getContext()); 804 populateVectorGatherLoweringPatterns(patterns); 805 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 806 } 807 }; 808 809 struct TestFoldArithExtensionIntoVectorContractPatterns 810 : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns, 811 OperationPass<func::FuncOp>> { 812 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 813 TestFoldArithExtensionIntoVectorContractPatterns) 814 815 StringRef getArgument() const final { 816 return "test-fold-arith-extf-into-vector-contract-patterns"; 817 } 818 StringRef getDescription() const final { 819 return "Test patterns that fold arithmetic extension ops into vector " 820 "contract ops"; 821 } 822 823 void getDependentDialects(DialectRegistry ®istry) const override { 824 registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect, 825 memref::MemRefDialect, scf::SCFDialect, 826 tensor::TensorDialect, vector::VectorDialect>(); 827 } 828 829 void runOnOperation() override { 830 RewritePatternSet patterns(&getContext()); 831 populateFoldArithExtensionPatterns(patterns); 832 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 833 } 834 }; 835 836 struct TestVectorEmulateMaskedLoadStore final 837 : public PassWrapper<TestVectorEmulateMaskedLoadStore, 838 OperationPass<func::FuncOp>> { 839 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore) 840 841 StringRef getArgument() const override { 842 return "test-vector-emulate-masked-load-store"; 843 } 844 StringRef getDescription() const override { 845 return "Test patterns that emulate the maskedload/maskedstore op by " 846 " memref.load/store and scf.if"; 847 } 848 void getDependentDialects(DialectRegistry ®istry) const override { 849 registry 850 .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect, 851 scf::SCFDialect, vector::VectorDialect>(); 852 } 853 854 void runOnOperation() override { 855 RewritePatternSet patterns(&getContext()); 856 populateVectorMaskedLoadStoreEmulationPatterns(patterns); 857 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 858 } 859 }; 860 861 struct TestVectorLinearize final 862 : public PassWrapper<TestVectorLinearize, OperationPass<>> { 863 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) 864 865 TestVectorLinearize() = default; 866 TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {} 867 868 StringRef getArgument() const override { return "test-vector-linearize"; } 869 StringRef getDescription() const override { 870 return "Linearizes ND vectors for N >= 2 into 1D vectors"; 871 } 872 void getDependentDialects(DialectRegistry ®istry) const override { 873 registry.insert<vector::VectorDialect>(); 874 } 875 876 Option<unsigned> targetVectorBitwidth{ 877 *this, "target-vector-bitwidth", 878 llvm::cl::desc( 879 "Minimum vector bitwidth to enable the flattening transformation"), 880 llvm::cl::init(std::numeric_limits<unsigned>::max())}; 881 void runOnOperation() override { 882 auto *context = &getContext(); 883 884 TypeConverter typeConverter; 885 RewritePatternSet patterns(context); 886 ConversionTarget target(*context); 887 888 vector::populateVectorLinearizeTypeConversionsAndLegality( 889 typeConverter, patterns, target, targetVectorBitwidth); 890 vector::populateVectorLinearizeShuffleLikeOpsPatterns( 891 typeConverter, patterns, target, targetVectorBitwidth); 892 if (failed(applyPartialConversion(getOperation(), target, 893 std::move(patterns)))) 894 return signalPassFailure(); 895 } 896 }; 897 898 struct TestEliminateVectorMasks 899 : public PassWrapper<TestEliminateVectorMasks, 900 OperationPass<func::FuncOp>> { 901 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEliminateVectorMasks) 902 903 TestEliminateVectorMasks() = default; 904 TestEliminateVectorMasks(const TestEliminateVectorMasks &pass) 905 : PassWrapper(pass) {} 906 907 Option<unsigned> vscaleMin{ 908 *this, "vscale-min", llvm::cl::desc("Minimum possible value of vscale."), 909 llvm::cl::init(1)}; 910 Option<unsigned> vscaleMax{ 911 *this, "vscale-max", llvm::cl::desc("Maximum possible value of vscale."), 912 llvm::cl::init(16)}; 913 914 StringRef getArgument() const final { return "test-eliminate-vector-masks"; } 915 StringRef getDescription() const final { 916 return "Test eliminating vector masks"; 917 } 918 void runOnOperation() override { 919 IRRewriter rewriter(&getContext()); 920 eliminateVectorMasks(rewriter, getOperation(), 921 VscaleRange{vscaleMin, vscaleMax}); 922 } 923 }; 924 } // namespace 925 926 namespace mlir { 927 namespace test { 928 void registerTestVectorLowerings() { 929 PassRegistration<TestVectorToVectorLowering>(); 930 931 PassRegistration<TestVectorContractionPrepareForMMTLowering>(); 932 933 PassRegistration<TestVectorUnrollingPatterns>(); 934 935 PassRegistration<TestVectorTransferUnrollingPatterns>(); 936 937 PassRegistration<TestScalarVectorTransferLoweringPatterns>(); 938 939 PassRegistration<TestVectorTransferOpt>(); 940 941 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 942 943 PassRegistration<TestVectorSinkPatterns>(); 944 945 PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 946 947 PassRegistration<TestVectorChainedReductionFoldingPatterns>(); 948 949 PassRegistration<TestVectorBreakDownReductionPatterns>(); 950 951 PassRegistration<TestFlattenVectorTransferPatterns>(); 952 953 PassRegistration<TestVectorScanLowering>(); 954 955 PassRegistration<TestVectorDistribution>(); 956 957 PassRegistration<TestVectorExtractStridedSliceLowering>(); 958 959 PassRegistration<TestVectorContiguousExtractStridedSliceToExtract>(); 960 961 PassRegistration<TestVectorBreakDownBitCast>(); 962 963 PassRegistration<TestCreateVectorBroadcast>(); 964 965 PassRegistration<TestVectorGatherLowering>(); 966 967 PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>(); 968 969 PassRegistration<TestVectorEmulateMaskedLoadStore>(); 970 971 PassRegistration<TestVectorLinearize>(); 972 973 PassRegistration<TestEliminateVectorMasks>(); 974 } 975 } // namespace test 976 } // namespace mlir 977