1 //===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector operations to XeGPU dialect ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" 14 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Vector/IR/VectorOps.h" 18 #include "mlir/Dialect/XeGPU/IR/XeGPU.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 #include "mlir/Transforms/Passes.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 24 #include <algorithm> 25 #include <optional> 26 27 namespace mlir { 28 #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU 29 #include "mlir/Conversion/Passes.h.inc" 30 } // namespace mlir 31 32 using namespace mlir; 33 34 namespace { 35 36 // Return true if value represents a zero constant. 37 static bool isZeroConstant(Value val) { 38 auto constant = val.getDefiningOp<arith::ConstantOp>(); 39 if (!constant) 40 return false; 41 42 return TypeSwitch<Attribute, bool>(constant.getValue()) 43 .Case<FloatAttr>( 44 [](auto floatAttr) { return floatAttr.getValue().isZero(); }) 45 .Case<IntegerAttr>( 46 [](auto intAttr) { return intAttr.getValue().isZero(); }) 47 .Default([](auto) { return false; }); 48 } 49 50 static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter, 51 Operation *op, VectorType vecTy) { 52 // Validate only vector as the basic vector store and load ops guarantee 53 // XeGPU-compatible memref source. 54 unsigned vecRank = vecTy.getRank(); 55 if (!(vecRank == 1 || vecRank == 2)) 56 return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector"); 57 58 return success(); 59 } 60 61 static LogicalResult transferPreconditions(PatternRewriter &rewriter, 62 VectorTransferOpInterface xferOp) { 63 if (xferOp.getMask()) 64 return rewriter.notifyMatchFailure(xferOp, 65 "Masked transfer is not supported"); 66 67 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType()); 68 if (!srcTy) 69 return rewriter.notifyMatchFailure(xferOp, "Expects memref source"); 70 71 // Perform common data transfer checks. 72 VectorType vecTy = xferOp.getVectorType(); 73 if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy))) 74 return failure(); 75 76 // Validate further transfer op semantics. 77 SmallVector<int64_t> strides; 78 int64_t offset; 79 if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1) 80 return rewriter.notifyMatchFailure( 81 xferOp, "Buffer must be contiguous in the innermost dimension"); 82 83 unsigned vecRank = vecTy.getRank(); 84 if (xferOp.hasOutOfBoundsDim() && vecRank < 2) 85 return rewriter.notifyMatchFailure( 86 xferOp, "Boundary check is available only for block instructions."); 87 88 AffineMap map = xferOp.getPermutationMap(); 89 if (!map.isProjectedPermutation(/*allowZeroInResults=*/false)) 90 return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map"); 91 unsigned numInputDims = map.getNumInputs(); 92 for (AffineExpr expr : map.getResults().take_back(vecRank)) { 93 auto dim = dyn_cast<AffineDimExpr>(expr); 94 if (dim.getPosition() < (numInputDims - vecRank)) 95 return rewriter.notifyMatchFailure( 96 xferOp, "Only the innermost dimensions can be accessed"); 97 } 98 99 return success(); 100 } 101 102 static xegpu::CreateNdDescOp 103 createNdDescriptor(PatternRewriter &rewriter, Location loc, 104 xegpu::TensorDescType descType, TypedValue<MemRefType> src, 105 Operation::operand_range offsets) { 106 MemRefType srcTy = src.getType(); 107 auto [strides, offset] = srcTy.getStridesAndOffset(); 108 109 xegpu::CreateNdDescOp ndDesc; 110 if (srcTy.hasStaticShape()) { 111 ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src, 112 getAsOpFoldResult(offsets)); 113 } else { 114 // In case of any dynamic shapes, source's shape and strides have to be 115 // explicitly provided. 116 SmallVector<Value> sourceDims; 117 unsigned srcRank = srcTy.getRank(); 118 for (unsigned i = 0; i < srcRank; ++i) 119 sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i)); 120 121 SmallVector<int64_t> constOffsets; 122 SmallVector<Value> dynOffsets; 123 for (Value offset : offsets) { 124 std::optional<int64_t> staticVal = getConstantIntValue(offset); 125 if (!staticVal) 126 dynOffsets.push_back(offset); 127 constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic)); 128 } 129 130 SmallVector<Value> dynShapes; 131 for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) { 132 if (shape == ShapedType::kDynamic) 133 dynShapes.push_back(sourceDims[idx]); 134 } 135 136 // Compute strides in reverse order. 137 SmallVector<Value> dynStrides; 138 Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1); 139 // Last stride is guaranteed to be static and unit. 140 for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) { 141 accStride = 142 rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]); 143 if (strides[i] == ShapedType::kDynamic) 144 dynStrides.push_back(accStride); 145 } 146 std::reverse(dynStrides.begin(), dynStrides.end()); 147 148 ndDesc = rewriter.create<xegpu::CreateNdDescOp>( 149 loc, descType, src, dynOffsets, dynShapes, dynStrides, 150 DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets), 151 DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()), 152 DenseI64ArrayAttr::get(rewriter.getContext(), strides)); 153 } 154 155 return ndDesc; 156 } 157 158 struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { 159 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 160 161 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 162 PatternRewriter &rewriter) const override { 163 Location loc = readOp.getLoc(); 164 165 if (failed(transferPreconditions(rewriter, readOp))) 166 return failure(); 167 168 bool isOutOfBounds = readOp.hasOutOfBoundsDim(); 169 if (isOutOfBounds && !isZeroConstant(readOp.getPadding())) 170 return rewriter.notifyMatchFailure( 171 readOp, "Unsupported non-zero padded out-of-bounds read"); 172 173 AffineMap readMap = readOp.getPermutationMap(); 174 bool isTransposeLoad = !readMap.isMinorIdentity(); 175 176 VectorType vecTy = readOp.getVectorType(); 177 Type elementType = vecTy.getElementType(); 178 unsigned minTransposeBitWidth = 32; 179 if (isTransposeLoad && 180 elementType.getIntOrFloatBitWidth() < minTransposeBitWidth) 181 return rewriter.notifyMatchFailure( 182 readOp, "Unsupported data type for transposition"); 183 184 // If load is transposed, get the base shape for the tensor descriptor. 185 SmallVector<int64_t> descShape(vecTy.getShape()); 186 if (isTransposeLoad) 187 std::reverse(descShape.begin(), descShape.end()); 188 auto descType = xegpu::TensorDescType::get( 189 descShape, elementType, /*array_length=*/1, 190 /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global); 191 192 xegpu::CreateNdDescOp ndDesc = 193 createNdDescriptor(rewriter, loc, descType, 194 dyn_cast<TypedValue<MemRefType>>(readOp.getSource()), 195 readOp.getIndices()); 196 197 DenseI64ArrayAttr transposeAttr = 198 !isTransposeLoad ? nullptr 199 : DenseI64ArrayAttr::get(rewriter.getContext(), 200 ArrayRef<int64_t>{1, 0}); 201 // By default, no specific caching policy is assigned. 202 xegpu::CachePolicyAttr hint = nullptr; 203 auto loadOp = rewriter.create<xegpu::LoadNdOp>( 204 loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr, 205 /*l1_hint=*/hint, 206 /*l2_hint=*/hint, /*l3_hint=*/hint); 207 rewriter.replaceOp(readOp, loadOp); 208 209 return success(); 210 } 211 }; 212 213 struct TransferWriteLowering 214 : public OpRewritePattern<vector::TransferWriteOp> { 215 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 216 217 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 218 PatternRewriter &rewriter) const override { 219 Location loc = writeOp.getLoc(); 220 221 if (failed(transferPreconditions(rewriter, writeOp))) 222 return failure(); 223 224 AffineMap map = writeOp.getPermutationMap(); 225 if (!map.isMinorIdentity()) 226 return rewriter.notifyMatchFailure(writeOp, "Expects identity map"); 227 228 VectorType vecTy = writeOp.getVectorType(); 229 auto descType = xegpu::TensorDescType::get( 230 vecTy.getShape(), vecTy.getElementType(), 231 /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(), 232 xegpu::MemorySpace::Global); 233 xegpu::CreateNdDescOp ndDesc = createNdDescriptor( 234 rewriter, loc, descType, 235 dyn_cast<TypedValue<MemRefType>>(writeOp.getSource()), 236 writeOp.getIndices()); 237 238 // By default, no specific caching policy is assigned. 239 xegpu::CachePolicyAttr hint = nullptr; 240 auto storeOp = 241 rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc, 242 /*l1_hint=*/hint, 243 /*l2_hint=*/hint, /*l3_hint=*/hint); 244 rewriter.replaceOp(writeOp, storeOp); 245 246 return success(); 247 } 248 }; 249 250 struct LoadLowering : public OpRewritePattern<vector::LoadOp> { 251 using OpRewritePattern<vector::LoadOp>::OpRewritePattern; 252 253 LogicalResult matchAndRewrite(vector::LoadOp loadOp, 254 PatternRewriter &rewriter) const override { 255 Location loc = loadOp.getLoc(); 256 257 VectorType vecTy = loadOp.getResult().getType(); 258 if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy))) 259 return failure(); 260 261 // Boundary check is available only for block instructions. 262 bool boundaryCheck = vecTy.getRank() > 1; 263 264 auto descType = xegpu::TensorDescType::get( 265 vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, 266 boundaryCheck, xegpu::MemorySpace::Global); 267 xegpu::CreateNdDescOp ndDesc = createNdDescriptor( 268 rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices()); 269 270 // By default, no specific caching policy is assigned. 271 xegpu::CachePolicyAttr hint = nullptr; 272 auto loadNdOp = rewriter.create<xegpu::LoadNdOp>( 273 loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr, 274 /*l1_hint=*/hint, 275 /*l2_hint=*/hint, /*l3_hint=*/hint); 276 rewriter.replaceOp(loadOp, loadNdOp); 277 278 return success(); 279 } 280 }; 281 282 struct StoreLowering : public OpRewritePattern<vector::StoreOp> { 283 using OpRewritePattern<vector::StoreOp>::OpRewritePattern; 284 285 LogicalResult matchAndRewrite(vector::StoreOp storeOp, 286 PatternRewriter &rewriter) const override { 287 Location loc = storeOp.getLoc(); 288 289 TypedValue<VectorType> vector = storeOp.getValueToStore(); 290 VectorType vecTy = vector.getType(); 291 if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy))) 292 return failure(); 293 294 // Boundary check is available only for block instructions. 295 bool boundaryCheck = vecTy.getRank() > 1; 296 297 auto descType = xegpu::TensorDescType::get( 298 vecTy.getShape(), vecTy.getElementType(), 299 /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); 300 xegpu::CreateNdDescOp ndDesc = createNdDescriptor( 301 rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices()); 302 303 // By default, no specific caching policy is assigned. 304 xegpu::CachePolicyAttr hint = nullptr; 305 auto storeNdOp = 306 rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc, 307 /*l1_hint=*/hint, 308 /*l2_hint=*/hint, /*l3_hint=*/hint); 309 rewriter.replaceOp(storeOp, storeNdOp); 310 311 return success(); 312 } 313 }; 314 315 struct ConvertVectorToXeGPUPass 316 : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> { 317 void runOnOperation() override { 318 RewritePatternSet patterns(&getContext()); 319 populateVectorToXeGPUConversionPatterns(patterns); 320 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 321 return signalPassFailure(); 322 } 323 }; 324 325 } // namespace 326 327 void mlir::populateVectorToXeGPUConversionPatterns( 328 RewritePatternSet &patterns) { 329 patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering, 330 StoreLowering>(patterns.getContext()); 331 } 332 333 std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() { 334 return std::make_unique<ConvertVectorToXeGPUPass>(); 335 } 336