102d34d80SAdam Siemieniuk //===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===// 202d34d80SAdam Siemieniuk // 302d34d80SAdam Siemieniuk // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 402d34d80SAdam Siemieniuk // See https://llvm.org/LICENSE.txt for license information. 502d34d80SAdam Siemieniuk // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 602d34d80SAdam Siemieniuk // 702d34d80SAdam Siemieniuk //===----------------------------------------------------------------------===// 802d34d80SAdam Siemieniuk // 902d34d80SAdam Siemieniuk // This file implements lowering of vector operations to XeGPU dialect ops. 1002d34d80SAdam Siemieniuk // 1102d34d80SAdam Siemieniuk //===----------------------------------------------------------------------===// 1202d34d80SAdam Siemieniuk 1302d34d80SAdam Siemieniuk #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" 1402d34d80SAdam Siemieniuk 1502d34d80SAdam Siemieniuk #include "mlir/Dialect/Arith/IR/Arith.h" 1602d34d80SAdam Siemieniuk #include "mlir/Dialect/MemRef/IR/MemRef.h" 1702d34d80SAdam Siemieniuk #include "mlir/Dialect/Vector/IR/VectorOps.h" 1802d34d80SAdam Siemieniuk #include "mlir/Dialect/XeGPU/IR/XeGPU.h" 1902d34d80SAdam Siemieniuk #include "mlir/Pass/Pass.h" 2002d34d80SAdam Siemieniuk #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 2102d34d80SAdam Siemieniuk #include "mlir/Transforms/Passes.h" 2202d34d80SAdam Siemieniuk #include "llvm/ADT/TypeSwitch.h" 2302d34d80SAdam Siemieniuk 2402d34d80SAdam Siemieniuk #include <algorithm> 2502d34d80SAdam Siemieniuk #include <optional> 2602d34d80SAdam Siemieniuk 2702d34d80SAdam Siemieniuk namespace mlir { 2802d34d80SAdam Siemieniuk #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU 2902d34d80SAdam Siemieniuk #include "mlir/Conversion/Passes.h.inc" 3002d34d80SAdam Siemieniuk } // namespace mlir 3102d34d80SAdam Siemieniuk 3202d34d80SAdam Siemieniuk using namespace mlir; 3302d34d80SAdam Siemieniuk 3402d34d80SAdam Siemieniuk namespace { 3502d34d80SAdam Siemieniuk 366c25604dSAdam Siemieniuk // Return true if value represents a zero constant. 3702d34d80SAdam Siemieniuk static bool isZeroConstant(Value val) { 3802d34d80SAdam Siemieniuk auto constant = val.getDefiningOp<arith::ConstantOp>(); 3902d34d80SAdam Siemieniuk if (!constant) 4002d34d80SAdam Siemieniuk return false; 4102d34d80SAdam Siemieniuk 4202d34d80SAdam Siemieniuk return TypeSwitch<Attribute, bool>(constant.getValue()) 4302d34d80SAdam Siemieniuk .Case<FloatAttr>( 4402d34d80SAdam Siemieniuk [](auto floatAttr) { return floatAttr.getValue().isZero(); }) 4502d34d80SAdam Siemieniuk .Case<IntegerAttr>( 4602d34d80SAdam Siemieniuk [](auto intAttr) { return intAttr.getValue().isZero(); }) 4702d34d80SAdam Siemieniuk .Default([](auto) { return false; }); 4802d34d80SAdam Siemieniuk } 4902d34d80SAdam Siemieniuk 506c25604dSAdam Siemieniuk static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter, 516c25604dSAdam Siemieniuk Operation *op, VectorType vecTy) { 526c25604dSAdam Siemieniuk // Validate only vector as the basic vector store and load ops guarantee 536c25604dSAdam Siemieniuk // XeGPU-compatible memref source. 546c25604dSAdam Siemieniuk unsigned vecRank = vecTy.getRank(); 556c25604dSAdam Siemieniuk if (!(vecRank == 1 || vecRank == 2)) 566c25604dSAdam Siemieniuk return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector"); 576c25604dSAdam Siemieniuk 586c25604dSAdam Siemieniuk return success(); 596c25604dSAdam Siemieniuk } 606c25604dSAdam Siemieniuk 6102d34d80SAdam Siemieniuk static LogicalResult transferPreconditions(PatternRewriter &rewriter, 6202d34d80SAdam Siemieniuk VectorTransferOpInterface xferOp) { 6302d34d80SAdam Siemieniuk if (xferOp.getMask()) 6402d34d80SAdam Siemieniuk return rewriter.notifyMatchFailure(xferOp, 6502d34d80SAdam Siemieniuk "Masked transfer is not supported"); 6602d34d80SAdam Siemieniuk 6702d34d80SAdam Siemieniuk auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType()); 6802d34d80SAdam Siemieniuk if (!srcTy) 6902d34d80SAdam Siemieniuk return rewriter.notifyMatchFailure(xferOp, "Expects memref source"); 7002d34d80SAdam Siemieniuk 716c25604dSAdam Siemieniuk // Perform common data transfer checks. 726c25604dSAdam Siemieniuk VectorType vecTy = xferOp.getVectorType(); 736c25604dSAdam Siemieniuk if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy))) 746c25604dSAdam Siemieniuk return failure(); 756c25604dSAdam Siemieniuk 766c25604dSAdam Siemieniuk // Validate further transfer op semantics. 7702d34d80SAdam Siemieniuk SmallVector<int64_t> strides; 7802d34d80SAdam Siemieniuk int64_t offset; 796aaa8f25SMatthias Springer if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1) 8002d34d80SAdam Siemieniuk return rewriter.notifyMatchFailure( 8102d34d80SAdam Siemieniuk xferOp, "Buffer must be contiguous in the innermost dimension"); 8202d34d80SAdam Siemieniuk 836c25604dSAdam Siemieniuk unsigned vecRank = vecTy.getRank(); 844c597d42SAdam Siemieniuk if (xferOp.hasOutOfBoundsDim() && vecRank < 2) 854c597d42SAdam Siemieniuk return rewriter.notifyMatchFailure( 864c597d42SAdam Siemieniuk xferOp, "Boundary check is available only for block instructions."); 874c597d42SAdam Siemieniuk 8802d34d80SAdam Siemieniuk AffineMap map = xferOp.getPermutationMap(); 8902d34d80SAdam Siemieniuk if (!map.isProjectedPermutation(/*allowZeroInResults=*/false)) 9002d34d80SAdam Siemieniuk return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map"); 9102d34d80SAdam Siemieniuk unsigned numInputDims = map.getNumInputs(); 9202d34d80SAdam Siemieniuk for (AffineExpr expr : map.getResults().take_back(vecRank)) { 9302d34d80SAdam Siemieniuk auto dim = dyn_cast<AffineDimExpr>(expr); 9402d34d80SAdam Siemieniuk if (dim.getPosition() < (numInputDims - vecRank)) 9502d34d80SAdam Siemieniuk return rewriter.notifyMatchFailure( 9602d34d80SAdam Siemieniuk xferOp, "Only the innermost dimensions can be accessed"); 9702d34d80SAdam Siemieniuk } 9802d34d80SAdam Siemieniuk 9902d34d80SAdam Siemieniuk return success(); 10002d34d80SAdam Siemieniuk } 10102d34d80SAdam Siemieniuk 10202d34d80SAdam Siemieniuk static xegpu::CreateNdDescOp 10302d34d80SAdam Siemieniuk createNdDescriptor(PatternRewriter &rewriter, Location loc, 10402d34d80SAdam Siemieniuk xegpu::TensorDescType descType, TypedValue<MemRefType> src, 10502d34d80SAdam Siemieniuk Operation::operand_range offsets) { 10602d34d80SAdam Siemieniuk MemRefType srcTy = src.getType(); 1076aaa8f25SMatthias Springer auto [strides, offset] = srcTy.getStridesAndOffset(); 10802d34d80SAdam Siemieniuk 10902d34d80SAdam Siemieniuk xegpu::CreateNdDescOp ndDesc; 11002d34d80SAdam Siemieniuk if (srcTy.hasStaticShape()) { 11102d34d80SAdam Siemieniuk ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src, 11202d34d80SAdam Siemieniuk getAsOpFoldResult(offsets)); 11302d34d80SAdam Siemieniuk } else { 11402d34d80SAdam Siemieniuk // In case of any dynamic shapes, source's shape and strides have to be 11502d34d80SAdam Siemieniuk // explicitly provided. 11602d34d80SAdam Siemieniuk SmallVector<Value> sourceDims; 11702d34d80SAdam Siemieniuk unsigned srcRank = srcTy.getRank(); 11802d34d80SAdam Siemieniuk for (unsigned i = 0; i < srcRank; ++i) 11902d34d80SAdam Siemieniuk sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i)); 12002d34d80SAdam Siemieniuk 12102d34d80SAdam Siemieniuk SmallVector<int64_t> constOffsets; 12202d34d80SAdam Siemieniuk SmallVector<Value> dynOffsets; 12302d34d80SAdam Siemieniuk for (Value offset : offsets) { 12402d34d80SAdam Siemieniuk std::optional<int64_t> staticVal = getConstantIntValue(offset); 12502d34d80SAdam Siemieniuk if (!staticVal) 12602d34d80SAdam Siemieniuk dynOffsets.push_back(offset); 127b52885bcSKazu Hirata constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic)); 12802d34d80SAdam Siemieniuk } 12902d34d80SAdam Siemieniuk 13002d34d80SAdam Siemieniuk SmallVector<Value> dynShapes; 13102d34d80SAdam Siemieniuk for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) { 13202d34d80SAdam Siemieniuk if (shape == ShapedType::kDynamic) 13302d34d80SAdam Siemieniuk dynShapes.push_back(sourceDims[idx]); 13402d34d80SAdam Siemieniuk } 13502d34d80SAdam Siemieniuk 13602d34d80SAdam Siemieniuk // Compute strides in reverse order. 13702d34d80SAdam Siemieniuk SmallVector<Value> dynStrides; 13802d34d80SAdam Siemieniuk Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1); 13902d34d80SAdam Siemieniuk // Last stride is guaranteed to be static and unit. 14002d34d80SAdam Siemieniuk for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) { 14102d34d80SAdam Siemieniuk accStride = 14202d34d80SAdam Siemieniuk rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]); 14302d34d80SAdam Siemieniuk if (strides[i] == ShapedType::kDynamic) 14402d34d80SAdam Siemieniuk dynStrides.push_back(accStride); 14502d34d80SAdam Siemieniuk } 14602d34d80SAdam Siemieniuk std::reverse(dynStrides.begin(), dynStrides.end()); 14702d34d80SAdam Siemieniuk 14802d34d80SAdam Siemieniuk ndDesc = rewriter.create<xegpu::CreateNdDescOp>( 14902d34d80SAdam Siemieniuk loc, descType, src, dynOffsets, dynShapes, dynStrides, 15002d34d80SAdam Siemieniuk DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets), 15102d34d80SAdam Siemieniuk DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()), 15202d34d80SAdam Siemieniuk DenseI64ArrayAttr::get(rewriter.getContext(), strides)); 15302d34d80SAdam Siemieniuk } 15402d34d80SAdam Siemieniuk 15502d34d80SAdam Siemieniuk return ndDesc; 15602d34d80SAdam Siemieniuk } 15702d34d80SAdam Siemieniuk 15802d34d80SAdam Siemieniuk struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { 15902d34d80SAdam Siemieniuk using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 16002d34d80SAdam Siemieniuk 16102d34d80SAdam Siemieniuk LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 16202d34d80SAdam Siemieniuk PatternRewriter &rewriter) const override { 16302d34d80SAdam Siemieniuk Location loc = readOp.getLoc(); 16402d34d80SAdam Siemieniuk 16502d34d80SAdam Siemieniuk if (failed(transferPreconditions(rewriter, readOp))) 16602d34d80SAdam Siemieniuk return failure(); 16702d34d80SAdam Siemieniuk 16802d34d80SAdam Siemieniuk bool isOutOfBounds = readOp.hasOutOfBoundsDim(); 16902d34d80SAdam Siemieniuk if (isOutOfBounds && !isZeroConstant(readOp.getPadding())) 17002d34d80SAdam Siemieniuk return rewriter.notifyMatchFailure( 17102d34d80SAdam Siemieniuk readOp, "Unsupported non-zero padded out-of-bounds read"); 17202d34d80SAdam Siemieniuk 17302d34d80SAdam Siemieniuk AffineMap readMap = readOp.getPermutationMap(); 17402d34d80SAdam Siemieniuk bool isTransposeLoad = !readMap.isMinorIdentity(); 17502d34d80SAdam Siemieniuk 17602d34d80SAdam Siemieniuk VectorType vecTy = readOp.getVectorType(); 17702d34d80SAdam Siemieniuk Type elementType = vecTy.getElementType(); 17802d34d80SAdam Siemieniuk unsigned minTransposeBitWidth = 32; 17902d34d80SAdam Siemieniuk if (isTransposeLoad && 18002d34d80SAdam Siemieniuk elementType.getIntOrFloatBitWidth() < minTransposeBitWidth) 18102d34d80SAdam Siemieniuk return rewriter.notifyMatchFailure( 182*aa295216SJay Foad readOp, "Unsupported data type for transposition"); 18302d34d80SAdam Siemieniuk 18402d34d80SAdam Siemieniuk // If load is transposed, get the base shape for the tensor descriptor. 1859cbc1f29SHan-Chung Wang SmallVector<int64_t> descShape(vecTy.getShape()); 18602d34d80SAdam Siemieniuk if (isTransposeLoad) 18702d34d80SAdam Siemieniuk std::reverse(descShape.begin(), descShape.end()); 18802d34d80SAdam Siemieniuk auto descType = xegpu::TensorDescType::get( 1898b5e8414SChao Chen descShape, elementType, /*array_length=*/1, 1908b5e8414SChao Chen /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global); 19102d34d80SAdam Siemieniuk 19202d34d80SAdam Siemieniuk xegpu::CreateNdDescOp ndDesc = 19302d34d80SAdam Siemieniuk createNdDescriptor(rewriter, loc, descType, 19402d34d80SAdam Siemieniuk dyn_cast<TypedValue<MemRefType>>(readOp.getSource()), 19502d34d80SAdam Siemieniuk readOp.getIndices()); 19602d34d80SAdam Siemieniuk 19702d34d80SAdam Siemieniuk DenseI64ArrayAttr transposeAttr = 19802d34d80SAdam Siemieniuk !isTransposeLoad ? nullptr 19902d34d80SAdam Siemieniuk : DenseI64ArrayAttr::get(rewriter.getContext(), 20002d34d80SAdam Siemieniuk ArrayRef<int64_t>{1, 0}); 20102d34d80SAdam Siemieniuk // By default, no specific caching policy is assigned. 20202d34d80SAdam Siemieniuk xegpu::CachePolicyAttr hint = nullptr; 20302d34d80SAdam Siemieniuk auto loadOp = rewriter.create<xegpu::LoadNdOp>( 20402d34d80SAdam Siemieniuk loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr, 20502d34d80SAdam Siemieniuk /*l1_hint=*/hint, 20602d34d80SAdam Siemieniuk /*l2_hint=*/hint, /*l3_hint=*/hint); 20702d34d80SAdam Siemieniuk rewriter.replaceOp(readOp, loadOp); 20802d34d80SAdam Siemieniuk 20902d34d80SAdam Siemieniuk return success(); 21002d34d80SAdam Siemieniuk } 21102d34d80SAdam Siemieniuk }; 21202d34d80SAdam Siemieniuk 21302d34d80SAdam Siemieniuk struct TransferWriteLowering 21402d34d80SAdam Siemieniuk : public OpRewritePattern<vector::TransferWriteOp> { 21502d34d80SAdam Siemieniuk using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 21602d34d80SAdam Siemieniuk 21702d34d80SAdam Siemieniuk LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 21802d34d80SAdam Siemieniuk PatternRewriter &rewriter) const override { 21902d34d80SAdam Siemieniuk Location loc = writeOp.getLoc(); 22002d34d80SAdam Siemieniuk 22102d34d80SAdam Siemieniuk if (failed(transferPreconditions(rewriter, writeOp))) 22202d34d80SAdam Siemieniuk return failure(); 22302d34d80SAdam Siemieniuk 22402d34d80SAdam Siemieniuk AffineMap map = writeOp.getPermutationMap(); 22502d34d80SAdam Siemieniuk if (!map.isMinorIdentity()) 22602d34d80SAdam Siemieniuk return rewriter.notifyMatchFailure(writeOp, "Expects identity map"); 22702d34d80SAdam Siemieniuk 22802d34d80SAdam Siemieniuk VectorType vecTy = writeOp.getVectorType(); 229ec450b19SAdam Siemieniuk auto descType = xegpu::TensorDescType::get( 230ec450b19SAdam Siemieniuk vecTy.getShape(), vecTy.getElementType(), 231ec450b19SAdam Siemieniuk /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(), 2328b5e8414SChao Chen xegpu::MemorySpace::Global); 23302d34d80SAdam Siemieniuk xegpu::CreateNdDescOp ndDesc = createNdDescriptor( 23402d34d80SAdam Siemieniuk rewriter, loc, descType, 23502d34d80SAdam Siemieniuk dyn_cast<TypedValue<MemRefType>>(writeOp.getSource()), 23602d34d80SAdam Siemieniuk writeOp.getIndices()); 23702d34d80SAdam Siemieniuk 23802d34d80SAdam Siemieniuk // By default, no specific caching policy is assigned. 23902d34d80SAdam Siemieniuk xegpu::CachePolicyAttr hint = nullptr; 24002d34d80SAdam Siemieniuk auto storeOp = 24102d34d80SAdam Siemieniuk rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc, 24202d34d80SAdam Siemieniuk /*l1_hint=*/hint, 24302d34d80SAdam Siemieniuk /*l2_hint=*/hint, /*l3_hint=*/hint); 24402d34d80SAdam Siemieniuk rewriter.replaceOp(writeOp, storeOp); 24502d34d80SAdam Siemieniuk 24602d34d80SAdam Siemieniuk return success(); 24702d34d80SAdam Siemieniuk } 24802d34d80SAdam Siemieniuk }; 24902d34d80SAdam Siemieniuk 2506c25604dSAdam Siemieniuk struct LoadLowering : public OpRewritePattern<vector::LoadOp> { 2516c25604dSAdam Siemieniuk using OpRewritePattern<vector::LoadOp>::OpRewritePattern; 2526c25604dSAdam Siemieniuk 2536c25604dSAdam Siemieniuk LogicalResult matchAndRewrite(vector::LoadOp loadOp, 2546c25604dSAdam Siemieniuk PatternRewriter &rewriter) const override { 2556c25604dSAdam Siemieniuk Location loc = loadOp.getLoc(); 2566c25604dSAdam Siemieniuk 2576c25604dSAdam Siemieniuk VectorType vecTy = loadOp.getResult().getType(); 2586c25604dSAdam Siemieniuk if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy))) 2596c25604dSAdam Siemieniuk return failure(); 2606c25604dSAdam Siemieniuk 2614c597d42SAdam Siemieniuk // Boundary check is available only for block instructions. 2624c597d42SAdam Siemieniuk bool boundaryCheck = vecTy.getRank() > 1; 2634c597d42SAdam Siemieniuk 2646c25604dSAdam Siemieniuk auto descType = xegpu::TensorDescType::get( 2656c25604dSAdam Siemieniuk vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, 2664c597d42SAdam Siemieniuk boundaryCheck, xegpu::MemorySpace::Global); 2676c25604dSAdam Siemieniuk xegpu::CreateNdDescOp ndDesc = createNdDescriptor( 2686c25604dSAdam Siemieniuk rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices()); 2696c25604dSAdam Siemieniuk 2706c25604dSAdam Siemieniuk // By default, no specific caching policy is assigned. 2716c25604dSAdam Siemieniuk xegpu::CachePolicyAttr hint = nullptr; 2726c25604dSAdam Siemieniuk auto loadNdOp = rewriter.create<xegpu::LoadNdOp>( 2736c25604dSAdam Siemieniuk loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr, 2746c25604dSAdam Siemieniuk /*l1_hint=*/hint, 2756c25604dSAdam Siemieniuk /*l2_hint=*/hint, /*l3_hint=*/hint); 2766c25604dSAdam Siemieniuk rewriter.replaceOp(loadOp, loadNdOp); 2776c25604dSAdam Siemieniuk 2786c25604dSAdam Siemieniuk return success(); 2796c25604dSAdam Siemieniuk } 2806c25604dSAdam Siemieniuk }; 2816c25604dSAdam Siemieniuk 2826c25604dSAdam Siemieniuk struct StoreLowering : public OpRewritePattern<vector::StoreOp> { 2836c25604dSAdam Siemieniuk using OpRewritePattern<vector::StoreOp>::OpRewritePattern; 2846c25604dSAdam Siemieniuk 2856c25604dSAdam Siemieniuk LogicalResult matchAndRewrite(vector::StoreOp storeOp, 2866c25604dSAdam Siemieniuk PatternRewriter &rewriter) const override { 2876c25604dSAdam Siemieniuk Location loc = storeOp.getLoc(); 2886c25604dSAdam Siemieniuk 2896c25604dSAdam Siemieniuk TypedValue<VectorType> vector = storeOp.getValueToStore(); 2906c25604dSAdam Siemieniuk VectorType vecTy = vector.getType(); 2916c25604dSAdam Siemieniuk if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy))) 2926c25604dSAdam Siemieniuk return failure(); 2936c25604dSAdam Siemieniuk 2944c597d42SAdam Siemieniuk // Boundary check is available only for block instructions. 2954c597d42SAdam Siemieniuk bool boundaryCheck = vecTy.getRank() > 1; 2964c597d42SAdam Siemieniuk 2974c597d42SAdam Siemieniuk auto descType = xegpu::TensorDescType::get( 2984c597d42SAdam Siemieniuk vecTy.getShape(), vecTy.getElementType(), 2994c597d42SAdam Siemieniuk /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); 3006c25604dSAdam Siemieniuk xegpu::CreateNdDescOp ndDesc = createNdDescriptor( 3016c25604dSAdam Siemieniuk rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices()); 3026c25604dSAdam Siemieniuk 3036c25604dSAdam Siemieniuk // By default, no specific caching policy is assigned. 3046c25604dSAdam Siemieniuk xegpu::CachePolicyAttr hint = nullptr; 3056c25604dSAdam Siemieniuk auto storeNdOp = 3066c25604dSAdam Siemieniuk rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc, 3076c25604dSAdam Siemieniuk /*l1_hint=*/hint, 3086c25604dSAdam Siemieniuk /*l2_hint=*/hint, /*l3_hint=*/hint); 3096c25604dSAdam Siemieniuk rewriter.replaceOp(storeOp, storeNdOp); 3106c25604dSAdam Siemieniuk 3116c25604dSAdam Siemieniuk return success(); 3126c25604dSAdam Siemieniuk } 3136c25604dSAdam Siemieniuk }; 3146c25604dSAdam Siemieniuk 31502d34d80SAdam Siemieniuk struct ConvertVectorToXeGPUPass 31602d34d80SAdam Siemieniuk : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> { 31702d34d80SAdam Siemieniuk void runOnOperation() override { 31802d34d80SAdam Siemieniuk RewritePatternSet patterns(&getContext()); 31902d34d80SAdam Siemieniuk populateVectorToXeGPUConversionPatterns(patterns); 32009dfc571SJacques Pienaar if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 32102d34d80SAdam Siemieniuk return signalPassFailure(); 32202d34d80SAdam Siemieniuk } 32302d34d80SAdam Siemieniuk }; 32402d34d80SAdam Siemieniuk 32502d34d80SAdam Siemieniuk } // namespace 32602d34d80SAdam Siemieniuk 32702d34d80SAdam Siemieniuk void mlir::populateVectorToXeGPUConversionPatterns( 32802d34d80SAdam Siemieniuk RewritePatternSet &patterns) { 3296c25604dSAdam Siemieniuk patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering, 3306c25604dSAdam Siemieniuk StoreLowering>(patterns.getContext()); 33102d34d80SAdam Siemieniuk } 33202d34d80SAdam Siemieniuk 33302d34d80SAdam Siemieniuk std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() { 33402d34d80SAdam Siemieniuk return std::make_unique<ConvertVectorToXeGPUPass>(); 33502d34d80SAdam Siemieniuk } 336