1 //===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains definitions of patterns to lower GPU Subgroup MMA ops to 10 // SPIRV Cooperative Matrix ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" 15 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" 16 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 19 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 20 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 21 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 22 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 23 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 24 #include "mlir/IR/BuiltinAttributes.h" 25 #include "mlir/IR/BuiltinTypes.h" 26 #include "mlir/IR/TypeUtilities.h" 27 #include "mlir/IR/ValueRange.h" 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/ADT/StringSwitch.h" 30 31 #include <cassert> 32 33 namespace mlir { 34 //===----------------------------------------------------------------------===// 35 // Patterns and helpers. 36 //===----------------------------------------------------------------------===// 37 38 /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op 39 /// when the elementwise op directly supports with cooperative matrix type. 40 /// Returns false if cannot. 41 /// 42 /// See SPV_KHR_cooperative_matrix for supported elementwise ops. 43 static bool createElementwiseOp(ConversionPatternRewriter &builder, 44 gpu::SubgroupMmaElementwiseOp op, Type coopType, 45 ValueRange operands) { 46 assert((isa<spirv::CooperativeMatrixType>(coopType))); 47 48 switch (op.getOpType()) { 49 case gpu::MMAElementwiseOp::ADDF: 50 builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands); 51 return true; 52 case gpu::MMAElementwiseOp::ADDI: 53 builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands); 54 return true; 55 case gpu::MMAElementwiseOp::SUBF: 56 builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands); 57 return true; 58 case gpu::MMAElementwiseOp::SUBI: 59 builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands); 60 return true; 61 case gpu::MMAElementwiseOp::DIVF: 62 builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands); 63 return true; 64 case gpu::MMAElementwiseOp::DIVS: 65 builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands); 66 return true; 67 case gpu::MMAElementwiseOp::DIVU: 68 builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands); 69 return true; 70 case gpu::MMAElementwiseOp::NEGATEF: 71 builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands); 72 return true; 73 case gpu::MMAElementwiseOp::NEGATES: 74 builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands); 75 return true; 76 case gpu::MMAElementwiseOp::EXTF: 77 builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands); 78 return true; 79 default: 80 break; 81 } 82 return false; 83 } 84 85 bool allOperandsHaveSameCoopMatrixType(ValueRange operands) { 86 assert(!operands.empty()); 87 if (!llvm::all_equal( 88 llvm::map_range(operands, [](Value v) { return v.getType(); }))) 89 return false; 90 91 return isa<spirv::CooperativeMatrixType>(operands.front().getType()); 92 } 93 94 namespace { 95 /// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative 96 /// matrix ops. 97 struct WmmaConstantOpToSPIRVLowering final 98 : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> { 99 using OpConversionPattern::OpConversionPattern; 100 101 LogicalResult 102 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor, 103 ConversionPatternRewriter &rewriter) const override { 104 assert(adaptor.getOperands().size() == 1); 105 Value cst = adaptor.getOperands().front(); 106 auto coopType = getTypeConverter()->convertType(op.getType()); 107 if (!coopType) 108 return rewriter.notifyMatchFailure(op, "type conversion failed"); 109 110 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst); 111 return success(); 112 } 113 }; 114 115 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for 116 /// the default case. 117 struct WmmaElementwiseOpToSPIRVDefaultLowering final 118 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { 119 using OpConversionPattern::OpConversionPattern; 120 121 LogicalResult 122 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor, 123 ConversionPatternRewriter &rewriter) const override { 124 // All operands should be of cooperative matrix types. 125 if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) { 126 return rewriter.notifyMatchFailure(op, 127 "not all operands are coop matrices"); 128 } 129 130 auto coopType = getTypeConverter()->convertType(op.getType()); 131 if (!coopType) 132 return rewriter.notifyMatchFailure(op, "type conversion failed"); 133 134 return success( 135 createElementwiseOp(rewriter, op, coopType, adaptor.getOperands())); 136 } 137 }; 138 139 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for 140 /// matrix times scalar case. 141 struct WmmaElementwiseOpToSPIRVScalarMulLowering final 142 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { 143 using OpConversionPattern::OpConversionPattern; 144 145 LogicalResult 146 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor, 147 ConversionPatternRewriter &rewriter) const override { 148 if (adaptor.getOperands().size() != 2) 149 return failure(); 150 151 // All operands should be of cooperative matrix types. 152 if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) { 153 return rewriter.notifyMatchFailure(op, 154 "not all operands are coop matrices"); 155 } 156 157 if (op.getOpType() != gpu::MMAElementwiseOp::MULF) 158 return failure(); 159 160 // Use the original operands to check whether one of the operands is a splat 161 // scalar value. 162 Value lhs = op.getOperands().front(); 163 Value rhs = op.getOperands().back(); 164 Value splat = nullptr; 165 Value matrix = nullptr; 166 if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) { 167 splat = adaptor.getOperands().front(); 168 matrix = adaptor.getOperands().back(); 169 } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) { 170 matrix = adaptor.getOperands().front(); 171 splat = adaptor.getOperands().back(); 172 } 173 if (!splat || !matrix) 174 return rewriter.notifyMatchFailure(op, "no splat operand"); 175 176 // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops. 177 Value scalar; 178 auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>(); 179 if (!cc) { 180 return rewriter.notifyMatchFailure(op, 181 "splat is not a composite construct"); 182 } 183 184 assert(cc.getConstituents().size() == 1); 185 scalar = cc.getConstituents().front(); 186 187 auto coopType = getTypeConverter()->convertType(op.getType()); 188 if (!coopType) 189 return rewriter.notifyMatchFailure(op, "type conversion failed"); 190 rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>( 191 op, coopType, ValueRange{matrix, scalar}); 192 return success(); 193 } 194 }; 195 } // namespace 196 197 //===----------------------------------------------------------------------===// 198 // SPV_KHR_cooperative_matrix 199 //===----------------------------------------------------------------------===// 200 201 namespace khr { 202 namespace { 203 204 /// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV 205 /// dialect. 206 struct WmmaLoadOpToSPIRVLowering final 207 : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> { 208 using OpConversionPattern::OpConversionPattern; 209 210 LogicalResult 211 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor, 212 ConversionPatternRewriter &rewriter) const override { 213 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 214 Location loc = op->getLoc(); 215 216 auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType()); 217 MemRefType memrefType = op.getSrcMemref().getType(); 218 Value bufferPtr = 219 spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(), 220 adaptor.getIndices(), loc, rewriter); 221 222 auto coopType = 223 typeConverter.convertType<spirv::CooperativeMatrixType>(retType); 224 if (!coopType) 225 return rewriter.notifyMatchFailure(op, "type conversion failed"); 226 227 int64_t stride = op.getLeadDimension().getSExtValue(); 228 IntegerType i32Type = rewriter.getI32Type(); 229 auto strideValue = rewriter.create<spirv::ConstantOp>( 230 loc, i32Type, IntegerAttr::get(i32Type, stride)); 231 232 bool isColMajor = op.getTranspose().value_or(false); 233 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor 234 : spirv::CooperativeMatrixLayoutKHR::RowMajor; 235 236 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>( 237 op, coopType, bufferPtr, strideValue, layout); 238 return success(); 239 } 240 }; 241 242 /// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV 243 /// dialect. 244 struct WmmaStoreOpToSPIRVLowering final 245 : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> { 246 using OpConversionPattern::OpConversionPattern; 247 248 LogicalResult 249 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor, 250 ConversionPatternRewriter &rewriter) const override { 251 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 252 Location loc = op->getLoc(); 253 254 auto memrefType = cast<MemRefType>(op.getDstMemref().getType()); 255 Value bufferPtr = 256 spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(), 257 adaptor.getIndices(), loc, rewriter); 258 259 int64_t stride = op.getLeadDimension().getSExtValue(); 260 IntegerType i32Type = rewriter.getI32Type(); 261 auto strideValue = rewriter.create<spirv::ConstantOp>( 262 loc, i32Type, IntegerAttr::get(i32Type, stride)); 263 264 bool isColMajor = op.getTranspose().value_or(false); 265 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor 266 : spirv::CooperativeMatrixLayoutKHR::RowMajor; 267 268 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>( 269 op, bufferPtr, adaptor.getSrc(), strideValue, layout); 270 return success(); 271 } 272 }; 273 274 /// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV 275 /// dialect. 276 struct WmmaMmaOpToSPIRVLowering final 277 : OpConversionPattern<gpu::SubgroupMmaComputeOp> { 278 using OpConversionPattern::OpConversionPattern; 279 280 LogicalResult 281 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, 282 OpAdaptor adaptor, 283 ConversionPatternRewriter &rewriter) const override { 284 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>( 285 subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(), 286 adaptor.getOpC()); 287 return success(); 288 } 289 }; 290 291 } // namespace 292 } // namespace khr 293 } // namespace mlir 294 295 void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns( 296 const SPIRVTypeConverter &converter, RewritePatternSet &patterns) { 297 using namespace mlir; 298 MLIRContext *context = patterns.getContext(); 299 patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering, 300 khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering, 301 WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context); 302 // Give the following patterns higher benefit to prevail over the default one. 303 patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context, 304 /*benefit=*/2); 305 } 306 307 void mlir::populateMMAToSPIRVCoopMatrixTypeConversion( 308 mlir::SPIRVTypeConverter &typeConverter) { 309 typeConverter.addConversion([](gpu::MMAMatrixType type) { 310 ArrayRef<int64_t> retTypeShape = type.getShape(); 311 Type elementType = type.getElementType(); 312 auto use = 313 llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand()) 314 .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA) 315 .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB) 316 .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc); 317 318 return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0], 319 retTypeShape[1], 320 spirv::Scope::Subgroup, use); 321 }); 322 } 323