xref: /llvm-project/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp (revision aa2952165cd1808dab2bb49b97becc097f4c9cac)
102d34d80SAdam Siemieniuk //===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
202d34d80SAdam Siemieniuk //
302d34d80SAdam Siemieniuk // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
402d34d80SAdam Siemieniuk // See https://llvm.org/LICENSE.txt for license information.
502d34d80SAdam Siemieniuk // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
602d34d80SAdam Siemieniuk //
702d34d80SAdam Siemieniuk //===----------------------------------------------------------------------===//
802d34d80SAdam Siemieniuk //
902d34d80SAdam Siemieniuk // This file implements lowering of vector operations to XeGPU dialect ops.
1002d34d80SAdam Siemieniuk //
1102d34d80SAdam Siemieniuk //===----------------------------------------------------------------------===//
1202d34d80SAdam Siemieniuk 
1302d34d80SAdam Siemieniuk #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
1402d34d80SAdam Siemieniuk 
1502d34d80SAdam Siemieniuk #include "mlir/Dialect/Arith/IR/Arith.h"
1602d34d80SAdam Siemieniuk #include "mlir/Dialect/MemRef/IR/MemRef.h"
1702d34d80SAdam Siemieniuk #include "mlir/Dialect/Vector/IR/VectorOps.h"
1802d34d80SAdam Siemieniuk #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1902d34d80SAdam Siemieniuk #include "mlir/Pass/Pass.h"
2002d34d80SAdam Siemieniuk #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2102d34d80SAdam Siemieniuk #include "mlir/Transforms/Passes.h"
2202d34d80SAdam Siemieniuk #include "llvm/ADT/TypeSwitch.h"
2302d34d80SAdam Siemieniuk 
2402d34d80SAdam Siemieniuk #include <algorithm>
2502d34d80SAdam Siemieniuk #include <optional>
2602d34d80SAdam Siemieniuk 
2702d34d80SAdam Siemieniuk namespace mlir {
2802d34d80SAdam Siemieniuk #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
2902d34d80SAdam Siemieniuk #include "mlir/Conversion/Passes.h.inc"
3002d34d80SAdam Siemieniuk } // namespace mlir
3102d34d80SAdam Siemieniuk 
3202d34d80SAdam Siemieniuk using namespace mlir;
3302d34d80SAdam Siemieniuk 
3402d34d80SAdam Siemieniuk namespace {
3502d34d80SAdam Siemieniuk 
366c25604dSAdam Siemieniuk // Return true if value represents a zero constant.
3702d34d80SAdam Siemieniuk static bool isZeroConstant(Value val) {
3802d34d80SAdam Siemieniuk   auto constant = val.getDefiningOp<arith::ConstantOp>();
3902d34d80SAdam Siemieniuk   if (!constant)
4002d34d80SAdam Siemieniuk     return false;
4102d34d80SAdam Siemieniuk 
4202d34d80SAdam Siemieniuk   return TypeSwitch<Attribute, bool>(constant.getValue())
4302d34d80SAdam Siemieniuk       .Case<FloatAttr>(
4402d34d80SAdam Siemieniuk           [](auto floatAttr) { return floatAttr.getValue().isZero(); })
4502d34d80SAdam Siemieniuk       .Case<IntegerAttr>(
4602d34d80SAdam Siemieniuk           [](auto intAttr) { return intAttr.getValue().isZero(); })
4702d34d80SAdam Siemieniuk       .Default([](auto) { return false; });
4802d34d80SAdam Siemieniuk }
4902d34d80SAdam Siemieniuk 
506c25604dSAdam Siemieniuk static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
516c25604dSAdam Siemieniuk                                             Operation *op, VectorType vecTy) {
526c25604dSAdam Siemieniuk   // Validate only vector as the basic vector store and load ops guarantee
536c25604dSAdam Siemieniuk   // XeGPU-compatible memref source.
546c25604dSAdam Siemieniuk   unsigned vecRank = vecTy.getRank();
556c25604dSAdam Siemieniuk   if (!(vecRank == 1 || vecRank == 2))
566c25604dSAdam Siemieniuk     return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
576c25604dSAdam Siemieniuk 
586c25604dSAdam Siemieniuk   return success();
596c25604dSAdam Siemieniuk }
606c25604dSAdam Siemieniuk 
6102d34d80SAdam Siemieniuk static LogicalResult transferPreconditions(PatternRewriter &rewriter,
6202d34d80SAdam Siemieniuk                                            VectorTransferOpInterface xferOp) {
6302d34d80SAdam Siemieniuk   if (xferOp.getMask())
6402d34d80SAdam Siemieniuk     return rewriter.notifyMatchFailure(xferOp,
6502d34d80SAdam Siemieniuk                                        "Masked transfer is not supported");
6602d34d80SAdam Siemieniuk 
6702d34d80SAdam Siemieniuk   auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
6802d34d80SAdam Siemieniuk   if (!srcTy)
6902d34d80SAdam Siemieniuk     return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
7002d34d80SAdam Siemieniuk 
716c25604dSAdam Siemieniuk   // Perform common data transfer checks.
726c25604dSAdam Siemieniuk   VectorType vecTy = xferOp.getVectorType();
736c25604dSAdam Siemieniuk   if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
746c25604dSAdam Siemieniuk     return failure();
756c25604dSAdam Siemieniuk 
766c25604dSAdam Siemieniuk   // Validate further transfer op semantics.
7702d34d80SAdam Siemieniuk   SmallVector<int64_t> strides;
7802d34d80SAdam Siemieniuk   int64_t offset;
796aaa8f25SMatthias Springer   if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
8002d34d80SAdam Siemieniuk     return rewriter.notifyMatchFailure(
8102d34d80SAdam Siemieniuk         xferOp, "Buffer must be contiguous in the innermost dimension");
8202d34d80SAdam Siemieniuk 
836c25604dSAdam Siemieniuk   unsigned vecRank = vecTy.getRank();
844c597d42SAdam Siemieniuk   if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
854c597d42SAdam Siemieniuk     return rewriter.notifyMatchFailure(
864c597d42SAdam Siemieniuk         xferOp, "Boundary check is available only for block instructions.");
874c597d42SAdam Siemieniuk 
8802d34d80SAdam Siemieniuk   AffineMap map = xferOp.getPermutationMap();
8902d34d80SAdam Siemieniuk   if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
9002d34d80SAdam Siemieniuk     return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
9102d34d80SAdam Siemieniuk   unsigned numInputDims = map.getNumInputs();
9202d34d80SAdam Siemieniuk   for (AffineExpr expr : map.getResults().take_back(vecRank)) {
9302d34d80SAdam Siemieniuk     auto dim = dyn_cast<AffineDimExpr>(expr);
9402d34d80SAdam Siemieniuk     if (dim.getPosition() < (numInputDims - vecRank))
9502d34d80SAdam Siemieniuk       return rewriter.notifyMatchFailure(
9602d34d80SAdam Siemieniuk           xferOp, "Only the innermost dimensions can be accessed");
9702d34d80SAdam Siemieniuk   }
9802d34d80SAdam Siemieniuk 
9902d34d80SAdam Siemieniuk   return success();
10002d34d80SAdam Siemieniuk }
10102d34d80SAdam Siemieniuk 
10202d34d80SAdam Siemieniuk static xegpu::CreateNdDescOp
10302d34d80SAdam Siemieniuk createNdDescriptor(PatternRewriter &rewriter, Location loc,
10402d34d80SAdam Siemieniuk                    xegpu::TensorDescType descType, TypedValue<MemRefType> src,
10502d34d80SAdam Siemieniuk                    Operation::operand_range offsets) {
10602d34d80SAdam Siemieniuk   MemRefType srcTy = src.getType();
1076aaa8f25SMatthias Springer   auto [strides, offset] = srcTy.getStridesAndOffset();
10802d34d80SAdam Siemieniuk 
10902d34d80SAdam Siemieniuk   xegpu::CreateNdDescOp ndDesc;
11002d34d80SAdam Siemieniuk   if (srcTy.hasStaticShape()) {
11102d34d80SAdam Siemieniuk     ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src,
11202d34d80SAdam Siemieniuk                                                     getAsOpFoldResult(offsets));
11302d34d80SAdam Siemieniuk   } else {
11402d34d80SAdam Siemieniuk     // In case of any dynamic shapes, source's shape and strides have to be
11502d34d80SAdam Siemieniuk     // explicitly provided.
11602d34d80SAdam Siemieniuk     SmallVector<Value> sourceDims;
11702d34d80SAdam Siemieniuk     unsigned srcRank = srcTy.getRank();
11802d34d80SAdam Siemieniuk     for (unsigned i = 0; i < srcRank; ++i)
11902d34d80SAdam Siemieniuk       sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i));
12002d34d80SAdam Siemieniuk 
12102d34d80SAdam Siemieniuk     SmallVector<int64_t> constOffsets;
12202d34d80SAdam Siemieniuk     SmallVector<Value> dynOffsets;
12302d34d80SAdam Siemieniuk     for (Value offset : offsets) {
12402d34d80SAdam Siemieniuk       std::optional<int64_t> staticVal = getConstantIntValue(offset);
12502d34d80SAdam Siemieniuk       if (!staticVal)
12602d34d80SAdam Siemieniuk         dynOffsets.push_back(offset);
127b52885bcSKazu Hirata       constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
12802d34d80SAdam Siemieniuk     }
12902d34d80SAdam Siemieniuk 
13002d34d80SAdam Siemieniuk     SmallVector<Value> dynShapes;
13102d34d80SAdam Siemieniuk     for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
13202d34d80SAdam Siemieniuk       if (shape == ShapedType::kDynamic)
13302d34d80SAdam Siemieniuk         dynShapes.push_back(sourceDims[idx]);
13402d34d80SAdam Siemieniuk     }
13502d34d80SAdam Siemieniuk 
13602d34d80SAdam Siemieniuk     // Compute strides in reverse order.
13702d34d80SAdam Siemieniuk     SmallVector<Value> dynStrides;
13802d34d80SAdam Siemieniuk     Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
13902d34d80SAdam Siemieniuk     // Last stride is guaranteed to be static and unit.
14002d34d80SAdam Siemieniuk     for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
14102d34d80SAdam Siemieniuk       accStride =
14202d34d80SAdam Siemieniuk           rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
14302d34d80SAdam Siemieniuk       if (strides[i] == ShapedType::kDynamic)
14402d34d80SAdam Siemieniuk         dynStrides.push_back(accStride);
14502d34d80SAdam Siemieniuk     }
14602d34d80SAdam Siemieniuk     std::reverse(dynStrides.begin(), dynStrides.end());
14702d34d80SAdam Siemieniuk 
14802d34d80SAdam Siemieniuk     ndDesc = rewriter.create<xegpu::CreateNdDescOp>(
14902d34d80SAdam Siemieniuk         loc, descType, src, dynOffsets, dynShapes, dynStrides,
15002d34d80SAdam Siemieniuk         DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
15102d34d80SAdam Siemieniuk         DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
15202d34d80SAdam Siemieniuk         DenseI64ArrayAttr::get(rewriter.getContext(), strides));
15302d34d80SAdam Siemieniuk   }
15402d34d80SAdam Siemieniuk 
15502d34d80SAdam Siemieniuk   return ndDesc;
15602d34d80SAdam Siemieniuk }
15702d34d80SAdam Siemieniuk 
15802d34d80SAdam Siemieniuk struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
15902d34d80SAdam Siemieniuk   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
16002d34d80SAdam Siemieniuk 
16102d34d80SAdam Siemieniuk   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
16202d34d80SAdam Siemieniuk                                 PatternRewriter &rewriter) const override {
16302d34d80SAdam Siemieniuk     Location loc = readOp.getLoc();
16402d34d80SAdam Siemieniuk 
16502d34d80SAdam Siemieniuk     if (failed(transferPreconditions(rewriter, readOp)))
16602d34d80SAdam Siemieniuk       return failure();
16702d34d80SAdam Siemieniuk 
16802d34d80SAdam Siemieniuk     bool isOutOfBounds = readOp.hasOutOfBoundsDim();
16902d34d80SAdam Siemieniuk     if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
17002d34d80SAdam Siemieniuk       return rewriter.notifyMatchFailure(
17102d34d80SAdam Siemieniuk           readOp, "Unsupported non-zero padded out-of-bounds read");
17202d34d80SAdam Siemieniuk 
17302d34d80SAdam Siemieniuk     AffineMap readMap = readOp.getPermutationMap();
17402d34d80SAdam Siemieniuk     bool isTransposeLoad = !readMap.isMinorIdentity();
17502d34d80SAdam Siemieniuk 
17602d34d80SAdam Siemieniuk     VectorType vecTy = readOp.getVectorType();
17702d34d80SAdam Siemieniuk     Type elementType = vecTy.getElementType();
17802d34d80SAdam Siemieniuk     unsigned minTransposeBitWidth = 32;
17902d34d80SAdam Siemieniuk     if (isTransposeLoad &&
18002d34d80SAdam Siemieniuk         elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
18102d34d80SAdam Siemieniuk       return rewriter.notifyMatchFailure(
182*aa295216SJay Foad           readOp, "Unsupported data type for transposition");
18302d34d80SAdam Siemieniuk 
18402d34d80SAdam Siemieniuk     // If load is transposed, get the base shape for the tensor descriptor.
1859cbc1f29SHan-Chung Wang     SmallVector<int64_t> descShape(vecTy.getShape());
18602d34d80SAdam Siemieniuk     if (isTransposeLoad)
18702d34d80SAdam Siemieniuk       std::reverse(descShape.begin(), descShape.end());
18802d34d80SAdam Siemieniuk     auto descType = xegpu::TensorDescType::get(
1898b5e8414SChao Chen         descShape, elementType, /*array_length=*/1,
1908b5e8414SChao Chen         /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
19102d34d80SAdam Siemieniuk 
19202d34d80SAdam Siemieniuk     xegpu::CreateNdDescOp ndDesc =
19302d34d80SAdam Siemieniuk         createNdDescriptor(rewriter, loc, descType,
19402d34d80SAdam Siemieniuk                            dyn_cast<TypedValue<MemRefType>>(readOp.getSource()),
19502d34d80SAdam Siemieniuk                            readOp.getIndices());
19602d34d80SAdam Siemieniuk 
19702d34d80SAdam Siemieniuk     DenseI64ArrayAttr transposeAttr =
19802d34d80SAdam Siemieniuk         !isTransposeLoad ? nullptr
19902d34d80SAdam Siemieniuk                          : DenseI64ArrayAttr::get(rewriter.getContext(),
20002d34d80SAdam Siemieniuk                                                   ArrayRef<int64_t>{1, 0});
20102d34d80SAdam Siemieniuk     // By default, no specific caching policy is assigned.
20202d34d80SAdam Siemieniuk     xegpu::CachePolicyAttr hint = nullptr;
20302d34d80SAdam Siemieniuk     auto loadOp = rewriter.create<xegpu::LoadNdOp>(
20402d34d80SAdam Siemieniuk         loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr,
20502d34d80SAdam Siemieniuk         /*l1_hint=*/hint,
20602d34d80SAdam Siemieniuk         /*l2_hint=*/hint, /*l3_hint=*/hint);
20702d34d80SAdam Siemieniuk     rewriter.replaceOp(readOp, loadOp);
20802d34d80SAdam Siemieniuk 
20902d34d80SAdam Siemieniuk     return success();
21002d34d80SAdam Siemieniuk   }
21102d34d80SAdam Siemieniuk };
21202d34d80SAdam Siemieniuk 
21302d34d80SAdam Siemieniuk struct TransferWriteLowering
21402d34d80SAdam Siemieniuk     : public OpRewritePattern<vector::TransferWriteOp> {
21502d34d80SAdam Siemieniuk   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
21602d34d80SAdam Siemieniuk 
21702d34d80SAdam Siemieniuk   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
21802d34d80SAdam Siemieniuk                                 PatternRewriter &rewriter) const override {
21902d34d80SAdam Siemieniuk     Location loc = writeOp.getLoc();
22002d34d80SAdam Siemieniuk 
22102d34d80SAdam Siemieniuk     if (failed(transferPreconditions(rewriter, writeOp)))
22202d34d80SAdam Siemieniuk       return failure();
22302d34d80SAdam Siemieniuk 
22402d34d80SAdam Siemieniuk     AffineMap map = writeOp.getPermutationMap();
22502d34d80SAdam Siemieniuk     if (!map.isMinorIdentity())
22602d34d80SAdam Siemieniuk       return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
22702d34d80SAdam Siemieniuk 
22802d34d80SAdam Siemieniuk     VectorType vecTy = writeOp.getVectorType();
229ec450b19SAdam Siemieniuk     auto descType = xegpu::TensorDescType::get(
230ec450b19SAdam Siemieniuk         vecTy.getShape(), vecTy.getElementType(),
231ec450b19SAdam Siemieniuk         /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
2328b5e8414SChao Chen         xegpu::MemorySpace::Global);
23302d34d80SAdam Siemieniuk     xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
23402d34d80SAdam Siemieniuk         rewriter, loc, descType,
23502d34d80SAdam Siemieniuk         dyn_cast<TypedValue<MemRefType>>(writeOp.getSource()),
23602d34d80SAdam Siemieniuk         writeOp.getIndices());
23702d34d80SAdam Siemieniuk 
23802d34d80SAdam Siemieniuk     // By default, no specific caching policy is assigned.
23902d34d80SAdam Siemieniuk     xegpu::CachePolicyAttr hint = nullptr;
24002d34d80SAdam Siemieniuk     auto storeOp =
24102d34d80SAdam Siemieniuk         rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc,
24202d34d80SAdam Siemieniuk                                           /*l1_hint=*/hint,
24302d34d80SAdam Siemieniuk                                           /*l2_hint=*/hint, /*l3_hint=*/hint);
24402d34d80SAdam Siemieniuk     rewriter.replaceOp(writeOp, storeOp);
24502d34d80SAdam Siemieniuk 
24602d34d80SAdam Siemieniuk     return success();
24702d34d80SAdam Siemieniuk   }
24802d34d80SAdam Siemieniuk };
24902d34d80SAdam Siemieniuk 
2506c25604dSAdam Siemieniuk struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
2516c25604dSAdam Siemieniuk   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
2526c25604dSAdam Siemieniuk 
2536c25604dSAdam Siemieniuk   LogicalResult matchAndRewrite(vector::LoadOp loadOp,
2546c25604dSAdam Siemieniuk                                 PatternRewriter &rewriter) const override {
2556c25604dSAdam Siemieniuk     Location loc = loadOp.getLoc();
2566c25604dSAdam Siemieniuk 
2576c25604dSAdam Siemieniuk     VectorType vecTy = loadOp.getResult().getType();
2586c25604dSAdam Siemieniuk     if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
2596c25604dSAdam Siemieniuk       return failure();
2606c25604dSAdam Siemieniuk 
2614c597d42SAdam Siemieniuk     // Boundary check is available only for block instructions.
2624c597d42SAdam Siemieniuk     bool boundaryCheck = vecTy.getRank() > 1;
2634c597d42SAdam Siemieniuk 
2646c25604dSAdam Siemieniuk     auto descType = xegpu::TensorDescType::get(
2656c25604dSAdam Siemieniuk         vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
2664c597d42SAdam Siemieniuk         boundaryCheck, xegpu::MemorySpace::Global);
2676c25604dSAdam Siemieniuk     xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
2686c25604dSAdam Siemieniuk         rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
2696c25604dSAdam Siemieniuk 
2706c25604dSAdam Siemieniuk     // By default, no specific caching policy is assigned.
2716c25604dSAdam Siemieniuk     xegpu::CachePolicyAttr hint = nullptr;
2726c25604dSAdam Siemieniuk     auto loadNdOp = rewriter.create<xegpu::LoadNdOp>(
2736c25604dSAdam Siemieniuk         loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
2746c25604dSAdam Siemieniuk         /*l1_hint=*/hint,
2756c25604dSAdam Siemieniuk         /*l2_hint=*/hint, /*l3_hint=*/hint);
2766c25604dSAdam Siemieniuk     rewriter.replaceOp(loadOp, loadNdOp);
2776c25604dSAdam Siemieniuk 
2786c25604dSAdam Siemieniuk     return success();
2796c25604dSAdam Siemieniuk   }
2806c25604dSAdam Siemieniuk };
2816c25604dSAdam Siemieniuk 
2826c25604dSAdam Siemieniuk struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
2836c25604dSAdam Siemieniuk   using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
2846c25604dSAdam Siemieniuk 
2856c25604dSAdam Siemieniuk   LogicalResult matchAndRewrite(vector::StoreOp storeOp,
2866c25604dSAdam Siemieniuk                                 PatternRewriter &rewriter) const override {
2876c25604dSAdam Siemieniuk     Location loc = storeOp.getLoc();
2886c25604dSAdam Siemieniuk 
2896c25604dSAdam Siemieniuk     TypedValue<VectorType> vector = storeOp.getValueToStore();
2906c25604dSAdam Siemieniuk     VectorType vecTy = vector.getType();
2916c25604dSAdam Siemieniuk     if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
2926c25604dSAdam Siemieniuk       return failure();
2936c25604dSAdam Siemieniuk 
2944c597d42SAdam Siemieniuk     // Boundary check is available only for block instructions.
2954c597d42SAdam Siemieniuk     bool boundaryCheck = vecTy.getRank() > 1;
2964c597d42SAdam Siemieniuk 
2974c597d42SAdam Siemieniuk     auto descType = xegpu::TensorDescType::get(
2984c597d42SAdam Siemieniuk         vecTy.getShape(), vecTy.getElementType(),
2994c597d42SAdam Siemieniuk         /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
3006c25604dSAdam Siemieniuk     xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
3016c25604dSAdam Siemieniuk         rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
3026c25604dSAdam Siemieniuk 
3036c25604dSAdam Siemieniuk     // By default, no specific caching policy is assigned.
3046c25604dSAdam Siemieniuk     xegpu::CachePolicyAttr hint = nullptr;
3056c25604dSAdam Siemieniuk     auto storeNdOp =
3066c25604dSAdam Siemieniuk         rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc,
3076c25604dSAdam Siemieniuk                                           /*l1_hint=*/hint,
3086c25604dSAdam Siemieniuk                                           /*l2_hint=*/hint, /*l3_hint=*/hint);
3096c25604dSAdam Siemieniuk     rewriter.replaceOp(storeOp, storeNdOp);
3106c25604dSAdam Siemieniuk 
3116c25604dSAdam Siemieniuk     return success();
3126c25604dSAdam Siemieniuk   }
3136c25604dSAdam Siemieniuk };
3146c25604dSAdam Siemieniuk 
31502d34d80SAdam Siemieniuk struct ConvertVectorToXeGPUPass
31602d34d80SAdam Siemieniuk     : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
31702d34d80SAdam Siemieniuk   void runOnOperation() override {
31802d34d80SAdam Siemieniuk     RewritePatternSet patterns(&getContext());
31902d34d80SAdam Siemieniuk     populateVectorToXeGPUConversionPatterns(patterns);
32009dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
32102d34d80SAdam Siemieniuk       return signalPassFailure();
32202d34d80SAdam Siemieniuk   }
32302d34d80SAdam Siemieniuk };
32402d34d80SAdam Siemieniuk 
32502d34d80SAdam Siemieniuk } // namespace
32602d34d80SAdam Siemieniuk 
32702d34d80SAdam Siemieniuk void mlir::populateVectorToXeGPUConversionPatterns(
32802d34d80SAdam Siemieniuk     RewritePatternSet &patterns) {
3296c25604dSAdam Siemieniuk   patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
3306c25604dSAdam Siemieniuk                StoreLowering>(patterns.getContext());
33102d34d80SAdam Siemieniuk }
33202d34d80SAdam Siemieniuk 
33302d34d80SAdam Siemieniuk std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() {
33402d34d80SAdam Siemieniuk   return std::make_unique<ConvertVectorToXeGPUPass>();
33502d34d80SAdam Siemieniuk }
336