xref: /llvm-project/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp (revision 9596e83b2aa9017f4ebec3c150ca3aadd047762b)
1 //===- EmulateAtomics.cpp - Emulate unsupported AMDGPU atomics ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12 #include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
15 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 namespace mlir::amdgpu {
22 #define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
23 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
24 } // namespace mlir::amdgpu
25 
26 using namespace mlir;
27 using namespace mlir::amdgpu;
28 
29 namespace {
30 struct AmdgpuEmulateAtomicsPass
31     : public amdgpu::impl::AmdgpuEmulateAtomicsPassBase<
32           AmdgpuEmulateAtomicsPass> {
33   using AmdgpuEmulateAtomicsPassBase<
34       AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase;
35   void runOnOperation() override;
36 };
37 
38 template <typename AtomicOp, typename ArithOp>
39 struct RawBufferAtomicByCasPattern : public OpConversionPattern<AtomicOp> {
40   using OpConversionPattern<AtomicOp>::OpConversionPattern;
41   using Adaptor = typename AtomicOp::Adaptor;
42 
43   LogicalResult
44   matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
45                   ConversionPatternRewriter &rewriter) const override;
46 };
47 } // namespace
48 
49 namespace {
50 enum class DataArgAction : unsigned char {
51   Duplicate,
52   Drop,
53 };
54 } // namespace
55 
56 // Fix up the fact that, when we're migrating from a general bugffer atomic
57 // to a load or to a CAS, the number of openrands, and thus the number of
58 // entries needed in operandSegmentSizes, needs to change. We use this method
59 // because we'd like to preserve unknown attributes on the atomic instead of
60 // discarding them.
61 static void patchOperandSegmentSizes(ArrayRef<NamedAttribute> attrs,
62                                      SmallVectorImpl<NamedAttribute> &newAttrs,
63                                      DataArgAction action) {
64   newAttrs.reserve(attrs.size());
65   for (NamedAttribute attr : attrs) {
66     if (attr.getName().getValue() != "operandSegmentSizes") {
67       newAttrs.push_back(attr);
68       continue;
69     }
70     auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue());
71     MLIRContext *context = segmentAttr.getContext();
72     DenseI32ArrayAttr newSegments;
73     switch (action) {
74     case DataArgAction::Drop:
75       newSegments = DenseI32ArrayAttr::get(
76           context, segmentAttr.asArrayRef().drop_front());
77       break;
78     case DataArgAction::Duplicate: {
79       SmallVector<int32_t> newVals;
80       ArrayRef<int32_t> oldVals = segmentAttr.asArrayRef();
81       newVals.push_back(oldVals[0]);
82       newVals.append(oldVals.begin(), oldVals.end());
83       newSegments = DenseI32ArrayAttr::get(context, newVals);
84       break;
85     }
86     }
87     newAttrs.push_back(NamedAttribute(attr.getName(), newSegments));
88   }
89 }
90 
91 // A helper function to flatten a vector value to a scalar containing its bits,
92 // returning the value itself if othetwise.
93 static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc,
94                               Value val) {
95   auto vectorType = dyn_cast<VectorType>(val.getType());
96   if (!vectorType)
97     return val;
98 
99   int64_t bitwidth =
100       vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
101   Type allBitsType = rewriter.getIntegerType(bitwidth);
102   auto allBitsVecType = VectorType::get({1}, allBitsType);
103   Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val);
104   Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0);
105   return scalar;
106 }
107 
108 template <typename AtomicOp, typename ArithOp>
109 LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
110     AtomicOp atomicOp, Adaptor adaptor,
111     ConversionPatternRewriter &rewriter) const {
112   Location loc = atomicOp.getLoc();
113 
114   ArrayRef<NamedAttribute> origAttrs = atomicOp->getAttrs();
115   ValueRange operands = adaptor.getOperands();
116   Value data = operands.take_front()[0];
117   ValueRange invariantArgs = operands.drop_front();
118   Type dataType = data.getType();
119 
120   SmallVector<NamedAttribute> loadAttrs;
121   patchOperandSegmentSizes(origAttrs, loadAttrs, DataArgAction::Drop);
122   Value initialLoad =
123       rewriter.create<RawBufferLoadOp>(loc, dataType, invariantArgs, loadAttrs);
124   Block *currentBlock = rewriter.getInsertionBlock();
125   Block *afterAtomic =
126       rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
127   Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});
128 
129   rewriter.setInsertionPointToEnd(currentBlock);
130   rewriter.create<cf::BranchOp>(loc, loopBlock, initialLoad);
131 
132   rewriter.setInsertionPointToEnd(loopBlock);
133   Value prevLoad = loopBlock->getArgument(0);
134   Value operated = rewriter.create<ArithOp>(loc, data, prevLoad);
135   dataType = operated.getType();
136 
137   SmallVector<NamedAttribute> cmpswapAttrs;
138   patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate);
139   SmallVector<Value> cmpswapArgs = {operated, prevLoad};
140   cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
141   Value atomicRes = rewriter.create<RawBufferAtomicCmpswapOp>(
142       loc, dataType, cmpswapArgs, cmpswapAttrs);
143 
144   // We care about exact bitwise equality here, so do some bitcasts.
145   // These will fold away during lowering to the ROCDL dialect, where
146   // an int->float bitcast is introduced to account for the fact that cmpswap
147   // only takes integer arguments.
148 
149   Value prevLoadForCompare = flattenVecToBits(rewriter, loc, prevLoad);
150   Value atomicResForCompare = flattenVecToBits(rewriter, loc, atomicRes);
151   if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
152     Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
153     prevLoadForCompare =
154         rewriter.create<arith::BitcastOp>(loc, equivInt, prevLoad);
155     atomicResForCompare =
156         rewriter.create<arith::BitcastOp>(loc, equivInt, atomicRes);
157   }
158   Value canLeave = rewriter.create<arith::CmpIOp>(
159       loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare);
160   rewriter.create<cf::CondBranchOp>(loc, canLeave, afterAtomic, ValueRange{},
161                                     loopBlock, atomicRes);
162   rewriter.eraseOp(atomicOp);
163   return success();
164 }
165 
166 void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
167     ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset) {
168   // gfx10 has no atomic adds.
169   if (chipset.majorVersion == 10 || chipset < Chipset(9, 0, 8)) {
170     target.addIllegalOp<RawBufferAtomicFaddOp>();
171   }
172   // gfx11 has no fp16 atomics
173   if (chipset.majorVersion == 11) {
174     target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
175         [](RawBufferAtomicFaddOp op) -> bool {
176           Type elemType = getElementTypeOrSelf(op.getValue().getType());
177           return !isa<Float16Type, BFloat16Type>(elemType);
178         });
179   }
180   // gfx9 has no to a very limited support for floating-point min and max.
181   if (chipset.majorVersion == 9) {
182     if (chipset >= Chipset(9, 0, 0xa) && chipset != Chipset(9, 4, 1)) {
183       // gfx90a supports f64 max (and min, but we don't have a min wrapper right
184       // now) but all other types need to be emulated.
185       target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
186           [](RawBufferAtomicFmaxOp op) -> bool {
187             return op.getValue().getType().isF64();
188           });
189     } else {
190       target.addIllegalOp<RawBufferAtomicFmaxOp>();
191     }
192     if (chipset == Chipset(9, 4, 1)) {
193       // gfx941 requires non-CAS atomics to be implemented with CAS loops.
194       // The workaround here mirrors HIP and OpenMP.
195       target.addIllegalOp<RawBufferAtomicFaddOp, RawBufferAtomicFmaxOp,
196                           RawBufferAtomicSmaxOp, RawBufferAtomicUminOp>();
197     }
198   }
199   patterns.add<
200       RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
201       RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,
202       RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,
203       RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(
204       patterns.getContext());
205 }
206 
207 void AmdgpuEmulateAtomicsPass::runOnOperation() {
208   Operation *op = getOperation();
209   FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
210   if (failed(maybeChipset)) {
211     emitError(op->getLoc(), "Invalid chipset name: " + chipset);
212     return signalPassFailure();
213   }
214 
215   MLIRContext &ctx = getContext();
216   ConversionTarget target(ctx);
217   RewritePatternSet patterns(&ctx);
218   target.markUnknownOpDynamicallyLegal(
219       [](Operation *op) -> bool { return true; });
220 
221   populateAmdgpuEmulateAtomicsPatterns(target, patterns, *maybeChipset);
222   if (failed(applyPartialConversion(op, target, std::move(patterns))))
223     return signalPassFailure();
224 }
225