1 //===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/AMX/Transforms.h" 10 11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Dialect/AMX/AMXDialect.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/PatternMatch.h" 18 19 using namespace mlir; 20 using namespace mlir::amx; 21 22 namespace { 23 24 /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first 25 /// dimension directly translates into the number of rows of the tiles. 26 /// The second dimensions needs to be scaled by the number of bytes. 27 std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter, 28 const LLVMTypeConverter &typeConverter, 29 amx::TileType tType, Location loc) { 30 Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16); 31 unsigned width = tType.getElementType().getIntOrFloatBitWidth(); 32 assert(llvm::isPowerOf2_64(width) && width >= 8); 33 unsigned bytes = width >> 3; 34 auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0)); 35 auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes); 36 return std::make_pair( 37 rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr), 38 rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)); 39 } 40 41 /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer 42 /// shape may "envelop" the actual tile shape, and may be dynamically sized. 43 /// Returns failure if proper stride couldn't be found. 44 FailureOr<Value> getStride(ConversionPatternRewriter &rewriter, 45 const LLVMTypeConverter &typeConverter, 46 MemRefType mType, Value base, Location loc) { 47 if (mType.getRank() < 2) 48 return failure(); 49 int64_t preLast = mType.getRank() - 2; 50 Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64); 51 unsigned width = mType.getElementType().getIntOrFloatBitWidth(); 52 assert(llvm::isPowerOf2_64(width) && width >= 8); 53 unsigned bytes = width >> 3; 54 int64_t offset; 55 SmallVector<int64_t, 4> strides; 56 if (failed(mType.getStridesAndOffset(strides, offset)) || strides.back() != 1) 57 return failure(); 58 if (strides[preLast] == ShapedType::kDynamic) { 59 // Dynamic stride needs code to compute the stride at runtime. 60 MemRefDescriptor memrefDescriptor(base); 61 auto attr = rewriter.getI64IntegerAttr(bytes); 62 Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr); 63 return rewriter 64 .create<LLVM::MulOp>(loc, llvmInt64Type, scale, 65 memrefDescriptor.stride(rewriter, loc, preLast)) 66 .getResult(); 67 } 68 // Use direct constant for static stride. 69 auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); 70 return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr) 71 .getResult(); 72 } 73 74 struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> { 75 using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern; 76 LogicalResult 77 matchAndRewrite(TileZeroOp op, OpAdaptor adaptor, 78 ConversionPatternRewriter &rewriter) const override { 79 amx::TileType tType = op.getTileType(); 80 // Determine m x n tile sizes. 81 std::pair<Value, Value> tsz = 82 getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc()); 83 // Replace operation with intrinsic. 84 Type resType = typeConverter->convertType(tType); 85 rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first, 86 tsz.second); 87 return success(); 88 } 89 }; 90 91 struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> { 92 using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern; 93 94 LogicalResult 95 matchAndRewrite(TileLoadOp op, OpAdaptor adaptor, 96 ConversionPatternRewriter &rewriter) const override { 97 MemRefType mType = op.getMemRefType(); 98 amx::TileType tType = op.getTileType(); 99 // Determine m x n tile sizes. 100 std::pair<Value, Value> tsz = 101 getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc()); 102 // Determine stride. 103 auto stride = getStride(rewriter, *getTypeConverter(), mType, 104 adaptor.getBase(), op.getLoc()); 105 if (failed(stride)) 106 return failure(); 107 // Replace operation with intrinsic. 108 Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(), 109 adaptor.getIndices(), rewriter); 110 Type resType = typeConverter->convertType(tType); 111 rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>( 112 op, resType, tsz.first, tsz.second, ptr, stride.value()); 113 return success(); 114 } 115 }; 116 117 struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> { 118 using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern; 119 120 LogicalResult 121 matchAndRewrite(TileStoreOp op, OpAdaptor adaptor, 122 ConversionPatternRewriter &rewriter) const override { 123 MemRefType mType = op.getMemRefType(); 124 amx::TileType tType = op.getTileType(); 125 // Determine m x n tile sizes. 126 std::pair<Value, Value> tsz = 127 getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc()); 128 // Determine stride. 129 auto stride = getStride(rewriter, *getTypeConverter(), mType, 130 adaptor.getBase(), op.getLoc()); 131 if (failed(stride)) 132 return failure(); 133 // Replace operation with intrinsic. 134 Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(), 135 adaptor.getIndices(), rewriter); 136 rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>( 137 op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal()); 138 return success(); 139 } 140 }; 141 142 struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> { 143 using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern; 144 LogicalResult 145 matchAndRewrite(TileMulFOp op, OpAdaptor adaptor, 146 ConversionPatternRewriter &rewriter) const override { 147 amx::TileType aType = op.getLhsTileType(); 148 amx::TileType bType = op.getRhsTileType(); 149 amx::TileType cType = op.getTileType(); 150 // Determine m x n x k tile sizes. 151 std::pair<Value, Value> tsza = 152 getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); 153 std::pair<Value, Value> tszb = 154 getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); 155 // Replace operation with intrinsic. 156 Type resType = typeConverter->convertType(cType); 157 if (aType.getElementType().isBF16()) 158 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>( 159 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 160 adaptor.getLhs(), adaptor.getRhs()); 161 else if (aType.getElementType().isF16()) 162 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>( 163 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 164 adaptor.getLhs(), adaptor.getRhs()); 165 else 166 llvm_unreachable("Unexpected element type for amx.mulf"); 167 return success(); 168 } 169 }; 170 171 struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> { 172 using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern; 173 LogicalResult 174 matchAndRewrite(TileMulIOp op, OpAdaptor adaptor, 175 ConversionPatternRewriter &rewriter) const override { 176 amx::TileType aType = op.getLhsTileType(); 177 amx::TileType bType = op.getRhsTileType(); 178 amx::TileType cType = op.getTileType(); 179 // Determine m x n x k tile sizes. 180 std::pair<Value, Value> tsza = 181 getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); 182 std::pair<Value, Value> tszb = 183 getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); 184 // Replace operation with intrinsic. 185 Type resType = typeConverter->convertType(cType); 186 bool zexta = op.getIsZextLhs(); 187 bool zextb = op.getIsZextRhs(); 188 if (zexta && zextb) 189 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>( 190 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 191 adaptor.getLhs(), adaptor.getRhs()); 192 else if (zexta && !zextb) 193 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>( 194 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 195 adaptor.getLhs(), adaptor.getRhs()); 196 else if (!zexta && zextb) 197 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>( 198 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 199 adaptor.getLhs(), adaptor.getRhs()); 200 else 201 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>( 202 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 203 adaptor.getLhs(), adaptor.getRhs()); 204 return success(); 205 } 206 }; 207 208 } // namespace 209 210 void mlir::populateAMXLegalizeForLLVMExportPatterns( 211 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 212 patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion, 213 TileMulFConversion, TileMulIConversion>(converter); 214 converter.addConversion([&](amx::TileType type) { 215 return LLVM::LLVMX86AMXType::get(&converter.getContext()); 216 }); 217 } 218 219 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) { 220 target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64, 221 x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd, 222 x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>(); 223 target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp, 224 TileMulFOp>(); 225 } 226 227 namespace { 228 /// Implement the interface to convert AMX to LLVM. 229 struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface { 230 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; 231 232 void populateConvertToLLVMConversionPatterns( 233 ConversionTarget &target, LLVMTypeConverter &typeConverter, 234 RewritePatternSet &patterns) const final { 235 populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns); 236 } 237 }; 238 } // namespace 239 240 void mlir::registerConvertAMXToLLVMInterface(DialectRegistry ®istry) { 241 registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) { 242 dialect->addInterfaces<AMXToLLVMDialectInterface>(); 243 }); 244 } 245