xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp (revision dd16cd731dfb4746a351380edc848199cf9631e8)
1 //===- GlobalIdRewriter.cpp - Implementation of GlobalId rewriting  -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect rewriting of the global_id op for archs
10 // where global_id.x = threadId.x + blockId.x * blockDim.x
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/Dialect/GPU/Transforms/Passes.h"
16 #include "mlir/Dialect/Index/IR/IndexOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Pass/Pass.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 struct GpuGlobalIdRewriter : public OpRewritePattern<gpu::GlobalIdOp> {
25   using OpRewritePattern<gpu::GlobalIdOp>::OpRewritePattern;
26 
matchAndRewrite__anonab3eb14c0111::GpuGlobalIdRewriter27   LogicalResult matchAndRewrite(gpu::GlobalIdOp op,
28                                 PatternRewriter &rewriter) const override {
29     auto loc = op.getLoc();
30     auto dim = op.getDimension();
31     auto blockId = rewriter.create<gpu::BlockIdOp>(loc, dim);
32     auto blockDim = rewriter.create<gpu::BlockDimOp>(loc, dim);
33     // Compute blockId.x * blockDim.x
34     auto tmp = rewriter.create<index::MulOp>(op.getLoc(), blockId, blockDim);
35     auto threadId = rewriter.create<gpu::ThreadIdOp>(loc, dim);
36     // Compute threadId.x + blockId.x * blockDim.x
37     rewriter.replaceOpWithNewOp<index::AddOp>(op, threadId, tmp);
38     return success();
39   }
40 };
41 } // namespace
42 
populateGpuGlobalIdPatterns(RewritePatternSet & patterns)43 void mlir::populateGpuGlobalIdPatterns(RewritePatternSet &patterns) {
44   patterns.add<GpuGlobalIdRewriter>(patterns.getContext());
45 }
46