1 //===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector operations to XeGPU dialect ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" 14 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Vector/IR/VectorOps.h" 18 #include "mlir/Dialect/XeGPU/IR/XeGPU.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 #include "mlir/Transforms/Passes.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 24 #include <algorithm> 25 #include <optional> 26 27 namespace mlir { 28 #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU 29 #include "mlir/Conversion/Passes.h.inc" 30 } // namespace mlir 31 32 using namespace mlir; 33 34 namespace { 35 36 // Return true if value represents a zero constant. 37 static bool isZeroConstant(Value val) { 38 auto constant = val.getDefiningOp<arith::ConstantOp>(); 39 if (!constant) 40 return false; 41 42 return TypeSwitch<Attribute, bool>(constant.getValue()) 43 .Case<FloatAttr>( 44 [](auto floatAttr) { return floatAttr.getValue().isZero(); }) 45 .Case<IntegerAttr>( 46 [](auto intAttr) { return intAttr.getValue().isZero(); }) 47 .Default([](auto) { return false; }); 48 } 49 50 static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter, 51 Operation *op, VectorType vecTy) { 52 // Validate only vector as the basic vector store and load ops guarantee 53 // XeGPU-compatible memref source. 54 unsigned vecRank = vecTy.getRank(); 55 if (!(vecRank == 1 || vecRank == 2)) 56 return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector"); 57 58 return success(); 59 } 60 61 static LogicalResult transferPreconditions(PatternRewriter &rewriter, 62 VectorTransferOpInterface xferOp) { 63 if (xferOp.getMask()) 64 return rewriter.notifyMatchFailure(xferOp, 65 "Masked transfer is not supported"); 66 67 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType()); 68 if (!srcTy) 69 return rewriter.notifyMatchFailure(xferOp, "Expects memref source"); 70 71 // Perform common data transfer checks. 72 VectorType vecTy = xferOp.getVectorType(); 73 if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy))) 74 return failure(); 75 76 // Validate further transfer op semantics. 77 SmallVector<int64_t> strides; 78 int64_t offset; 79 if (failed(getStridesAndOffset(srcTy, strides, offset)) || 80 strides.back() != 1) 81 return rewriter.notifyMatchFailure( 82 xferOp, "Buffer must be contiguous in the innermost dimension"); 83 84 unsigned vecRank = vecTy.getRank(); 85 AffineMap map = xferOp.getPermutationMap(); 86 if (!map.isProjectedPermutation(/*allowZeroInResults=*/false)) 87 return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map"); 88 unsigned numInputDims = map.getNumInputs(); 89 for (AffineExpr expr : map.getResults().take_back(vecRank)) { 90 auto dim = dyn_cast<AffineDimExpr>(expr); 91 if (dim.getPosition() < (numInputDims - vecRank)) 92 return rewriter.notifyMatchFailure( 93 xferOp, "Only the innermost dimensions can be accessed"); 94 } 95 96 return success(); 97 } 98 99 static xegpu::CreateNdDescOp 100 createNdDescriptor(PatternRewriter &rewriter, Location loc, 101 xegpu::TensorDescType descType, TypedValue<MemRefType> src, 102 Operation::operand_range offsets) { 103 MemRefType srcTy = src.getType(); 104 auto [strides, offset] = getStridesAndOffset(srcTy); 105 106 xegpu::CreateNdDescOp ndDesc; 107 if (srcTy.hasStaticShape()) { 108 ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src, 109 getAsOpFoldResult(offsets)); 110 } else { 111 // In case of any dynamic shapes, source's shape and strides have to be 112 // explicitly provided. 113 SmallVector<Value> sourceDims; 114 unsigned srcRank = srcTy.getRank(); 115 for (unsigned i = 0; i < srcRank; ++i) 116 sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i)); 117 118 SmallVector<int64_t> constOffsets; 119 SmallVector<Value> dynOffsets; 120 for (Value offset : offsets) { 121 std::optional<int64_t> staticVal = getConstantIntValue(offset); 122 if (!staticVal) 123 dynOffsets.push_back(offset); 124 constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic)); 125 } 126 127 SmallVector<Value> dynShapes; 128 for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) { 129 if (shape == ShapedType::kDynamic) 130 dynShapes.push_back(sourceDims[idx]); 131 } 132 133 // Compute strides in reverse order. 134 SmallVector<Value> dynStrides; 135 Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1); 136 // Last stride is guaranteed to be static and unit. 137 for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) { 138 accStride = 139 rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]); 140 if (strides[i] == ShapedType::kDynamic) 141 dynStrides.push_back(accStride); 142 } 143 std::reverse(dynStrides.begin(), dynStrides.end()); 144 145 ndDesc = rewriter.create<xegpu::CreateNdDescOp>( 146 loc, descType, src, dynOffsets, dynShapes, dynStrides, 147 DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets), 148 DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()), 149 DenseI64ArrayAttr::get(rewriter.getContext(), strides)); 150 } 151 152 return ndDesc; 153 } 154 155 struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { 156 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 157 158 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 159 PatternRewriter &rewriter) const override { 160 Location loc = readOp.getLoc(); 161 162 if (failed(transferPreconditions(rewriter, readOp))) 163 return failure(); 164 165 bool isOutOfBounds = readOp.hasOutOfBoundsDim(); 166 if (isOutOfBounds && !isZeroConstant(readOp.getPadding())) 167 return rewriter.notifyMatchFailure( 168 readOp, "Unsupported non-zero padded out-of-bounds read"); 169 170 AffineMap readMap = readOp.getPermutationMap(); 171 bool isTransposeLoad = !readMap.isMinorIdentity(); 172 173 VectorType vecTy = readOp.getVectorType(); 174 Type elementType = vecTy.getElementType(); 175 unsigned minTransposeBitWidth = 32; 176 if (isTransposeLoad && 177 elementType.getIntOrFloatBitWidth() < minTransposeBitWidth) 178 return rewriter.notifyMatchFailure( 179 readOp, "Unsupported data type for tranposition"); 180 181 // If load is transposed, get the base shape for the tensor descriptor. 182 SmallVector<int64_t> descShape{vecTy.getShape()}; 183 if (isTransposeLoad) 184 std::reverse(descShape.begin(), descShape.end()); 185 auto descType = xegpu::TensorDescType::get( 186 descShape, elementType, /*array_length=*/1, 187 /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global); 188 189 xegpu::CreateNdDescOp ndDesc = 190 createNdDescriptor(rewriter, loc, descType, 191 dyn_cast<TypedValue<MemRefType>>(readOp.getSource()), 192 readOp.getIndices()); 193 194 DenseI64ArrayAttr transposeAttr = 195 !isTransposeLoad ? nullptr 196 : DenseI64ArrayAttr::get(rewriter.getContext(), 197 ArrayRef<int64_t>{1, 0}); 198 // By default, no specific caching policy is assigned. 199 xegpu::CachePolicyAttr hint = nullptr; 200 auto loadOp = rewriter.create<xegpu::LoadNdOp>( 201 loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr, 202 /*l1_hint=*/hint, 203 /*l2_hint=*/hint, /*l3_hint=*/hint); 204 rewriter.replaceOp(readOp, loadOp); 205 206 return success(); 207 } 208 }; 209 210 struct TransferWriteLowering 211 : public OpRewritePattern<vector::TransferWriteOp> { 212 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 213 214 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 215 PatternRewriter &rewriter) const override { 216 Location loc = writeOp.getLoc(); 217 218 if (failed(transferPreconditions(rewriter, writeOp))) 219 return failure(); 220 221 AffineMap map = writeOp.getPermutationMap(); 222 if (!map.isMinorIdentity()) 223 return rewriter.notifyMatchFailure(writeOp, "Expects identity map"); 224 225 VectorType vecTy = writeOp.getVectorType(); 226 auto descType = xegpu::TensorDescType::get( 227 vecTy.getShape(), vecTy.getElementType(), 228 /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(), 229 xegpu::MemorySpace::Global); 230 xegpu::CreateNdDescOp ndDesc = createNdDescriptor( 231 rewriter, loc, descType, 232 dyn_cast<TypedValue<MemRefType>>(writeOp.getSource()), 233 writeOp.getIndices()); 234 235 // By default, no specific caching policy is assigned. 236 xegpu::CachePolicyAttr hint = nullptr; 237 auto storeOp = 238 rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc, 239 /*l1_hint=*/hint, 240 /*l2_hint=*/hint, /*l3_hint=*/hint); 241 rewriter.replaceOp(writeOp, storeOp); 242 243 return success(); 244 } 245 }; 246 247 struct LoadLowering : public OpRewritePattern<vector::LoadOp> { 248 using OpRewritePattern<vector::LoadOp>::OpRewritePattern; 249 250 LogicalResult matchAndRewrite(vector::LoadOp loadOp, 251 PatternRewriter &rewriter) const override { 252 Location loc = loadOp.getLoc(); 253 254 VectorType vecTy = loadOp.getResult().getType(); 255 if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy))) 256 return failure(); 257 258 auto descType = xegpu::TensorDescType::get( 259 vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, 260 /*boundary_check=*/true, xegpu::MemorySpace::Global); 261 xegpu::CreateNdDescOp ndDesc = createNdDescriptor( 262 rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices()); 263 264 // By default, no specific caching policy is assigned. 265 xegpu::CachePolicyAttr hint = nullptr; 266 auto loadNdOp = rewriter.create<xegpu::LoadNdOp>( 267 loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr, 268 /*l1_hint=*/hint, 269 /*l2_hint=*/hint, /*l3_hint=*/hint); 270 rewriter.replaceOp(loadOp, loadNdOp); 271 272 return success(); 273 } 274 }; 275 276 struct StoreLowering : public OpRewritePattern<vector::StoreOp> { 277 using OpRewritePattern<vector::StoreOp>::OpRewritePattern; 278 279 LogicalResult matchAndRewrite(vector::StoreOp storeOp, 280 PatternRewriter &rewriter) const override { 281 Location loc = storeOp.getLoc(); 282 283 TypedValue<VectorType> vector = storeOp.getValueToStore(); 284 VectorType vecTy = vector.getType(); 285 if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy))) 286 return failure(); 287 288 auto descType = 289 xegpu::TensorDescType::get(vecTy.getShape(), vecTy.getElementType(), 290 /*array_length=*/1, /*boundary_check=*/true, 291 xegpu::MemorySpace::Global); 292 xegpu::CreateNdDescOp ndDesc = createNdDescriptor( 293 rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices()); 294 295 // By default, no specific caching policy is assigned. 296 xegpu::CachePolicyAttr hint = nullptr; 297 auto storeNdOp = 298 rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc, 299 /*l1_hint=*/hint, 300 /*l2_hint=*/hint, /*l3_hint=*/hint); 301 rewriter.replaceOp(storeOp, storeNdOp); 302 303 return success(); 304 } 305 }; 306 307 struct ConvertVectorToXeGPUPass 308 : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> { 309 void runOnOperation() override { 310 RewritePatternSet patterns(&getContext()); 311 populateVectorToXeGPUConversionPatterns(patterns); 312 if (failed( 313 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) 314 return signalPassFailure(); 315 } 316 }; 317 318 } // namespace 319 320 void mlir::populateVectorToXeGPUConversionPatterns( 321 RewritePatternSet &patterns) { 322 patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering, 323 StoreLowering>(patterns.getContext()); 324 } 325 326 std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() { 327 return std::make_unique<ConvertVectorToXeGPUPass>(); 328 } 329