16ad7b97eSAart Bik //===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===// 26ad7b97eSAart Bik // 36ad7b97eSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 46ad7b97eSAart Bik // See https://llvm.org/LICENSE.txt for license information. 56ad7b97eSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 66ad7b97eSAart Bik // 76ad7b97eSAart Bik //===----------------------------------------------------------------------===// 86ad7b97eSAart Bik 96ad7b97eSAart Bik #include "mlir/Dialect/AMX/Transforms.h" 106ad7b97eSAart Bik 112f743ac5SIlya Enkovich #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 1275e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 1375e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h" 146ad7b97eSAart Bik #include "mlir/Dialect/AMX/AMXDialect.h" 156ad7b97eSAart Bik #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 166ad7b97eSAart Bik #include "mlir/IR/BuiltinOps.h" 176ad7b97eSAart Bik #include "mlir/IR/PatternMatch.h" 186ad7b97eSAart Bik 196ad7b97eSAart Bik using namespace mlir; 206ad7b97eSAart Bik using namespace mlir::amx; 216ad7b97eSAart Bik 226ad7b97eSAart Bik namespace { 236ad7b97eSAart Bik 246ad7b97eSAart Bik /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first 256ad7b97eSAart Bik /// dimension directly translates into the number of rows of the tiles. 266ad7b97eSAart Bik /// The second dimensions needs to be scaled by the number of bytes. 276ad7b97eSAart Bik std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter, 28ce254598SMatthias Springer const LLVMTypeConverter &typeConverter, 292f743ac5SIlya Enkovich amx::TileType tType, Location loc) { 306ad7b97eSAart Bik Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16); 312f743ac5SIlya Enkovich unsigned width = tType.getElementType().getIntOrFloatBitWidth(); 326ad7b97eSAart Bik assert(llvm::isPowerOf2_64(width) && width >= 8); 336ad7b97eSAart Bik unsigned bytes = width >> 3; 342f743ac5SIlya Enkovich auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0)); 352f743ac5SIlya Enkovich auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes); 366ad7b97eSAart Bik return std::make_pair( 376ad7b97eSAart Bik rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr), 386ad7b97eSAart Bik rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)); 396ad7b97eSAart Bik } 406ad7b97eSAart Bik 416ad7b97eSAart Bik /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer 426ad7b97eSAart Bik /// shape may "envelop" the actual tile shape, and may be dynamically sized. 43d2109640SIlya Enkovich /// Returns failure if proper stride couldn't be found. 44d2109640SIlya Enkovich FailureOr<Value> getStride(ConversionPatternRewriter &rewriter, 45d2109640SIlya Enkovich const LLVMTypeConverter &typeConverter, 46d2109640SIlya Enkovich MemRefType mType, Value base, Location loc) { 47d2109640SIlya Enkovich if (mType.getRank() < 2) 48d2109640SIlya Enkovich return failure(); 49d2109640SIlya Enkovich int64_t preLast = mType.getRank() - 2; 506ad7b97eSAart Bik Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64); 516ad7b97eSAart Bik unsigned width = mType.getElementType().getIntOrFloatBitWidth(); 526ad7b97eSAart Bik assert(llvm::isPowerOf2_64(width) && width >= 8); 536ad7b97eSAart Bik unsigned bytes = width >> 3; 54d2109640SIlya Enkovich int64_t offset; 55d2109640SIlya Enkovich SmallVector<int64_t, 4> strides; 56*6aaa8f25SMatthias Springer if (failed(mType.getStridesAndOffset(strides, offset)) || strides.back() != 1) 57d2109640SIlya Enkovich return failure(); 58d2109640SIlya Enkovich if (strides[preLast] == ShapedType::kDynamic) { 59d2109640SIlya Enkovich // Dynamic stride needs code to compute the stride at runtime. 606ad7b97eSAart Bik MemRefDescriptor memrefDescriptor(base); 616ad7b97eSAart Bik auto attr = rewriter.getI64IntegerAttr(bytes); 626ad7b97eSAart Bik Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr); 63d2109640SIlya Enkovich return rewriter 64d2109640SIlya Enkovich .create<LLVM::MulOp>(loc, llvmInt64Type, scale, 65d2109640SIlya Enkovich memrefDescriptor.stride(rewriter, loc, preLast)) 66d2109640SIlya Enkovich .getResult(); 676ad7b97eSAart Bik } 68d2109640SIlya Enkovich // Use direct constant for static stride. 69d2109640SIlya Enkovich auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); 70d2109640SIlya Enkovich return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr) 71d2109640SIlya Enkovich .getResult(); 726ad7b97eSAart Bik } 736ad7b97eSAart Bik 746ad7b97eSAart Bik struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> { 756ad7b97eSAart Bik using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern; 766ad7b97eSAart Bik LogicalResult 77b54c724bSRiver Riddle matchAndRewrite(TileZeroOp op, OpAdaptor adaptor, 786ad7b97eSAart Bik ConversionPatternRewriter &rewriter) const override { 792f743ac5SIlya Enkovich amx::TileType tType = op.getTileType(); 806ad7b97eSAart Bik // Determine m x n tile sizes. 816ad7b97eSAart Bik std::pair<Value, Value> tsz = 822f743ac5SIlya Enkovich getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc()); 836ad7b97eSAart Bik // Replace operation with intrinsic. 842f743ac5SIlya Enkovich Type resType = typeConverter->convertType(tType); 856ad7b97eSAart Bik rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first, 866ad7b97eSAart Bik tsz.second); 876ad7b97eSAart Bik return success(); 886ad7b97eSAart Bik } 896ad7b97eSAart Bik }; 906ad7b97eSAart Bik 916ad7b97eSAart Bik struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> { 926ad7b97eSAart Bik using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern; 936ad7b97eSAart Bik 946ad7b97eSAart Bik LogicalResult 95b54c724bSRiver Riddle matchAndRewrite(TileLoadOp op, OpAdaptor adaptor, 966ad7b97eSAart Bik ConversionPatternRewriter &rewriter) const override { 976ad7b97eSAart Bik MemRefType mType = op.getMemRefType(); 982f743ac5SIlya Enkovich amx::TileType tType = op.getTileType(); 996ad7b97eSAart Bik // Determine m x n tile sizes. 1006ad7b97eSAart Bik std::pair<Value, Value> tsz = 1012f743ac5SIlya Enkovich getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc()); 1026ad7b97eSAart Bik // Determine stride. 103d2109640SIlya Enkovich auto stride = getStride(rewriter, *getTypeConverter(), mType, 1048df54a6aSJacques Pienaar adaptor.getBase(), op.getLoc()); 105d2109640SIlya Enkovich if (failed(stride)) 106d2109640SIlya Enkovich return failure(); 1076ad7b97eSAart Bik // Replace operation with intrinsic. 1088df54a6aSJacques Pienaar Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(), 1098df54a6aSJacques Pienaar adaptor.getIndices(), rewriter); 1102f743ac5SIlya Enkovich Type resType = typeConverter->convertType(tType); 1116ad7b97eSAart Bik rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>( 112d2109640SIlya Enkovich op, resType, tsz.first, tsz.second, ptr, stride.value()); 1136ad7b97eSAart Bik return success(); 1146ad7b97eSAart Bik } 1156ad7b97eSAart Bik }; 1166ad7b97eSAart Bik 1176ad7b97eSAart Bik struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> { 1186ad7b97eSAart Bik using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern; 1196ad7b97eSAart Bik 1206ad7b97eSAart Bik LogicalResult 121b54c724bSRiver Riddle matchAndRewrite(TileStoreOp op, OpAdaptor adaptor, 1226ad7b97eSAart Bik ConversionPatternRewriter &rewriter) const override { 1236ad7b97eSAart Bik MemRefType mType = op.getMemRefType(); 1242f743ac5SIlya Enkovich amx::TileType tType = op.getTileType(); 1256ad7b97eSAart Bik // Determine m x n tile sizes. 1266ad7b97eSAart Bik std::pair<Value, Value> tsz = 1272f743ac5SIlya Enkovich getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc()); 1286ad7b97eSAart Bik // Determine stride. 129d2109640SIlya Enkovich auto stride = getStride(rewriter, *getTypeConverter(), mType, 1308df54a6aSJacques Pienaar adaptor.getBase(), op.getLoc()); 131d2109640SIlya Enkovich if (failed(stride)) 132d2109640SIlya Enkovich return failure(); 1336ad7b97eSAart Bik // Replace operation with intrinsic. 1348df54a6aSJacques Pienaar Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(), 1358df54a6aSJacques Pienaar adaptor.getIndices(), rewriter); 1366ad7b97eSAart Bik rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>( 137d2109640SIlya Enkovich op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal()); 1386ad7b97eSAart Bik return success(); 1396ad7b97eSAart Bik } 1406ad7b97eSAart Bik }; 1416ad7b97eSAart Bik 1426ad7b97eSAart Bik struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> { 1436ad7b97eSAart Bik using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern; 1446ad7b97eSAart Bik LogicalResult 145b54c724bSRiver Riddle matchAndRewrite(TileMulFOp op, OpAdaptor adaptor, 1466ad7b97eSAart Bik ConversionPatternRewriter &rewriter) const override { 1472f743ac5SIlya Enkovich amx::TileType aType = op.getLhsTileType(); 1482f743ac5SIlya Enkovich amx::TileType bType = op.getRhsTileType(); 1492f743ac5SIlya Enkovich amx::TileType cType = op.getTileType(); 1506ad7b97eSAart Bik // Determine m x n x k tile sizes. 1516ad7b97eSAart Bik std::pair<Value, Value> tsza = 1526ad7b97eSAart Bik getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); 1536ad7b97eSAart Bik std::pair<Value, Value> tszb = 1546ad7b97eSAart Bik getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); 1556ad7b97eSAart Bik // Replace operation with intrinsic. 1566ad7b97eSAart Bik Type resType = typeConverter->convertType(cType); 1572f743ac5SIlya Enkovich if (aType.getElementType().isBF16()) 1586ad7b97eSAart Bik rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>( 1598df54a6aSJacques Pienaar op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 1608df54a6aSJacques Pienaar adaptor.getLhs(), adaptor.getRhs()); 1612f743ac5SIlya Enkovich else if (aType.getElementType().isF16()) 1622f743ac5SIlya Enkovich rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>( 1632f743ac5SIlya Enkovich op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 1642f743ac5SIlya Enkovich adaptor.getLhs(), adaptor.getRhs()); 1652f743ac5SIlya Enkovich else 1662f743ac5SIlya Enkovich llvm_unreachable("Unexpected element type for amx.mulf"); 1676ad7b97eSAart Bik return success(); 1686ad7b97eSAart Bik } 1696ad7b97eSAart Bik }; 1706ad7b97eSAart Bik 1716ad7b97eSAart Bik struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> { 1726ad7b97eSAart Bik using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern; 1736ad7b97eSAart Bik LogicalResult 174b54c724bSRiver Riddle matchAndRewrite(TileMulIOp op, OpAdaptor adaptor, 1756ad7b97eSAart Bik ConversionPatternRewriter &rewriter) const override { 1762f743ac5SIlya Enkovich amx::TileType aType = op.getLhsTileType(); 1772f743ac5SIlya Enkovich amx::TileType bType = op.getRhsTileType(); 1782f743ac5SIlya Enkovich amx::TileType cType = op.getTileType(); 1796ad7b97eSAart Bik // Determine m x n x k tile sizes. 1806ad7b97eSAart Bik std::pair<Value, Value> tsza = 1816ad7b97eSAart Bik getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); 1826ad7b97eSAart Bik std::pair<Value, Value> tszb = 1836ad7b97eSAart Bik getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); 1846ad7b97eSAart Bik // Replace operation with intrinsic. 1856ad7b97eSAart Bik Type resType = typeConverter->convertType(cType); 1868df54a6aSJacques Pienaar bool zexta = op.getIsZextLhs(); 1878df54a6aSJacques Pienaar bool zextb = op.getIsZextRhs(); 1886ad7b97eSAart Bik if (zexta && zextb) 1896ad7b97eSAart Bik rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>( 1908df54a6aSJacques Pienaar op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 1918df54a6aSJacques Pienaar adaptor.getLhs(), adaptor.getRhs()); 1926ad7b97eSAart Bik else if (zexta && !zextb) 1936ad7b97eSAart Bik rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>( 1948df54a6aSJacques Pienaar op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 1958df54a6aSJacques Pienaar adaptor.getLhs(), adaptor.getRhs()); 1966ad7b97eSAart Bik else if (!zexta && zextb) 1976ad7b97eSAart Bik rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>( 1988df54a6aSJacques Pienaar op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 1998df54a6aSJacques Pienaar adaptor.getLhs(), adaptor.getRhs()); 2006ad7b97eSAart Bik else 2016ad7b97eSAart Bik rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>( 2028df54a6aSJacques Pienaar op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), 2038df54a6aSJacques Pienaar adaptor.getLhs(), adaptor.getRhs()); 2046ad7b97eSAart Bik return success(); 2056ad7b97eSAart Bik } 2066ad7b97eSAart Bik }; 2076ad7b97eSAart Bik 2086ad7b97eSAart Bik } // namespace 2096ad7b97eSAart Bik 2106ad7b97eSAart Bik void mlir::populateAMXLegalizeForLLVMExportPatterns( 2112f743ac5SIlya Enkovich LLVMTypeConverter &converter, RewritePatternSet &patterns) { 212dc4e913bSChris Lattner patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion, 2136ad7b97eSAart Bik TileMulFConversion, TileMulIConversion>(converter); 2142f743ac5SIlya Enkovich converter.addConversion([&](amx::TileType type) { 2152f743ac5SIlya Enkovich return LLVM::LLVMX86AMXType::get(&converter.getContext()); 2162f743ac5SIlya Enkovich }); 2176ad7b97eSAart Bik } 2186ad7b97eSAart Bik 2196ad7b97eSAart Bik void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) { 2206ad7b97eSAart Bik target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64, 2212f743ac5SIlya Enkovich x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd, 2222f743ac5SIlya Enkovich x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>(); 2236ad7b97eSAart Bik target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp, 2246ad7b97eSAart Bik TileMulFOp>(); 2256ad7b97eSAart Bik } 2262f743ac5SIlya Enkovich 2272f743ac5SIlya Enkovich namespace { 2282f743ac5SIlya Enkovich /// Implement the interface to convert AMX to LLVM. 2292f743ac5SIlya Enkovich struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface { 2302f743ac5SIlya Enkovich using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; 2312f743ac5SIlya Enkovich 2322f743ac5SIlya Enkovich void populateConvertToLLVMConversionPatterns( 2332f743ac5SIlya Enkovich ConversionTarget &target, LLVMTypeConverter &typeConverter, 2342f743ac5SIlya Enkovich RewritePatternSet &patterns) const final { 2352f743ac5SIlya Enkovich populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns); 2362f743ac5SIlya Enkovich } 2372f743ac5SIlya Enkovich }; 2382f743ac5SIlya Enkovich } // namespace 2392f743ac5SIlya Enkovich 2402f743ac5SIlya Enkovich void mlir::registerConvertAMXToLLVMInterface(DialectRegistry ®istry) { 2412f743ac5SIlya Enkovich registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) { 2422f743ac5SIlya Enkovich dialect->addInterfaces<AMXToLLVMDialectInterface>(); 2432f743ac5SIlya Enkovich }); 2442f743ac5SIlya Enkovich } 245