1 //===- EmulateAtomics.cpp - Emulate unsupported AMDGPU atomics ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h" 10 11 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" 12 #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" 15 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 16 #include "mlir/Dialect/Vector/IR/VectorOps.h" 17 #include "mlir/IR/BuiltinAttributes.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 21 namespace mlir::amdgpu { 22 #define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS 23 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" 24 } // namespace mlir::amdgpu 25 26 using namespace mlir; 27 using namespace mlir::amdgpu; 28 29 namespace { 30 struct AmdgpuEmulateAtomicsPass 31 : public amdgpu::impl::AmdgpuEmulateAtomicsPassBase< 32 AmdgpuEmulateAtomicsPass> { 33 using AmdgpuEmulateAtomicsPassBase< 34 AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase; 35 void runOnOperation() override; 36 }; 37 38 template <typename AtomicOp, typename ArithOp> 39 struct RawBufferAtomicByCasPattern : public OpConversionPattern<AtomicOp> { 40 using OpConversionPattern<AtomicOp>::OpConversionPattern; 41 using Adaptor = typename AtomicOp::Adaptor; 42 43 LogicalResult 44 matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor, 45 ConversionPatternRewriter &rewriter) const override; 46 }; 47 } // namespace 48 49 namespace { 50 enum class DataArgAction : unsigned char { 51 Duplicate, 52 Drop, 53 }; 54 } // namespace 55 56 // Fix up the fact that, when we're migrating from a general bugffer atomic 57 // to a load or to a CAS, the number of openrands, and thus the number of 58 // entries needed in operandSegmentSizes, needs to change. We use this method 59 // because we'd like to preserve unknown attributes on the atomic instead of 60 // discarding them. 61 static void patchOperandSegmentSizes(ArrayRef<NamedAttribute> attrs, 62 SmallVectorImpl<NamedAttribute> &newAttrs, 63 DataArgAction action) { 64 newAttrs.reserve(attrs.size()); 65 for (NamedAttribute attr : attrs) { 66 if (attr.getName().getValue() != "operandSegmentSizes") { 67 newAttrs.push_back(attr); 68 continue; 69 } 70 auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue()); 71 MLIRContext *context = segmentAttr.getContext(); 72 DenseI32ArrayAttr newSegments; 73 switch (action) { 74 case DataArgAction::Drop: 75 newSegments = DenseI32ArrayAttr::get( 76 context, segmentAttr.asArrayRef().drop_front()); 77 break; 78 case DataArgAction::Duplicate: { 79 SmallVector<int32_t> newVals; 80 ArrayRef<int32_t> oldVals = segmentAttr.asArrayRef(); 81 newVals.push_back(oldVals[0]); 82 newVals.append(oldVals.begin(), oldVals.end()); 83 newSegments = DenseI32ArrayAttr::get(context, newVals); 84 break; 85 } 86 } 87 newAttrs.push_back(NamedAttribute(attr.getName(), newSegments)); 88 } 89 } 90 91 // A helper function to flatten a vector value to a scalar containing its bits, 92 // returning the value itself if othetwise. 93 static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc, 94 Value val) { 95 auto vectorType = dyn_cast<VectorType>(val.getType()); 96 if (!vectorType) 97 return val; 98 99 int64_t bitwidth = 100 vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); 101 Type allBitsType = rewriter.getIntegerType(bitwidth); 102 auto allBitsVecType = VectorType::get({1}, allBitsType); 103 Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val); 104 Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0); 105 return scalar; 106 } 107 108 template <typename AtomicOp, typename ArithOp> 109 LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite( 110 AtomicOp atomicOp, Adaptor adaptor, 111 ConversionPatternRewriter &rewriter) const { 112 Location loc = atomicOp.getLoc(); 113 114 ArrayRef<NamedAttribute> origAttrs = atomicOp->getAttrs(); 115 ValueRange operands = adaptor.getOperands(); 116 Value data = operands.take_front()[0]; 117 ValueRange invariantArgs = operands.drop_front(); 118 Type dataType = data.getType(); 119 120 SmallVector<NamedAttribute> loadAttrs; 121 patchOperandSegmentSizes(origAttrs, loadAttrs, DataArgAction::Drop); 122 Value initialLoad = 123 rewriter.create<RawBufferLoadOp>(loc, dataType, invariantArgs, loadAttrs); 124 Block *currentBlock = rewriter.getInsertionBlock(); 125 Block *afterAtomic = 126 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 127 Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc}); 128 129 rewriter.setInsertionPointToEnd(currentBlock); 130 rewriter.create<cf::BranchOp>(loc, loopBlock, initialLoad); 131 132 rewriter.setInsertionPointToEnd(loopBlock); 133 Value prevLoad = loopBlock->getArgument(0); 134 Value operated = rewriter.create<ArithOp>(loc, data, prevLoad); 135 dataType = operated.getType(); 136 137 SmallVector<NamedAttribute> cmpswapAttrs; 138 patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate); 139 SmallVector<Value> cmpswapArgs = {operated, prevLoad}; 140 cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end()); 141 Value atomicRes = rewriter.create<RawBufferAtomicCmpswapOp>( 142 loc, dataType, cmpswapArgs, cmpswapAttrs); 143 144 // We care about exact bitwise equality here, so do some bitcasts. 145 // These will fold away during lowering to the ROCDL dialect, where 146 // an int->float bitcast is introduced to account for the fact that cmpswap 147 // only takes integer arguments. 148 149 Value prevLoadForCompare = flattenVecToBits(rewriter, loc, prevLoad); 150 Value atomicResForCompare = flattenVecToBits(rewriter, loc, atomicRes); 151 if (auto floatDataTy = dyn_cast<FloatType>(dataType)) { 152 Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth()); 153 prevLoadForCompare = 154 rewriter.create<arith::BitcastOp>(loc, equivInt, prevLoad); 155 atomicResForCompare = 156 rewriter.create<arith::BitcastOp>(loc, equivInt, atomicRes); 157 } 158 Value canLeave = rewriter.create<arith::CmpIOp>( 159 loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare); 160 rewriter.create<cf::CondBranchOp>(loc, canLeave, afterAtomic, ValueRange{}, 161 loopBlock, atomicRes); 162 rewriter.eraseOp(atomicOp); 163 return success(); 164 } 165 166 void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns( 167 ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset) { 168 // gfx10 has no atomic adds. 169 if (chipset.majorVersion == 10 || chipset < Chipset(9, 0, 8)) { 170 target.addIllegalOp<RawBufferAtomicFaddOp>(); 171 } 172 // gfx11 has no fp16 atomics 173 if (chipset.majorVersion == 11) { 174 target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>( 175 [](RawBufferAtomicFaddOp op) -> bool { 176 Type elemType = getElementTypeOrSelf(op.getValue().getType()); 177 return !isa<Float16Type, BFloat16Type>(elemType); 178 }); 179 } 180 // gfx9 has no to a very limited support for floating-point min and max. 181 if (chipset.majorVersion == 9) { 182 if (chipset >= Chipset(9, 0, 0xa) && chipset != Chipset(9, 4, 1)) { 183 // gfx90a supports f64 max (and min, but we don't have a min wrapper right 184 // now) but all other types need to be emulated. 185 target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>( 186 [](RawBufferAtomicFmaxOp op) -> bool { 187 return op.getValue().getType().isF64(); 188 }); 189 } else { 190 target.addIllegalOp<RawBufferAtomicFmaxOp>(); 191 } 192 if (chipset == Chipset(9, 4, 1)) { 193 // gfx941 requires non-CAS atomics to be implemented with CAS loops. 194 // The workaround here mirrors HIP and OpenMP. 195 target.addIllegalOp<RawBufferAtomicFaddOp, RawBufferAtomicFmaxOp, 196 RawBufferAtomicSmaxOp, RawBufferAtomicUminOp>(); 197 } 198 } 199 patterns.add< 200 RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>, 201 RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>, 202 RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>, 203 RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>( 204 patterns.getContext()); 205 } 206 207 void AmdgpuEmulateAtomicsPass::runOnOperation() { 208 Operation *op = getOperation(); 209 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset); 210 if (failed(maybeChipset)) { 211 emitError(op->getLoc(), "Invalid chipset name: " + chipset); 212 return signalPassFailure(); 213 } 214 215 MLIRContext &ctx = getContext(); 216 ConversionTarget target(ctx); 217 RewritePatternSet patterns(&ctx); 218 target.markUnknownOpDynamicallyLegal( 219 [](Operation *op) -> bool { return true; }); 220 221 populateAmdgpuEmulateAtomicsPatterns(target, patterns, *maybeChipset); 222 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 223 return signalPassFailure(); 224 } 225