1 //===- GlobalIdRewriter.cpp - Implementation of GlobalId rewriting -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements in-dialect rewriting of the global_id op for archs 10 // where global_id.x = threadId.x + blockId.x * blockDim.x 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 15 #include "mlir/Dialect/GPU/Transforms/Passes.h" 16 #include "mlir/Dialect/Index/IR/IndexOps.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Pass/Pass.h" 20 21 using namespace mlir; 22 23 namespace { 24 struct GpuGlobalIdRewriter : public OpRewritePattern<gpu::GlobalIdOp> { 25 using OpRewritePattern<gpu::GlobalIdOp>::OpRewritePattern; 26 matchAndRewrite__anonab3eb14c0111::GpuGlobalIdRewriter27 LogicalResult matchAndRewrite(gpu::GlobalIdOp op, 28 PatternRewriter &rewriter) const override { 29 auto loc = op.getLoc(); 30 auto dim = op.getDimension(); 31 auto blockId = rewriter.create<gpu::BlockIdOp>(loc, dim); 32 auto blockDim = rewriter.create<gpu::BlockDimOp>(loc, dim); 33 // Compute blockId.x * blockDim.x 34 auto tmp = rewriter.create<index::MulOp>(op.getLoc(), blockId, blockDim); 35 auto threadId = rewriter.create<gpu::ThreadIdOp>(loc, dim); 36 // Compute threadId.x + blockId.x * blockDim.x 37 rewriter.replaceOpWithNewOp<index::AddOp>(op, threadId, tmp); 38 return success(); 39 } 40 }; 41 } // namespace 42 populateGpuGlobalIdPatterns(RewritePatternSet & patterns)43void mlir::populateGpuGlobalIdPatterns(RewritePatternSet &patterns) { 44 patterns.add<GpuGlobalIdRewriter>(patterns.getContext()); 45 } 46