xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp (revision 330a232ae76139c3970df5ccaf1b51640cbd4d66)
1 //===- ShuffleRewriter.cpp - Implementation of shuffle rewriting  ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect rewriting of the shuffle op for types i64 and
10 // f64, rewriting 64bit shuffles into two 32bit shuffles. This particular
11 // implementation using shifts and truncations can be obtained using clang: by
12 // emitting IR for shuffle operations with `-O3`.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18 #include "mlir/Dialect/GPU/Transforms/Passes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Pass/Pass.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
27   using OpRewritePattern<gpu::ShuffleOp>::OpRewritePattern;
28 
initialize__anona36ad37b0111::GpuShuffleRewriter29   void initialize() {
30     // Required as the pattern will replace the Op with 2 additional ShuffleOps.
31     setHasBoundedRewriteRecursion();
32   }
matchAndRewrite__anona36ad37b0111::GpuShuffleRewriter33   LogicalResult matchAndRewrite(gpu::ShuffleOp op,
34                                 PatternRewriter &rewriter) const override {
35     auto loc = op.getLoc();
36     auto value = op.getValue();
37     auto valueType = value.getType();
38     auto valueLoc = value.getLoc();
39     auto i32 = rewriter.getI32Type();
40     auto i64 = rewriter.getI64Type();
41 
42     // If the type of the value is either i32 or f32, the op is already valid.
43     if (valueType.getIntOrFloatBitWidth() == 32)
44       return failure();
45 
46     Value lo, hi;
47 
48     // Float types must be converted to i64 to extract the bits.
49     if (isa<FloatType>(valueType))
50       value = rewriter.create<arith::BitcastOp>(valueLoc, i64, value);
51 
52     // Get the low bits by trunc(value).
53     lo = rewriter.create<arith::TruncIOp>(valueLoc, i32, value);
54 
55     // Get the high bits by trunc(value >> 32).
56     auto c32 = rewriter.create<arith::ConstantOp>(
57         valueLoc, rewriter.getIntegerAttr(i64, 32));
58     hi = rewriter.create<arith::ShRUIOp>(valueLoc, value, c32);
59     hi = rewriter.create<arith::TruncIOp>(valueLoc, i32, hi);
60 
61     // Shuffle the values.
62     ValueRange loRes =
63         rewriter
64             .create<gpu::ShuffleOp>(op.getLoc(), lo, op.getOffset(),
65                                     op.getWidth(), op.getMode())
66             .getResults();
67     ValueRange hiRes =
68         rewriter
69             .create<gpu::ShuffleOp>(op.getLoc(), hi, op.getOffset(),
70                                     op.getWidth(), op.getMode())
71             .getResults();
72 
73     // Convert lo back to i64.
74     lo = rewriter.create<arith::ExtUIOp>(valueLoc, i64, loRes[0]);
75 
76     // Convert hi back to i64.
77     hi = rewriter.create<arith::ExtUIOp>(valueLoc, i64, hiRes[0]);
78     hi = rewriter.create<arith::ShLIOp>(valueLoc, hi, c32);
79 
80     // Obtain the shuffled bits hi | lo.
81     value = rewriter.create<arith::OrIOp>(loc, hi, lo);
82 
83     // Convert the value back to float.
84     if (isa<FloatType>(valueType))
85       value = rewriter.create<arith::BitcastOp>(valueLoc, valueType, value);
86 
87     // Obtain the shuffle validity by combining both validities.
88     auto validity = rewriter.create<arith::AndIOp>(loc, loRes[1], hiRes[1]);
89 
90     // Replace the op.
91     rewriter.replaceOp(op, {value, validity});
92     return success();
93   }
94 };
95 } // namespace
96 
populateGpuShufflePatterns(RewritePatternSet & patterns)97 void mlir::populateGpuShufflePatterns(RewritePatternSet &patterns) {
98   patterns.add<GpuShuffleRewriter>(patterns.getContext());
99 }
100