xref: /llvm-project/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp (revision ec450b19004a653f3db3ad50e88fbf6529a9d841)
1 //===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to XeGPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
14 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 #include "mlir/Transforms/Passes.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 #include <algorithm>
25 #include <optional>
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 
34 namespace {
35 
36 // Return true if value represents a zero constant.
37 static bool isZeroConstant(Value val) {
38   auto constant = val.getDefiningOp<arith::ConstantOp>();
39   if (!constant)
40     return false;
41 
42   return TypeSwitch<Attribute, bool>(constant.getValue())
43       .Case<FloatAttr>(
44           [](auto floatAttr) { return floatAttr.getValue().isZero(); })
45       .Case<IntegerAttr>(
46           [](auto intAttr) { return intAttr.getValue().isZero(); })
47       .Default([](auto) { return false; });
48 }
49 
50 static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
51                                             Operation *op, VectorType vecTy) {
52   // Validate only vector as the basic vector store and load ops guarantee
53   // XeGPU-compatible memref source.
54   unsigned vecRank = vecTy.getRank();
55   if (!(vecRank == 1 || vecRank == 2))
56     return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
57 
58   return success();
59 }
60 
61 static LogicalResult transferPreconditions(PatternRewriter &rewriter,
62                                            VectorTransferOpInterface xferOp) {
63   if (xferOp.getMask())
64     return rewriter.notifyMatchFailure(xferOp,
65                                        "Masked transfer is not supported");
66 
67   auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
68   if (!srcTy)
69     return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
70 
71   // Perform common data transfer checks.
72   VectorType vecTy = xferOp.getVectorType();
73   if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
74     return failure();
75 
76   // Validate further transfer op semantics.
77   SmallVector<int64_t> strides;
78   int64_t offset;
79   if (failed(getStridesAndOffset(srcTy, strides, offset)) ||
80       strides.back() != 1)
81     return rewriter.notifyMatchFailure(
82         xferOp, "Buffer must be contiguous in the innermost dimension");
83 
84   unsigned vecRank = vecTy.getRank();
85   AffineMap map = xferOp.getPermutationMap();
86   if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
87     return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
88   unsigned numInputDims = map.getNumInputs();
89   for (AffineExpr expr : map.getResults().take_back(vecRank)) {
90     auto dim = dyn_cast<AffineDimExpr>(expr);
91     if (dim.getPosition() < (numInputDims - vecRank))
92       return rewriter.notifyMatchFailure(
93           xferOp, "Only the innermost dimensions can be accessed");
94   }
95 
96   return success();
97 }
98 
99 static xegpu::CreateNdDescOp
100 createNdDescriptor(PatternRewriter &rewriter, Location loc,
101                    xegpu::TensorDescType descType, TypedValue<MemRefType> src,
102                    Operation::operand_range offsets) {
103   MemRefType srcTy = src.getType();
104   auto [strides, offset] = getStridesAndOffset(srcTy);
105 
106   xegpu::CreateNdDescOp ndDesc;
107   if (srcTy.hasStaticShape()) {
108     ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src,
109                                                     getAsOpFoldResult(offsets));
110   } else {
111     // In case of any dynamic shapes, source's shape and strides have to be
112     // explicitly provided.
113     SmallVector<Value> sourceDims;
114     unsigned srcRank = srcTy.getRank();
115     for (unsigned i = 0; i < srcRank; ++i)
116       sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i));
117 
118     SmallVector<int64_t> constOffsets;
119     SmallVector<Value> dynOffsets;
120     for (Value offset : offsets) {
121       std::optional<int64_t> staticVal = getConstantIntValue(offset);
122       if (!staticVal)
123         dynOffsets.push_back(offset);
124       constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
125     }
126 
127     SmallVector<Value> dynShapes;
128     for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
129       if (shape == ShapedType::kDynamic)
130         dynShapes.push_back(sourceDims[idx]);
131     }
132 
133     // Compute strides in reverse order.
134     SmallVector<Value> dynStrides;
135     Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
136     // Last stride is guaranteed to be static and unit.
137     for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
138       accStride =
139           rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
140       if (strides[i] == ShapedType::kDynamic)
141         dynStrides.push_back(accStride);
142     }
143     std::reverse(dynStrides.begin(), dynStrides.end());
144 
145     ndDesc = rewriter.create<xegpu::CreateNdDescOp>(
146         loc, descType, src, dynOffsets, dynShapes, dynStrides,
147         DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
148         DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
149         DenseI64ArrayAttr::get(rewriter.getContext(), strides));
150   }
151 
152   return ndDesc;
153 }
154 
155 struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
156   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
157 
158   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
159                                 PatternRewriter &rewriter) const override {
160     Location loc = readOp.getLoc();
161 
162     if (failed(transferPreconditions(rewriter, readOp)))
163       return failure();
164 
165     bool isOutOfBounds = readOp.hasOutOfBoundsDim();
166     if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
167       return rewriter.notifyMatchFailure(
168           readOp, "Unsupported non-zero padded out-of-bounds read");
169 
170     AffineMap readMap = readOp.getPermutationMap();
171     bool isTransposeLoad = !readMap.isMinorIdentity();
172 
173     VectorType vecTy = readOp.getVectorType();
174     Type elementType = vecTy.getElementType();
175     unsigned minTransposeBitWidth = 32;
176     if (isTransposeLoad &&
177         elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
178       return rewriter.notifyMatchFailure(
179           readOp, "Unsupported data type for tranposition");
180 
181     // If load is transposed, get the base shape for the tensor descriptor.
182     SmallVector<int64_t> descShape{vecTy.getShape()};
183     if (isTransposeLoad)
184       std::reverse(descShape.begin(), descShape.end());
185     auto descType = xegpu::TensorDescType::get(
186         descShape, elementType, /*array_length=*/1,
187         /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
188 
189     xegpu::CreateNdDescOp ndDesc =
190         createNdDescriptor(rewriter, loc, descType,
191                            dyn_cast<TypedValue<MemRefType>>(readOp.getSource()),
192                            readOp.getIndices());
193 
194     DenseI64ArrayAttr transposeAttr =
195         !isTransposeLoad ? nullptr
196                          : DenseI64ArrayAttr::get(rewriter.getContext(),
197                                                   ArrayRef<int64_t>{1, 0});
198     // By default, no specific caching policy is assigned.
199     xegpu::CachePolicyAttr hint = nullptr;
200     auto loadOp = rewriter.create<xegpu::LoadNdOp>(
201         loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr,
202         /*l1_hint=*/hint,
203         /*l2_hint=*/hint, /*l3_hint=*/hint);
204     rewriter.replaceOp(readOp, loadOp);
205 
206     return success();
207   }
208 };
209 
210 struct TransferWriteLowering
211     : public OpRewritePattern<vector::TransferWriteOp> {
212   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
213 
214   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
215                                 PatternRewriter &rewriter) const override {
216     Location loc = writeOp.getLoc();
217 
218     if (failed(transferPreconditions(rewriter, writeOp)))
219       return failure();
220 
221     AffineMap map = writeOp.getPermutationMap();
222     if (!map.isMinorIdentity())
223       return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
224 
225     VectorType vecTy = writeOp.getVectorType();
226     auto descType = xegpu::TensorDescType::get(
227         vecTy.getShape(), vecTy.getElementType(),
228         /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
229         xegpu::MemorySpace::Global);
230     xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
231         rewriter, loc, descType,
232         dyn_cast<TypedValue<MemRefType>>(writeOp.getSource()),
233         writeOp.getIndices());
234 
235     // By default, no specific caching policy is assigned.
236     xegpu::CachePolicyAttr hint = nullptr;
237     auto storeOp =
238         rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc,
239                                           /*l1_hint=*/hint,
240                                           /*l2_hint=*/hint, /*l3_hint=*/hint);
241     rewriter.replaceOp(writeOp, storeOp);
242 
243     return success();
244   }
245 };
246 
247 struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
248   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
249 
250   LogicalResult matchAndRewrite(vector::LoadOp loadOp,
251                                 PatternRewriter &rewriter) const override {
252     Location loc = loadOp.getLoc();
253 
254     VectorType vecTy = loadOp.getResult().getType();
255     if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
256       return failure();
257 
258     auto descType = xegpu::TensorDescType::get(
259         vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
260         /*boundary_check=*/true, xegpu::MemorySpace::Global);
261     xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
262         rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
263 
264     // By default, no specific caching policy is assigned.
265     xegpu::CachePolicyAttr hint = nullptr;
266     auto loadNdOp = rewriter.create<xegpu::LoadNdOp>(
267         loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
268         /*l1_hint=*/hint,
269         /*l2_hint=*/hint, /*l3_hint=*/hint);
270     rewriter.replaceOp(loadOp, loadNdOp);
271 
272     return success();
273   }
274 };
275 
276 struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
277   using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
278 
279   LogicalResult matchAndRewrite(vector::StoreOp storeOp,
280                                 PatternRewriter &rewriter) const override {
281     Location loc = storeOp.getLoc();
282 
283     TypedValue<VectorType> vector = storeOp.getValueToStore();
284     VectorType vecTy = vector.getType();
285     if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
286       return failure();
287 
288     auto descType =
289         xegpu::TensorDescType::get(vecTy.getShape(), vecTy.getElementType(),
290                                    /*array_length=*/1, /*boundary_check=*/true,
291                                    xegpu::MemorySpace::Global);
292     xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
293         rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
294 
295     // By default, no specific caching policy is assigned.
296     xegpu::CachePolicyAttr hint = nullptr;
297     auto storeNdOp =
298         rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc,
299                                           /*l1_hint=*/hint,
300                                           /*l2_hint=*/hint, /*l3_hint=*/hint);
301     rewriter.replaceOp(storeOp, storeNdOp);
302 
303     return success();
304   }
305 };
306 
307 struct ConvertVectorToXeGPUPass
308     : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
309   void runOnOperation() override {
310     RewritePatternSet patterns(&getContext());
311     populateVectorToXeGPUConversionPatterns(patterns);
312     if (failed(
313             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
314       return signalPassFailure();
315   }
316 };
317 
318 } // namespace
319 
320 void mlir::populateVectorToXeGPUConversionPatterns(
321     RewritePatternSet &patterns) {
322   patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
323                StoreLowering>(patterns.getContext());
324 }
325 
326 std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() {
327   return std::make_unique<ConvertVectorToXeGPUPass>();
328 }
329