xref: /llvm-project/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/AMX/Transforms.h"
10 
11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Dialect/AMX/AMXDialect.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 using namespace mlir;
20 using namespace mlir::amx;
21 
22 namespace {
23 
24 /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
25 /// dimension directly translates into the number of rows of the tiles.
26 /// The second dimensions needs to be scaled by the number of bytes.
27 std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
28                                      const LLVMTypeConverter &typeConverter,
29                                      amx::TileType tType, Location loc) {
30   Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
31   unsigned width = tType.getElementType().getIntOrFloatBitWidth();
32   assert(llvm::isPowerOf2_64(width) && width >= 8);
33   unsigned bytes = width >> 3;
34   auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
35   auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
36   return std::make_pair(
37       rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
38       rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
39 }
40 
41 /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
42 /// shape may "envelop" the actual tile shape, and may be dynamically sized.
43 /// Returns failure if proper stride couldn't be found.
44 FailureOr<Value> getStride(ConversionPatternRewriter &rewriter,
45                            const LLVMTypeConverter &typeConverter,
46                            MemRefType mType, Value base, Location loc) {
47   if (mType.getRank() < 2)
48     return failure();
49   int64_t preLast = mType.getRank() - 2;
50   Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
51   unsigned width = mType.getElementType().getIntOrFloatBitWidth();
52   assert(llvm::isPowerOf2_64(width) && width >= 8);
53   unsigned bytes = width >> 3;
54   int64_t offset;
55   SmallVector<int64_t, 4> strides;
56   if (failed(mType.getStridesAndOffset(strides, offset)) || strides.back() != 1)
57     return failure();
58   if (strides[preLast] == ShapedType::kDynamic) {
59     // Dynamic stride needs code to compute the stride at runtime.
60     MemRefDescriptor memrefDescriptor(base);
61     auto attr = rewriter.getI64IntegerAttr(bytes);
62     Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
63     return rewriter
64         .create<LLVM::MulOp>(loc, llvmInt64Type, scale,
65                              memrefDescriptor.stride(rewriter, loc, preLast))
66         .getResult();
67   }
68   // Use direct constant for static stride.
69   auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
70   return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
71       .getResult();
72 }
73 
74 struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
75   using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern;
76   LogicalResult
77   matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
78                   ConversionPatternRewriter &rewriter) const override {
79     amx::TileType tType = op.getTileType();
80     // Determine m x n tile sizes.
81     std::pair<Value, Value> tsz =
82         getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
83     // Replace operation with intrinsic.
84     Type resType = typeConverter->convertType(tType);
85     rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
86                                                        tsz.second);
87     return success();
88   }
89 };
90 
91 struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
92   using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern;
93 
94   LogicalResult
95   matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
96                   ConversionPatternRewriter &rewriter) const override {
97     MemRefType mType = op.getMemRefType();
98     amx::TileType tType = op.getTileType();
99     // Determine m x n tile sizes.
100     std::pair<Value, Value> tsz =
101         getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
102     // Determine stride.
103     auto stride = getStride(rewriter, *getTypeConverter(), mType,
104                             adaptor.getBase(), op.getLoc());
105     if (failed(stride))
106       return failure();
107     // Replace operation with intrinsic.
108     Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
109                                      adaptor.getIndices(), rewriter);
110     Type resType = typeConverter->convertType(tType);
111     rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
112         op, resType, tsz.first, tsz.second, ptr, stride.value());
113     return success();
114   }
115 };
116 
117 struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
118   using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
119 
120   LogicalResult
121   matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
122                   ConversionPatternRewriter &rewriter) const override {
123     MemRefType mType = op.getMemRefType();
124     amx::TileType tType = op.getTileType();
125     // Determine m x n tile sizes.
126     std::pair<Value, Value> tsz =
127         getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
128     // Determine stride.
129     auto stride = getStride(rewriter, *getTypeConverter(), mType,
130                             adaptor.getBase(), op.getLoc());
131     if (failed(stride))
132       return failure();
133     // Replace operation with intrinsic.
134     Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
135                                      adaptor.getIndices(), rewriter);
136     rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
137         op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
138     return success();
139   }
140 };
141 
142 struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
143   using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern;
144   LogicalResult
145   matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
146                   ConversionPatternRewriter &rewriter) const override {
147     amx::TileType aType = op.getLhsTileType();
148     amx::TileType bType = op.getRhsTileType();
149     amx::TileType cType = op.getTileType();
150     // Determine m x n x k tile sizes.
151     std::pair<Value, Value> tsza =
152         getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
153     std::pair<Value, Value> tszb =
154         getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
155     // Replace operation with intrinsic.
156     Type resType = typeConverter->convertType(cType);
157     if (aType.getElementType().isBF16())
158       rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
159           op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
160           adaptor.getLhs(), adaptor.getRhs());
161     else if (aType.getElementType().isF16())
162       rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
163           op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
164           adaptor.getLhs(), adaptor.getRhs());
165     else
166       llvm_unreachable("Unexpected element type for amx.mulf");
167     return success();
168   }
169 };
170 
171 struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
172   using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern;
173   LogicalResult
174   matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
175                   ConversionPatternRewriter &rewriter) const override {
176     amx::TileType aType = op.getLhsTileType();
177     amx::TileType bType = op.getRhsTileType();
178     amx::TileType cType = op.getTileType();
179     // Determine m x n x k tile sizes.
180     std::pair<Value, Value> tsza =
181         getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
182     std::pair<Value, Value> tszb =
183         getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
184     // Replace operation with intrinsic.
185     Type resType = typeConverter->convertType(cType);
186     bool zexta = op.getIsZextLhs();
187     bool zextb = op.getIsZextRhs();
188     if (zexta && zextb)
189       rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
190           op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
191           adaptor.getLhs(), adaptor.getRhs());
192     else if (zexta && !zextb)
193       rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>(
194           op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
195           adaptor.getLhs(), adaptor.getRhs());
196     else if (!zexta && zextb)
197       rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>(
198           op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
199           adaptor.getLhs(), adaptor.getRhs());
200     else
201       rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>(
202           op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
203           adaptor.getLhs(), adaptor.getRhs());
204     return success();
205   }
206 };
207 
208 } // namespace
209 
210 void mlir::populateAMXLegalizeForLLVMExportPatterns(
211     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
212   patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
213                TileMulFConversion, TileMulIConversion>(converter);
214   converter.addConversion([&](amx::TileType type) {
215     return LLVM::LLVMX86AMXType::get(&converter.getContext());
216   });
217 }
218 
219 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
220   target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
221                     x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd,
222                     x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>();
223   target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
224                       TileMulFOp>();
225 }
226 
227 namespace {
228 /// Implement the interface to convert AMX to LLVM.
229 struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
230   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
231 
232   void populateConvertToLLVMConversionPatterns(
233       ConversionTarget &target, LLVMTypeConverter &typeConverter,
234       RewritePatternSet &patterns) const final {
235     populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
236   }
237 };
238 } // namespace
239 
240 void mlir::registerConvertAMXToLLVMInterface(DialectRegistry &registry) {
241   registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
242     dialect->addInterfaces<AMXToLLVMDialectInterface>();
243   });
244 }
245