1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/Dialect/Vector/VectorTransforms.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/Support/MathExtras.h" 20 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 using namespace mlir::vector; 25 26 // Helper to reduce vector type by one rank at front. 27 static VectorType reducedVectorTypeFront(VectorType tp) { 28 assert((tp.getRank() > 1) && "unlowerable vector type"); 29 unsigned numScalableDims = tp.getNumScalableDims(); 30 if (tp.getShape().size() == numScalableDims) 31 --numScalableDims; 32 return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), 33 numScalableDims); 34 } 35 36 // Helper to reduce vector type by *all* but one rank at back. 37 static VectorType reducedVectorTypeBack(VectorType tp) { 38 assert((tp.getRank() > 1) && "unlowerable vector type"); 39 unsigned numScalableDims = tp.getNumScalableDims(); 40 if (numScalableDims > 0) 41 --numScalableDims; 42 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 43 numScalableDims); 44 } 45 46 // Helper that picks the proper sequence for inserting. 47 static Value insertOne(ConversionPatternRewriter &rewriter, 48 LLVMTypeConverter &typeConverter, Location loc, 49 Value val1, Value val2, Type llvmType, int64_t rank, 50 int64_t pos) { 51 assert(rank > 0 && "0-D vector corner case should have been handled already"); 52 if (rank == 1) { 53 auto idxType = rewriter.getIndexType(); 54 auto constant = rewriter.create<LLVM::ConstantOp>( 55 loc, typeConverter.convertType(idxType), 56 rewriter.getIntegerAttr(idxType, pos)); 57 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 58 constant); 59 } 60 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 61 rewriter.getI64ArrayAttr(pos)); 62 } 63 64 // Helper that picks the proper sequence for extracting. 65 static Value extractOne(ConversionPatternRewriter &rewriter, 66 LLVMTypeConverter &typeConverter, Location loc, 67 Value val, Type llvmType, int64_t rank, int64_t pos) { 68 if (rank <= 1) { 69 auto idxType = rewriter.getIndexType(); 70 auto constant = rewriter.create<LLVM::ConstantOp>( 71 loc, typeConverter.convertType(idxType), 72 rewriter.getIntegerAttr(idxType, pos)); 73 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 74 constant); 75 } 76 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 77 rewriter.getI64ArrayAttr(pos)); 78 } 79 80 // Helper that returns data layout alignment of a memref. 81 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 82 MemRefType memrefType, unsigned &align) { 83 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 84 if (!elementTy) 85 return failure(); 86 87 // TODO: this should use the MLIR data layout when it becomes available and 88 // stop depending on translation. 89 llvm::LLVMContext llvmContext; 90 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 91 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 92 return success(); 93 } 94 95 // Add an index vector component to a base pointer. This almost always succeeds 96 // unless the last stride is non-unit or the memory space is not zero. 97 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 98 Location loc, Value memref, Value base, 99 Value index, MemRefType memRefType, 100 VectorType vType, Value &ptrs) { 101 int64_t offset; 102 SmallVector<int64_t, 4> strides; 103 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 104 if (failed(successStrides) || strides.back() != 1 || 105 memRefType.getMemorySpaceAsInt() != 0) 106 return failure(); 107 auto pType = MemRefDescriptor(memref).getElementPtrType(); 108 auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); 109 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index); 110 return success(); 111 } 112 113 // Casts a strided element pointer to a vector pointer. The vector pointer 114 // will be in the same address space as the incoming memref type. 115 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 116 Value ptr, MemRefType memRefType, Type vt) { 117 auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt()); 118 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 119 } 120 121 namespace { 122 123 /// Trivial Vector to LLVM conversions 124 using VectorScaleOpConversion = 125 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 126 127 /// Conversion pattern for a vector.bitcast. 128 class VectorBitCastOpConversion 129 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 130 public: 131 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 132 133 LogicalResult 134 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 135 ConversionPatternRewriter &rewriter) const override { 136 // Only 0-D and 1-D vectors can be lowered to LLVM. 137 VectorType resultTy = bitCastOp.getResultVectorType(); 138 if (resultTy.getRank() > 1) 139 return failure(); 140 Type newResultTy = typeConverter->convertType(resultTy); 141 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 142 adaptor.getOperands()[0]); 143 return success(); 144 } 145 }; 146 147 /// Conversion pattern for a vector.matrix_multiply. 148 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 149 class VectorMatmulOpConversion 150 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 151 public: 152 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 153 154 LogicalResult 155 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 156 ConversionPatternRewriter &rewriter) const override { 157 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 158 matmulOp, typeConverter->convertType(matmulOp.res().getType()), 159 adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), 160 matmulOp.lhs_columns(), matmulOp.rhs_columns()); 161 return success(); 162 } 163 }; 164 165 /// Conversion pattern for a vector.flat_transpose. 166 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 167 class VectorFlatTransposeOpConversion 168 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 169 public: 170 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 171 172 LogicalResult 173 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 174 ConversionPatternRewriter &rewriter) const override { 175 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 176 transOp, typeConverter->convertType(transOp.res().getType()), 177 adaptor.matrix(), transOp.rows(), transOp.columns()); 178 return success(); 179 } 180 }; 181 182 /// Overloaded utility that replaces a vector.load, vector.store, 183 /// vector.maskedload and vector.maskedstore with their respective LLVM 184 /// couterparts. 185 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 186 vector::LoadOpAdaptor adaptor, 187 VectorType vectorTy, Value ptr, unsigned align, 188 ConversionPatternRewriter &rewriter) { 189 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align); 190 } 191 192 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 193 vector::MaskedLoadOpAdaptor adaptor, 194 VectorType vectorTy, Value ptr, unsigned align, 195 ConversionPatternRewriter &rewriter) { 196 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 197 loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align); 198 } 199 200 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 201 vector::StoreOpAdaptor adaptor, 202 VectorType vectorTy, Value ptr, unsigned align, 203 ConversionPatternRewriter &rewriter) { 204 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(), 205 ptr, align); 206 } 207 208 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 209 vector::MaskedStoreOpAdaptor adaptor, 210 VectorType vectorTy, Value ptr, unsigned align, 211 ConversionPatternRewriter &rewriter) { 212 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 213 storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align); 214 } 215 216 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 217 /// vector.maskedstore. 218 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 219 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 220 public: 221 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 222 223 LogicalResult 224 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 225 typename LoadOrStoreOp::Adaptor adaptor, 226 ConversionPatternRewriter &rewriter) const override { 227 // Only 1-D vectors can be lowered to LLVM. 228 VectorType vectorTy = loadOrStoreOp.getVectorType(); 229 if (vectorTy.getRank() > 1) 230 return failure(); 231 232 auto loc = loadOrStoreOp->getLoc(); 233 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 234 235 // Resolve alignment. 236 unsigned align; 237 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 238 return failure(); 239 240 // Resolve address. 241 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) 242 .template cast<VectorType>(); 243 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(), 244 adaptor.indices(), rewriter); 245 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); 246 247 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 248 return success(); 249 } 250 }; 251 252 /// Conversion pattern for a vector.gather. 253 class VectorGatherOpConversion 254 : public ConvertOpToLLVMPattern<vector::GatherOp> { 255 public: 256 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 257 258 LogicalResult 259 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 260 ConversionPatternRewriter &rewriter) const override { 261 auto loc = gather->getLoc(); 262 MemRefType memRefType = gather.getMemRefType(); 263 264 // Resolve alignment. 265 unsigned align; 266 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 267 return failure(); 268 269 // Resolve address. 270 Value ptrs; 271 VectorType vType = gather.getVectorType(); 272 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 273 adaptor.indices(), rewriter); 274 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr, 275 adaptor.index_vec(), memRefType, vType, ptrs))) 276 return failure(); 277 278 // Replace with the gather intrinsic. 279 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 280 gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), 281 adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 282 return success(); 283 } 284 }; 285 286 /// Conversion pattern for a vector.scatter. 287 class VectorScatterOpConversion 288 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 289 public: 290 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 291 292 LogicalResult 293 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 294 ConversionPatternRewriter &rewriter) const override { 295 auto loc = scatter->getLoc(); 296 MemRefType memRefType = scatter.getMemRefType(); 297 298 // Resolve alignment. 299 unsigned align; 300 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 301 return failure(); 302 303 // Resolve address. 304 Value ptrs; 305 VectorType vType = scatter.getVectorType(); 306 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 307 adaptor.indices(), rewriter); 308 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr, 309 adaptor.index_vec(), memRefType, vType, ptrs))) 310 return failure(); 311 312 // Replace with the scatter intrinsic. 313 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 314 scatter, adaptor.valueToStore(), ptrs, adaptor.mask(), 315 rewriter.getI32IntegerAttr(align)); 316 return success(); 317 } 318 }; 319 320 /// Conversion pattern for a vector.expandload. 321 class VectorExpandLoadOpConversion 322 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 323 public: 324 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 325 326 LogicalResult 327 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 328 ConversionPatternRewriter &rewriter) const override { 329 auto loc = expand->getLoc(); 330 MemRefType memRefType = expand.getMemRefType(); 331 332 // Resolve address. 333 auto vtype = typeConverter->convertType(expand.getVectorType()); 334 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 335 adaptor.indices(), rewriter); 336 337 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 338 expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru()); 339 return success(); 340 } 341 }; 342 343 /// Conversion pattern for a vector.compressstore. 344 class VectorCompressStoreOpConversion 345 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 346 public: 347 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 348 349 LogicalResult 350 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 351 ConversionPatternRewriter &rewriter) const override { 352 auto loc = compress->getLoc(); 353 MemRefType memRefType = compress.getMemRefType(); 354 355 // Resolve address. 356 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 357 adaptor.indices(), rewriter); 358 359 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 360 compress, adaptor.valueToStore(), ptr, adaptor.mask()); 361 return success(); 362 } 363 }; 364 365 /// Conversion pattern for all vector reductions. 366 class VectorReductionOpConversion 367 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 368 public: 369 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 370 bool reassociateFPRed) 371 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 372 reassociateFPReductions(reassociateFPRed) {} 373 374 LogicalResult 375 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 376 ConversionPatternRewriter &rewriter) const override { 377 auto kind = reductionOp.kind(); 378 Type eltType = reductionOp.dest().getType(); 379 Type llvmType = typeConverter->convertType(eltType); 380 Value operand = adaptor.getOperands()[0]; 381 if (eltType.isIntOrIndex()) { 382 // Integer reductions: add/mul/min/max/and/or/xor. 383 if (kind == "add") 384 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp, 385 llvmType, operand); 386 else if (kind == "mul") 387 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp, 388 llvmType, operand); 389 else if (kind == "minui") 390 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 391 reductionOp, llvmType, operand); 392 else if (kind == "minsi") 393 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 394 reductionOp, llvmType, operand); 395 else if (kind == "maxui") 396 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 397 reductionOp, llvmType, operand); 398 else if (kind == "maxsi") 399 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 400 reductionOp, llvmType, operand); 401 else if (kind == "and") 402 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp, 403 llvmType, operand); 404 else if (kind == "or") 405 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp, 406 llvmType, operand); 407 else if (kind == "xor") 408 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp, 409 llvmType, operand); 410 else 411 return failure(); 412 return success(); 413 } 414 415 if (!eltType.isa<FloatType>()) 416 return failure(); 417 418 // Floating-point reductions: add/mul/min/max 419 if (kind == "add") { 420 // Optional accumulator (or zero). 421 Value acc = adaptor.getOperands().size() > 1 422 ? adaptor.getOperands()[1] 423 : rewriter.create<LLVM::ConstantOp>( 424 reductionOp->getLoc(), llvmType, 425 rewriter.getZeroAttr(eltType)); 426 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 427 reductionOp, llvmType, acc, operand, 428 rewriter.getBoolAttr(reassociateFPReductions)); 429 } else if (kind == "mul") { 430 // Optional accumulator (or one). 431 Value acc = adaptor.getOperands().size() > 1 432 ? adaptor.getOperands()[1] 433 : rewriter.create<LLVM::ConstantOp>( 434 reductionOp->getLoc(), llvmType, 435 rewriter.getFloatAttr(eltType, 1.0)); 436 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 437 reductionOp, llvmType, acc, operand, 438 rewriter.getBoolAttr(reassociateFPReductions)); 439 } else if (kind == "minf") 440 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 441 // NaNs/-0.0/+0.0 in the same way. 442 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp, 443 llvmType, operand); 444 else if (kind == "maxf") 445 // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle 446 // NaNs/-0.0/+0.0 in the same way. 447 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp, 448 llvmType, operand); 449 else 450 return failure(); 451 return success(); 452 } 453 454 private: 455 const bool reassociateFPReductions; 456 }; 457 458 class VectorShuffleOpConversion 459 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 460 public: 461 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 462 463 LogicalResult 464 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 465 ConversionPatternRewriter &rewriter) const override { 466 auto loc = shuffleOp->getLoc(); 467 auto v1Type = shuffleOp.getV1VectorType(); 468 auto v2Type = shuffleOp.getV2VectorType(); 469 auto vectorType = shuffleOp.getVectorType(); 470 Type llvmType = typeConverter->convertType(vectorType); 471 auto maskArrayAttr = shuffleOp.mask(); 472 473 // Bail if result type cannot be lowered. 474 if (!llvmType) 475 return failure(); 476 477 // Get rank and dimension sizes. 478 int64_t rank = vectorType.getRank(); 479 assert(v1Type.getRank() == rank); 480 assert(v2Type.getRank() == rank); 481 int64_t v1Dim = v1Type.getDimSize(0); 482 483 // For rank 1, where both operands have *exactly* the same vector type, 484 // there is direct shuffle support in LLVM. Use it! 485 if (rank == 1 && v1Type == v2Type) { 486 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 487 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 488 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 489 return success(); 490 } 491 492 // For all other cases, insert the individual values individually. 493 Type eltType; 494 if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>()) 495 eltType = arrayType.getElementType(); 496 else 497 eltType = llvmType.cast<VectorType>().getElementType(); 498 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 499 int64_t insPos = 0; 500 for (auto en : llvm::enumerate(maskArrayAttr)) { 501 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 502 Value value = adaptor.v1(); 503 if (extPos >= v1Dim) { 504 extPos -= v1Dim; 505 value = adaptor.v2(); 506 } 507 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 508 eltType, rank, extPos); 509 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 510 llvmType, rank, insPos++); 511 } 512 rewriter.replaceOp(shuffleOp, insert); 513 return success(); 514 } 515 }; 516 517 class VectorExtractElementOpConversion 518 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 519 public: 520 using ConvertOpToLLVMPattern< 521 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 522 523 LogicalResult 524 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 525 ConversionPatternRewriter &rewriter) const override { 526 auto vectorType = extractEltOp.getVectorType(); 527 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 528 529 // Bail if result type cannot be lowered. 530 if (!llvmType) 531 return failure(); 532 533 if (vectorType.getRank() == 0) { 534 Location loc = extractEltOp.getLoc(); 535 auto idxType = rewriter.getIndexType(); 536 auto zero = rewriter.create<LLVM::ConstantOp>( 537 loc, typeConverter->convertType(idxType), 538 rewriter.getIntegerAttr(idxType, 0)); 539 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 540 extractEltOp, llvmType, adaptor.vector(), zero); 541 return success(); 542 } 543 544 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 545 extractEltOp, llvmType, adaptor.vector(), adaptor.position()); 546 return success(); 547 } 548 }; 549 550 class VectorExtractOpConversion 551 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 552 public: 553 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 554 555 LogicalResult 556 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 557 ConversionPatternRewriter &rewriter) const override { 558 auto loc = extractOp->getLoc(); 559 auto vectorType = extractOp.getVectorType(); 560 auto resultType = extractOp.getResult().getType(); 561 auto llvmResultType = typeConverter->convertType(resultType); 562 auto positionArrayAttr = extractOp.position(); 563 564 // Bail if result type cannot be lowered. 565 if (!llvmResultType) 566 return failure(); 567 568 // Extract entire vector. Should be handled by folder, but just to be safe. 569 if (positionArrayAttr.empty()) { 570 rewriter.replaceOp(extractOp, adaptor.vector()); 571 return success(); 572 } 573 574 // One-shot extraction of vector from array (only requires extractvalue). 575 if (resultType.isa<VectorType>()) { 576 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 577 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 578 rewriter.replaceOp(extractOp, extracted); 579 return success(); 580 } 581 582 // Potential extraction of 1-D vector from array. 583 auto *context = extractOp->getContext(); 584 Value extracted = adaptor.vector(); 585 auto positionAttrs = positionArrayAttr.getValue(); 586 if (positionAttrs.size() > 1) { 587 auto oneDVectorType = reducedVectorTypeBack(vectorType); 588 auto nMinusOnePositionAttrs = 589 ArrayAttr::get(context, positionAttrs.drop_back()); 590 extracted = rewriter.create<LLVM::ExtractValueOp>( 591 loc, typeConverter->convertType(oneDVectorType), extracted, 592 nMinusOnePositionAttrs); 593 } 594 595 // Remaining extraction of element from 1-D LLVM vector 596 auto position = positionAttrs.back().cast<IntegerAttr>(); 597 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 598 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 599 extracted = 600 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 601 rewriter.replaceOp(extractOp, extracted); 602 603 return success(); 604 } 605 }; 606 607 /// Conversion pattern that turns a vector.fma on a 1-D vector 608 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 609 /// This does not match vectors of n >= 2 rank. 610 /// 611 /// Example: 612 /// ``` 613 /// vector.fma %a, %a, %a : vector<8xf32> 614 /// ``` 615 /// is converted to: 616 /// ``` 617 /// llvm.intr.fmuladd %va, %va, %va: 618 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 619 /// -> !llvm."<8 x f32>"> 620 /// ``` 621 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 622 public: 623 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 624 625 LogicalResult 626 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 627 ConversionPatternRewriter &rewriter) const override { 628 VectorType vType = fmaOp.getVectorType(); 629 if (vType.getRank() != 1) 630 return failure(); 631 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(), 632 adaptor.rhs(), adaptor.acc()); 633 return success(); 634 } 635 }; 636 637 class VectorInsertElementOpConversion 638 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 639 public: 640 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 641 642 LogicalResult 643 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 644 ConversionPatternRewriter &rewriter) const override { 645 auto vectorType = insertEltOp.getDestVectorType(); 646 auto llvmType = typeConverter->convertType(vectorType); 647 648 // Bail if result type cannot be lowered. 649 if (!llvmType) 650 return failure(); 651 652 if (vectorType.getRank() == 0) { 653 Location loc = insertEltOp.getLoc(); 654 auto idxType = rewriter.getIndexType(); 655 auto zero = rewriter.create<LLVM::ConstantOp>( 656 loc, typeConverter->convertType(idxType), 657 rewriter.getIntegerAttr(idxType, 0)); 658 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 659 insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero); 660 return success(); 661 } 662 663 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 664 insertEltOp, llvmType, adaptor.dest(), adaptor.source(), 665 adaptor.position()); 666 return success(); 667 } 668 }; 669 670 class VectorInsertOpConversion 671 : public ConvertOpToLLVMPattern<vector::InsertOp> { 672 public: 673 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 674 675 LogicalResult 676 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 677 ConversionPatternRewriter &rewriter) const override { 678 auto loc = insertOp->getLoc(); 679 auto sourceType = insertOp.getSourceType(); 680 auto destVectorType = insertOp.getDestVectorType(); 681 auto llvmResultType = typeConverter->convertType(destVectorType); 682 auto positionArrayAttr = insertOp.position(); 683 684 // Bail if result type cannot be lowered. 685 if (!llvmResultType) 686 return failure(); 687 688 // Overwrite entire vector with value. Should be handled by folder, but 689 // just to be safe. 690 if (positionArrayAttr.empty()) { 691 rewriter.replaceOp(insertOp, adaptor.source()); 692 return success(); 693 } 694 695 // One-shot insertion of a vector into an array (only requires insertvalue). 696 if (sourceType.isa<VectorType>()) { 697 Value inserted = rewriter.create<LLVM::InsertValueOp>( 698 loc, llvmResultType, adaptor.dest(), adaptor.source(), 699 positionArrayAttr); 700 rewriter.replaceOp(insertOp, inserted); 701 return success(); 702 } 703 704 // Potential extraction of 1-D vector from array. 705 auto *context = insertOp->getContext(); 706 Value extracted = adaptor.dest(); 707 auto positionAttrs = positionArrayAttr.getValue(); 708 auto position = positionAttrs.back().cast<IntegerAttr>(); 709 auto oneDVectorType = destVectorType; 710 if (positionAttrs.size() > 1) { 711 oneDVectorType = reducedVectorTypeBack(destVectorType); 712 auto nMinusOnePositionAttrs = 713 ArrayAttr::get(context, positionAttrs.drop_back()); 714 extracted = rewriter.create<LLVM::ExtractValueOp>( 715 loc, typeConverter->convertType(oneDVectorType), extracted, 716 nMinusOnePositionAttrs); 717 } 718 719 // Insertion of an element into a 1-D LLVM vector. 720 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 721 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 722 Value inserted = rewriter.create<LLVM::InsertElementOp>( 723 loc, typeConverter->convertType(oneDVectorType), extracted, 724 adaptor.source(), constant); 725 726 // Potential insertion of resulting 1-D vector into array. 727 if (positionAttrs.size() > 1) { 728 auto nMinusOnePositionAttrs = 729 ArrayAttr::get(context, positionAttrs.drop_back()); 730 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 731 adaptor.dest(), inserted, 732 nMinusOnePositionAttrs); 733 } 734 735 rewriter.replaceOp(insertOp, inserted); 736 return success(); 737 } 738 }; 739 740 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 741 /// 742 /// Example: 743 /// ``` 744 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 745 /// ``` 746 /// is rewritten into: 747 /// ``` 748 /// %r = splat %f0: vector<2x4xf32> 749 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 750 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 751 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 752 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 753 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 754 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 755 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 756 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 757 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 758 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 759 /// // %r3 holds the final value. 760 /// ``` 761 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 762 public: 763 using OpRewritePattern<FMAOp>::OpRewritePattern; 764 765 void initialize() { 766 // This pattern recursively unpacks one dimension at a time. The recursion 767 // bounded as the rank is strictly decreasing. 768 setHasBoundedRewriteRecursion(); 769 } 770 771 LogicalResult matchAndRewrite(FMAOp op, 772 PatternRewriter &rewriter) const override { 773 auto vType = op.getVectorType(); 774 if (vType.getRank() < 2) 775 return failure(); 776 777 auto loc = op.getLoc(); 778 auto elemType = vType.getElementType(); 779 Value zero = rewriter.create<arith::ConstantOp>( 780 loc, elemType, rewriter.getZeroAttr(elemType)); 781 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 782 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 783 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 784 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 785 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 786 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 787 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 788 } 789 rewriter.replaceOp(op, desc); 790 return success(); 791 } 792 }; 793 794 /// Returns the strides if the memory underlying `memRefType` has a contiguous 795 /// static layout. 796 static llvm::Optional<SmallVector<int64_t, 4>> 797 computeContiguousStrides(MemRefType memRefType) { 798 int64_t offset; 799 SmallVector<int64_t, 4> strides; 800 if (failed(getStridesAndOffset(memRefType, strides, offset))) 801 return None; 802 if (!strides.empty() && strides.back() != 1) 803 return None; 804 // If no layout or identity layout, this is contiguous by definition. 805 if (memRefType.getLayout().isIdentity()) 806 return strides; 807 808 // Otherwise, we must determine contiguity form shapes. This can only ever 809 // work in static cases because MemRefType is underspecified to represent 810 // contiguous dynamic shapes in other ways than with just empty/identity 811 // layout. 812 auto sizes = memRefType.getShape(); 813 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 814 if (ShapedType::isDynamic(sizes[index + 1]) || 815 ShapedType::isDynamicStrideOrOffset(strides[index]) || 816 ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 817 return None; 818 if (strides[index] != strides[index + 1] * sizes[index + 1]) 819 return None; 820 } 821 return strides; 822 } 823 824 class VectorTypeCastOpConversion 825 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 826 public: 827 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 828 829 LogicalResult 830 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 831 ConversionPatternRewriter &rewriter) const override { 832 auto loc = castOp->getLoc(); 833 MemRefType sourceMemRefType = 834 castOp.getOperand().getType().cast<MemRefType>(); 835 MemRefType targetMemRefType = castOp.getType(); 836 837 // Only static shape casts supported atm. 838 if (!sourceMemRefType.hasStaticShape() || 839 !targetMemRefType.hasStaticShape()) 840 return failure(); 841 842 auto llvmSourceDescriptorTy = 843 adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>(); 844 if (!llvmSourceDescriptorTy) 845 return failure(); 846 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 847 848 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 849 .dyn_cast_or_null<LLVM::LLVMStructType>(); 850 if (!llvmTargetDescriptorTy) 851 return failure(); 852 853 // Only contiguous source buffers supported atm. 854 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 855 if (!sourceStrides) 856 return failure(); 857 auto targetStrides = computeContiguousStrides(targetMemRefType); 858 if (!targetStrides) 859 return failure(); 860 // Only support static strides for now, regardless of contiguity. 861 if (llvm::any_of(*targetStrides, [](int64_t stride) { 862 return ShapedType::isDynamicStrideOrOffset(stride); 863 })) 864 return failure(); 865 866 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 867 868 // Create descriptor. 869 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 870 Type llvmTargetElementTy = desc.getElementPtrType(); 871 // Set allocated ptr. 872 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 873 allocated = 874 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 875 desc.setAllocatedPtr(rewriter, loc, allocated); 876 // Set aligned ptr. 877 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 878 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 879 desc.setAlignedPtr(rewriter, loc, ptr); 880 // Fill offset 0. 881 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 882 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 883 desc.setOffset(rewriter, loc, zero); 884 885 // Fill size and stride descriptors in memref. 886 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 887 int64_t index = indexedSize.index(); 888 auto sizeAttr = 889 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 890 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 891 desc.setSize(rewriter, loc, index, size); 892 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 893 (*targetStrides)[index]); 894 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 895 desc.setStride(rewriter, loc, index, stride); 896 } 897 898 rewriter.replaceOp(castOp, {desc}); 899 return success(); 900 } 901 }; 902 903 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 904 public: 905 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 906 907 // Proof-of-concept lowering implementation that relies on a small 908 // runtime support library, which only needs to provide a few 909 // printing methods (single value for all data types, opening/closing 910 // bracket, comma, newline). The lowering fully unrolls a vector 911 // in terms of these elementary printing operations. The advantage 912 // of this approach is that the library can remain unaware of all 913 // low-level implementation details of vectors while still supporting 914 // output of any shaped and dimensioned vector. Due to full unrolling, 915 // this approach is less suited for very large vectors though. 916 // 917 // TODO: rely solely on libc in future? something else? 918 // 919 LogicalResult 920 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 921 ConversionPatternRewriter &rewriter) const override { 922 Type printType = printOp.getPrintType(); 923 924 if (typeConverter->convertType(printType) == nullptr) 925 return failure(); 926 927 // Make sure element type has runtime support. 928 PrintConversion conversion = PrintConversion::None; 929 VectorType vectorType = printType.dyn_cast<VectorType>(); 930 Type eltType = vectorType ? vectorType.getElementType() : printType; 931 Operation *printer; 932 if (eltType.isF32()) { 933 printer = 934 LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>()); 935 } else if (eltType.isF64()) { 936 printer = 937 LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>()); 938 } else if (eltType.isIndex()) { 939 printer = 940 LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>()); 941 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 942 // Integers need a zero or sign extension on the operand 943 // (depending on the source type) as well as a signed or 944 // unsigned print method. Up to 64-bit is supported. 945 unsigned width = intTy.getWidth(); 946 if (intTy.isUnsigned()) { 947 if (width <= 64) { 948 if (width < 64) 949 conversion = PrintConversion::ZeroExt64; 950 printer = LLVM::lookupOrCreatePrintU64Fn( 951 printOp->getParentOfType<ModuleOp>()); 952 } else { 953 return failure(); 954 } 955 } else { 956 assert(intTy.isSignless() || intTy.isSigned()); 957 if (width <= 64) { 958 // Note that we *always* zero extend booleans (1-bit integers), 959 // so that true/false is printed as 1/0 rather than -1/0. 960 if (width == 1) 961 conversion = PrintConversion::ZeroExt64; 962 else if (width < 64) 963 conversion = PrintConversion::SignExt64; 964 printer = LLVM::lookupOrCreatePrintI64Fn( 965 printOp->getParentOfType<ModuleOp>()); 966 } else { 967 return failure(); 968 } 969 } 970 } else { 971 return failure(); 972 } 973 974 // Unroll vector into elementary print calls. 975 int64_t rank = vectorType ? vectorType.getRank() : 0; 976 Type type = vectorType ? vectorType : eltType; 977 emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank, 978 conversion); 979 emitCall(rewriter, printOp->getLoc(), 980 LLVM::lookupOrCreatePrintNewlineFn( 981 printOp->getParentOfType<ModuleOp>())); 982 rewriter.eraseOp(printOp); 983 return success(); 984 } 985 986 private: 987 enum class PrintConversion { 988 // clang-format off 989 None, 990 ZeroExt64, 991 SignExt64 992 // clang-format on 993 }; 994 995 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 996 Value value, Type type, Operation *printer, int64_t rank, 997 PrintConversion conversion) const { 998 VectorType vectorType = type.dyn_cast<VectorType>(); 999 Location loc = op->getLoc(); 1000 if (!vectorType) { 1001 assert(rank == 0 && "The scalar case expects rank == 0"); 1002 switch (conversion) { 1003 case PrintConversion::ZeroExt64: 1004 value = rewriter.create<arith::ExtUIOp>( 1005 loc, value, IntegerType::get(rewriter.getContext(), 64)); 1006 break; 1007 case PrintConversion::SignExt64: 1008 value = rewriter.create<arith::ExtSIOp>( 1009 loc, value, IntegerType::get(rewriter.getContext(), 64)); 1010 break; 1011 case PrintConversion::None: 1012 break; 1013 } 1014 emitCall(rewriter, loc, printer, value); 1015 return; 1016 } 1017 1018 emitCall(rewriter, loc, 1019 LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>())); 1020 Operation *printComma = 1021 LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>()); 1022 1023 if (rank <= 1) { 1024 auto reducedType = vectorType.getElementType(); 1025 auto llvmType = typeConverter->convertType(reducedType); 1026 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); 1027 for (int64_t d = 0; d < dim; ++d) { 1028 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1029 llvmType, /*rank=*/0, /*pos=*/d); 1030 emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, 1031 conversion); 1032 if (d != dim - 1) 1033 emitCall(rewriter, loc, printComma); 1034 } 1035 emitCall( 1036 rewriter, loc, 1037 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1038 return; 1039 } 1040 1041 int64_t dim = vectorType.getDimSize(0); 1042 for (int64_t d = 0; d < dim; ++d) { 1043 auto reducedType = reducedVectorTypeFront(vectorType); 1044 auto llvmType = typeConverter->convertType(reducedType); 1045 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1046 llvmType, rank, d); 1047 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1048 conversion); 1049 if (d != dim - 1) 1050 emitCall(rewriter, loc, printComma); 1051 } 1052 emitCall(rewriter, loc, 1053 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1054 } 1055 1056 // Helper to emit a call. 1057 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1058 Operation *ref, ValueRange params = ValueRange()) { 1059 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1060 params); 1061 } 1062 }; 1063 1064 } // namespace 1065 1066 /// Populate the given list with patterns that convert from Vector to LLVM. 1067 void mlir::populateVectorToLLVMConversionPatterns( 1068 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1069 bool reassociateFPReductions) { 1070 MLIRContext *ctx = converter.getDialect()->getContext(); 1071 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1072 populateVectorInsertExtractStridedSliceTransforms(patterns); 1073 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1074 patterns 1075 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1076 VectorExtractElementOpConversion, VectorExtractOpConversion, 1077 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1078 VectorInsertOpConversion, VectorPrintOpConversion, 1079 VectorTypeCastOpConversion, VectorScaleOpConversion, 1080 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1081 VectorLoadStoreConversion<vector::MaskedLoadOp, 1082 vector::MaskedLoadOpAdaptor>, 1083 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1084 VectorLoadStoreConversion<vector::MaskedStoreOp, 1085 vector::MaskedStoreOpAdaptor>, 1086 VectorGatherOpConversion, VectorScatterOpConversion, 1087 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>( 1088 converter); 1089 // Transfer ops with rank > 1 are handled by VectorToSCF. 1090 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1091 } 1092 1093 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1094 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1095 patterns.add<VectorMatmulOpConversion>(converter); 1096 patterns.add<VectorFlatTransposeOpConversion>(converter); 1097 } 1098