1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/Dialect/Vector/VectorTransforms.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/Support/MathExtras.h" 20 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 using namespace mlir::vector; 25 26 // Helper to reduce vector type by one rank at front. 27 static VectorType reducedVectorTypeFront(VectorType tp) { 28 assert((tp.getRank() > 1) && "unlowerable vector type"); 29 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 30 } 31 32 // Helper to reduce vector type by *all* but one rank at back. 33 static VectorType reducedVectorTypeBack(VectorType tp) { 34 assert((tp.getRank() > 1) && "unlowerable vector type"); 35 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 36 } 37 38 // Helper that picks the proper sequence for inserting. 39 static Value insertOne(ConversionPatternRewriter &rewriter, 40 LLVMTypeConverter &typeConverter, Location loc, 41 Value val1, Value val2, Type llvmType, int64_t rank, 42 int64_t pos) { 43 assert(rank > 0 && "0-D vector corner case should have been handled already"); 44 if (rank == 1) { 45 auto idxType = rewriter.getIndexType(); 46 auto constant = rewriter.create<LLVM::ConstantOp>( 47 loc, typeConverter.convertType(idxType), 48 rewriter.getIntegerAttr(idxType, pos)); 49 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 50 constant); 51 } 52 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 53 rewriter.getI64ArrayAttr(pos)); 54 } 55 56 // Helper that picks the proper sequence for extracting. 57 static Value extractOne(ConversionPatternRewriter &rewriter, 58 LLVMTypeConverter &typeConverter, Location loc, 59 Value val, Type llvmType, int64_t rank, int64_t pos) { 60 if (rank <= 1) { 61 auto idxType = rewriter.getIndexType(); 62 auto constant = rewriter.create<LLVM::ConstantOp>( 63 loc, typeConverter.convertType(idxType), 64 rewriter.getIntegerAttr(idxType, pos)); 65 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 66 constant); 67 } 68 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 69 rewriter.getI64ArrayAttr(pos)); 70 } 71 72 // Helper that returns data layout alignment of a memref. 73 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 74 MemRefType memrefType, unsigned &align) { 75 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 76 if (!elementTy) 77 return failure(); 78 79 // TODO: this should use the MLIR data layout when it becomes available and 80 // stop depending on translation. 81 llvm::LLVMContext llvmContext; 82 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 83 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 84 return success(); 85 } 86 87 // Return the minimal alignment value that satisfies all the AssumeAlignment 88 // uses of `value`. If no such uses exist, return 1. 89 static unsigned getAssumedAlignment(Value value) { 90 unsigned align = 1; 91 for (auto &u : value.getUses()) { 92 Operation *owner = u.getOwner(); 93 if (auto op = dyn_cast<memref::AssumeAlignmentOp>(owner)) 94 align = mlir::lcm(align, op.alignment()); 95 } 96 return align; 97 } 98 99 // Helper that returns data layout alignment of a memref associated with a 100 // load, store, scatter, or gather op, including additional information from 101 // assume_alignment calls on the source of the transfer 102 template <class OpAdaptor> 103 LogicalResult getMemRefOpAlignment(LLVMTypeConverter &typeConverter, 104 OpAdaptor op, unsigned &align) { 105 if (failed(getMemRefAlignment(typeConverter, op.getMemRefType(), align))) 106 return failure(); 107 align = std::max(align, getAssumedAlignment(op.base())); 108 return success(); 109 } 110 111 // Add an index vector component to a base pointer. This almost always succeeds 112 // unless the last stride is non-unit or the memory space is not zero. 113 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 114 Location loc, Value memref, Value base, 115 Value index, MemRefType memRefType, 116 VectorType vType, Value &ptrs) { 117 int64_t offset; 118 SmallVector<int64_t, 4> strides; 119 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 120 if (failed(successStrides) || strides.back() != 1 || 121 memRefType.getMemorySpaceAsInt() != 0) 122 return failure(); 123 auto pType = MemRefDescriptor(memref).getElementPtrType(); 124 auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); 125 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index); 126 return success(); 127 } 128 129 // Casts a strided element pointer to a vector pointer. The vector pointer 130 // will be in the same address space as the incoming memref type. 131 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 132 Value ptr, MemRefType memRefType, Type vt) { 133 auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt()); 134 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 135 } 136 137 namespace { 138 139 /// Conversion pattern for a vector.bitcast. 140 class VectorBitCastOpConversion 141 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 142 public: 143 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 144 145 LogicalResult 146 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 147 ConversionPatternRewriter &rewriter) const override { 148 // Only 1-D vectors can be lowered to LLVM. 149 VectorType resultTy = bitCastOp.getType(); 150 if (resultTy.getRank() != 1) 151 return failure(); 152 Type newResultTy = typeConverter->convertType(resultTy); 153 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 154 adaptor.getOperands()[0]); 155 return success(); 156 } 157 }; 158 159 /// Conversion pattern for a vector.matrix_multiply. 160 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 161 class VectorMatmulOpConversion 162 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 163 public: 164 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 165 166 LogicalResult 167 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 168 ConversionPatternRewriter &rewriter) const override { 169 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 170 matmulOp, typeConverter->convertType(matmulOp.res().getType()), 171 adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), 172 matmulOp.lhs_columns(), matmulOp.rhs_columns()); 173 return success(); 174 } 175 }; 176 177 /// Conversion pattern for a vector.flat_transpose. 178 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 179 class VectorFlatTransposeOpConversion 180 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 181 public: 182 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 183 184 LogicalResult 185 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 186 ConversionPatternRewriter &rewriter) const override { 187 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 188 transOp, typeConverter->convertType(transOp.res().getType()), 189 adaptor.matrix(), transOp.rows(), transOp.columns()); 190 return success(); 191 } 192 }; 193 194 /// Overloaded utility that replaces a vector.load, vector.store, 195 /// vector.maskedload and vector.maskedstore with their respective LLVM 196 /// couterparts. 197 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 198 vector::LoadOpAdaptor adaptor, 199 VectorType vectorTy, Value ptr, unsigned align, 200 ConversionPatternRewriter &rewriter) { 201 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align); 202 } 203 204 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 205 vector::MaskedLoadOpAdaptor adaptor, 206 VectorType vectorTy, Value ptr, unsigned align, 207 ConversionPatternRewriter &rewriter) { 208 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 209 loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align); 210 } 211 212 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 213 vector::StoreOpAdaptor adaptor, 214 VectorType vectorTy, Value ptr, unsigned align, 215 ConversionPatternRewriter &rewriter) { 216 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(), 217 ptr, align); 218 } 219 220 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 221 vector::MaskedStoreOpAdaptor adaptor, 222 VectorType vectorTy, Value ptr, unsigned align, 223 ConversionPatternRewriter &rewriter) { 224 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 225 storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align); 226 } 227 228 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 229 /// vector.maskedstore. 230 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 231 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 232 public: 233 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 234 235 LogicalResult 236 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 237 typename LoadOrStoreOp::Adaptor adaptor, 238 ConversionPatternRewriter &rewriter) const override { 239 // Only 1-D vectors can be lowered to LLVM. 240 VectorType vectorTy = loadOrStoreOp.getVectorType(); 241 if (vectorTy.getRank() > 1) 242 return failure(); 243 244 auto loc = loadOrStoreOp->getLoc(); 245 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 246 247 // Resolve alignment. 248 unsigned align; 249 if (failed(getMemRefOpAlignment(*this->getTypeConverter(), loadOrStoreOp, 250 align))) 251 return failure(); 252 253 // Resolve address. 254 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) 255 .template cast<VectorType>(); 256 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(), 257 adaptor.indices(), rewriter); 258 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); 259 260 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 261 return success(); 262 } 263 }; 264 265 /// Conversion pattern for a vector.gather. 266 class VectorGatherOpConversion 267 : public ConvertOpToLLVMPattern<vector::GatherOp> { 268 public: 269 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 270 271 LogicalResult 272 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 273 ConversionPatternRewriter &rewriter) const override { 274 auto loc = gather->getLoc(); 275 MemRefType memRefType = gather.getMemRefType(); 276 277 // Resolve alignment. 278 unsigned align; 279 if (failed(getMemRefOpAlignment(*getTypeConverter(), gather, align))) 280 return failure(); 281 282 // Resolve address. 283 Value ptrs; 284 VectorType vType = gather.getVectorType(); 285 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 286 adaptor.indices(), rewriter); 287 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr, 288 adaptor.index_vec(), memRefType, vType, ptrs))) 289 return failure(); 290 291 // Replace with the gather intrinsic. 292 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 293 gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), 294 adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 295 return success(); 296 } 297 }; 298 299 /// Conversion pattern for a vector.scatter. 300 class VectorScatterOpConversion 301 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 302 public: 303 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 304 305 LogicalResult 306 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 307 ConversionPatternRewriter &rewriter) const override { 308 auto loc = scatter->getLoc(); 309 MemRefType memRefType = scatter.getMemRefType(); 310 311 // Resolve alignment. 312 unsigned align; 313 if (failed(getMemRefOpAlignment(*getTypeConverter(), scatter, align))) 314 return failure(); 315 316 // Resolve address. 317 Value ptrs; 318 VectorType vType = scatter.getVectorType(); 319 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 320 adaptor.indices(), rewriter); 321 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr, 322 adaptor.index_vec(), memRefType, vType, ptrs))) 323 return failure(); 324 325 // Replace with the scatter intrinsic. 326 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 327 scatter, adaptor.valueToStore(), ptrs, adaptor.mask(), 328 rewriter.getI32IntegerAttr(align)); 329 return success(); 330 } 331 }; 332 333 /// Conversion pattern for a vector.expandload. 334 class VectorExpandLoadOpConversion 335 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 336 public: 337 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 338 339 LogicalResult 340 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 341 ConversionPatternRewriter &rewriter) const override { 342 auto loc = expand->getLoc(); 343 MemRefType memRefType = expand.getMemRefType(); 344 345 // Resolve address. 346 auto vtype = typeConverter->convertType(expand.getVectorType()); 347 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 348 adaptor.indices(), rewriter); 349 350 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 351 expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru()); 352 return success(); 353 } 354 }; 355 356 /// Conversion pattern for a vector.compressstore. 357 class VectorCompressStoreOpConversion 358 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 359 public: 360 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 361 362 LogicalResult 363 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 364 ConversionPatternRewriter &rewriter) const override { 365 auto loc = compress->getLoc(); 366 MemRefType memRefType = compress.getMemRefType(); 367 368 // Resolve address. 369 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), 370 adaptor.indices(), rewriter); 371 372 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 373 compress, adaptor.valueToStore(), ptr, adaptor.mask()); 374 return success(); 375 } 376 }; 377 378 /// Conversion pattern for all vector reductions. 379 class VectorReductionOpConversion 380 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 381 public: 382 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 383 bool reassociateFPRed) 384 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 385 reassociateFPReductions(reassociateFPRed) {} 386 387 LogicalResult 388 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 389 ConversionPatternRewriter &rewriter) const override { 390 auto kind = reductionOp.kind(); 391 Type eltType = reductionOp.dest().getType(); 392 Type llvmType = typeConverter->convertType(eltType); 393 Value operand = adaptor.getOperands()[0]; 394 if (eltType.isIntOrIndex()) { 395 // Integer reductions: add/mul/min/max/and/or/xor. 396 if (kind == "add") 397 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp, 398 llvmType, operand); 399 else if (kind == "mul") 400 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp, 401 llvmType, operand); 402 else if (kind == "minui") 403 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 404 reductionOp, llvmType, operand); 405 else if (kind == "minsi") 406 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 407 reductionOp, llvmType, operand); 408 else if (kind == "maxui") 409 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 410 reductionOp, llvmType, operand); 411 else if (kind == "maxsi") 412 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 413 reductionOp, llvmType, operand); 414 else if (kind == "and") 415 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp, 416 llvmType, operand); 417 else if (kind == "or") 418 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp, 419 llvmType, operand); 420 else if (kind == "xor") 421 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp, 422 llvmType, operand); 423 else 424 return failure(); 425 return success(); 426 } 427 428 if (!eltType.isa<FloatType>()) 429 return failure(); 430 431 // Floating-point reductions: add/mul/min/max 432 if (kind == "add") { 433 // Optional accumulator (or zero). 434 Value acc = adaptor.getOperands().size() > 1 435 ? adaptor.getOperands()[1] 436 : rewriter.create<LLVM::ConstantOp>( 437 reductionOp->getLoc(), llvmType, 438 rewriter.getZeroAttr(eltType)); 439 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 440 reductionOp, llvmType, acc, operand, 441 rewriter.getBoolAttr(reassociateFPReductions)); 442 } else if (kind == "mul") { 443 // Optional accumulator (or one). 444 Value acc = adaptor.getOperands().size() > 1 445 ? adaptor.getOperands()[1] 446 : rewriter.create<LLVM::ConstantOp>( 447 reductionOp->getLoc(), llvmType, 448 rewriter.getFloatAttr(eltType, 1.0)); 449 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 450 reductionOp, llvmType, acc, operand, 451 rewriter.getBoolAttr(reassociateFPReductions)); 452 } else if (kind == "minf") 453 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 454 // NaNs/-0.0/+0.0 in the same way. 455 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp, 456 llvmType, operand); 457 else if (kind == "maxf") 458 // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle 459 // NaNs/-0.0/+0.0 in the same way. 460 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp, 461 llvmType, operand); 462 else 463 return failure(); 464 return success(); 465 } 466 467 private: 468 const bool reassociateFPReductions; 469 }; 470 471 class VectorShuffleOpConversion 472 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 473 public: 474 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 475 476 LogicalResult 477 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 478 ConversionPatternRewriter &rewriter) const override { 479 auto loc = shuffleOp->getLoc(); 480 auto v1Type = shuffleOp.getV1VectorType(); 481 auto v2Type = shuffleOp.getV2VectorType(); 482 auto vectorType = shuffleOp.getVectorType(); 483 Type llvmType = typeConverter->convertType(vectorType); 484 auto maskArrayAttr = shuffleOp.mask(); 485 486 // Bail if result type cannot be lowered. 487 if (!llvmType) 488 return failure(); 489 490 // Get rank and dimension sizes. 491 int64_t rank = vectorType.getRank(); 492 assert(v1Type.getRank() == rank); 493 assert(v2Type.getRank() == rank); 494 int64_t v1Dim = v1Type.getDimSize(0); 495 496 // For rank 1, where both operands have *exactly* the same vector type, 497 // there is direct shuffle support in LLVM. Use it! 498 if (rank == 1 && v1Type == v2Type) { 499 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 500 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 501 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 502 return success(); 503 } 504 505 // For all other cases, insert the individual values individually. 506 Type eltType; 507 if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>()) 508 eltType = arrayType.getElementType(); 509 else 510 eltType = llvmType.cast<VectorType>().getElementType(); 511 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 512 int64_t insPos = 0; 513 for (auto en : llvm::enumerate(maskArrayAttr)) { 514 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 515 Value value = adaptor.v1(); 516 if (extPos >= v1Dim) { 517 extPos -= v1Dim; 518 value = adaptor.v2(); 519 } 520 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 521 eltType, rank, extPos); 522 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 523 llvmType, rank, insPos++); 524 } 525 rewriter.replaceOp(shuffleOp, insert); 526 return success(); 527 } 528 }; 529 530 class VectorExtractElementOpConversion 531 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 532 public: 533 using ConvertOpToLLVMPattern< 534 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 535 536 LogicalResult 537 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 538 ConversionPatternRewriter &rewriter) const override { 539 auto vectorType = extractEltOp.getVectorType(); 540 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 541 542 // Bail if result type cannot be lowered. 543 if (!llvmType) 544 return failure(); 545 546 if (vectorType.getRank() == 0) { 547 Location loc = extractEltOp.getLoc(); 548 auto idxType = rewriter.getIndexType(); 549 auto zero = rewriter.create<LLVM::ConstantOp>( 550 loc, typeConverter->convertType(idxType), 551 rewriter.getIntegerAttr(idxType, 0)); 552 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 553 extractEltOp, llvmType, adaptor.vector(), zero); 554 return success(); 555 } 556 557 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 558 extractEltOp, llvmType, adaptor.vector(), adaptor.position()); 559 return success(); 560 } 561 }; 562 563 class VectorExtractOpConversion 564 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 565 public: 566 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 567 568 LogicalResult 569 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 570 ConversionPatternRewriter &rewriter) const override { 571 auto loc = extractOp->getLoc(); 572 auto vectorType = extractOp.getVectorType(); 573 auto resultType = extractOp.getResult().getType(); 574 auto llvmResultType = typeConverter->convertType(resultType); 575 auto positionArrayAttr = extractOp.position(); 576 577 // Bail if result type cannot be lowered. 578 if (!llvmResultType) 579 return failure(); 580 581 // Extract entire vector. Should be handled by folder, but just to be safe. 582 if (positionArrayAttr.empty()) { 583 rewriter.replaceOp(extractOp, adaptor.vector()); 584 return success(); 585 } 586 587 // One-shot extraction of vector from array (only requires extractvalue). 588 if (resultType.isa<VectorType>()) { 589 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 590 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 591 rewriter.replaceOp(extractOp, extracted); 592 return success(); 593 } 594 595 // Potential extraction of 1-D vector from array. 596 auto *context = extractOp->getContext(); 597 Value extracted = adaptor.vector(); 598 auto positionAttrs = positionArrayAttr.getValue(); 599 if (positionAttrs.size() > 1) { 600 auto oneDVectorType = reducedVectorTypeBack(vectorType); 601 auto nMinusOnePositionAttrs = 602 ArrayAttr::get(context, positionAttrs.drop_back()); 603 extracted = rewriter.create<LLVM::ExtractValueOp>( 604 loc, typeConverter->convertType(oneDVectorType), extracted, 605 nMinusOnePositionAttrs); 606 } 607 608 // Remaining extraction of element from 1-D LLVM vector 609 auto position = positionAttrs.back().cast<IntegerAttr>(); 610 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 611 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 612 extracted = 613 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 614 rewriter.replaceOp(extractOp, extracted); 615 616 return success(); 617 } 618 }; 619 620 /// Conversion pattern that turns a vector.fma on a 1-D vector 621 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 622 /// This does not match vectors of n >= 2 rank. 623 /// 624 /// Example: 625 /// ``` 626 /// vector.fma %a, %a, %a : vector<8xf32> 627 /// ``` 628 /// is converted to: 629 /// ``` 630 /// llvm.intr.fmuladd %va, %va, %va: 631 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 632 /// -> !llvm."<8 x f32>"> 633 /// ``` 634 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 635 public: 636 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 637 638 LogicalResult 639 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 640 ConversionPatternRewriter &rewriter) const override { 641 VectorType vType = fmaOp.getVectorType(); 642 if (vType.getRank() != 1) 643 return failure(); 644 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(), 645 adaptor.rhs(), adaptor.acc()); 646 return success(); 647 } 648 }; 649 650 class VectorInsertElementOpConversion 651 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 652 public: 653 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 654 655 LogicalResult 656 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 657 ConversionPatternRewriter &rewriter) const override { 658 auto vectorType = insertEltOp.getDestVectorType(); 659 auto llvmType = typeConverter->convertType(vectorType); 660 661 // Bail if result type cannot be lowered. 662 if (!llvmType) 663 return failure(); 664 665 if (vectorType.getRank() == 0) { 666 Location loc = insertEltOp.getLoc(); 667 auto idxType = rewriter.getIndexType(); 668 auto zero = rewriter.create<LLVM::ConstantOp>( 669 loc, typeConverter->convertType(idxType), 670 rewriter.getIntegerAttr(idxType, 0)); 671 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 672 insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero); 673 return success(); 674 } 675 676 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 677 insertEltOp, llvmType, adaptor.dest(), adaptor.source(), 678 adaptor.position()); 679 return success(); 680 } 681 }; 682 683 class VectorInsertOpConversion 684 : public ConvertOpToLLVMPattern<vector::InsertOp> { 685 public: 686 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 687 688 LogicalResult 689 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 690 ConversionPatternRewriter &rewriter) const override { 691 auto loc = insertOp->getLoc(); 692 auto sourceType = insertOp.getSourceType(); 693 auto destVectorType = insertOp.getDestVectorType(); 694 auto llvmResultType = typeConverter->convertType(destVectorType); 695 auto positionArrayAttr = insertOp.position(); 696 697 // Bail if result type cannot be lowered. 698 if (!llvmResultType) 699 return failure(); 700 701 // Overwrite entire vector with value. Should be handled by folder, but 702 // just to be safe. 703 if (positionArrayAttr.empty()) { 704 rewriter.replaceOp(insertOp, adaptor.source()); 705 return success(); 706 } 707 708 // One-shot insertion of a vector into an array (only requires insertvalue). 709 if (sourceType.isa<VectorType>()) { 710 Value inserted = rewriter.create<LLVM::InsertValueOp>( 711 loc, llvmResultType, adaptor.dest(), adaptor.source(), 712 positionArrayAttr); 713 rewriter.replaceOp(insertOp, inserted); 714 return success(); 715 } 716 717 // Potential extraction of 1-D vector from array. 718 auto *context = insertOp->getContext(); 719 Value extracted = adaptor.dest(); 720 auto positionAttrs = positionArrayAttr.getValue(); 721 auto position = positionAttrs.back().cast<IntegerAttr>(); 722 auto oneDVectorType = destVectorType; 723 if (positionAttrs.size() > 1) { 724 oneDVectorType = reducedVectorTypeBack(destVectorType); 725 auto nMinusOnePositionAttrs = 726 ArrayAttr::get(context, positionAttrs.drop_back()); 727 extracted = rewriter.create<LLVM::ExtractValueOp>( 728 loc, typeConverter->convertType(oneDVectorType), extracted, 729 nMinusOnePositionAttrs); 730 } 731 732 // Insertion of an element into a 1-D LLVM vector. 733 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 734 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 735 Value inserted = rewriter.create<LLVM::InsertElementOp>( 736 loc, typeConverter->convertType(oneDVectorType), extracted, 737 adaptor.source(), constant); 738 739 // Potential insertion of resulting 1-D vector into array. 740 if (positionAttrs.size() > 1) { 741 auto nMinusOnePositionAttrs = 742 ArrayAttr::get(context, positionAttrs.drop_back()); 743 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 744 adaptor.dest(), inserted, 745 nMinusOnePositionAttrs); 746 } 747 748 rewriter.replaceOp(insertOp, inserted); 749 return success(); 750 } 751 }; 752 753 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 754 /// 755 /// Example: 756 /// ``` 757 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 758 /// ``` 759 /// is rewritten into: 760 /// ``` 761 /// %r = splat %f0: vector<2x4xf32> 762 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 763 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 764 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 765 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 766 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 767 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 768 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 769 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 770 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 771 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 772 /// // %r3 holds the final value. 773 /// ``` 774 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 775 public: 776 using OpRewritePattern<FMAOp>::OpRewritePattern; 777 778 void initialize() { 779 // This pattern recursively unpacks one dimension at a time. The recursion 780 // bounded as the rank is strictly decreasing. 781 setHasBoundedRewriteRecursion(); 782 } 783 784 LogicalResult matchAndRewrite(FMAOp op, 785 PatternRewriter &rewriter) const override { 786 auto vType = op.getVectorType(); 787 if (vType.getRank() < 2) 788 return failure(); 789 790 auto loc = op.getLoc(); 791 auto elemType = vType.getElementType(); 792 Value zero = rewriter.create<arith::ConstantOp>( 793 loc, elemType, rewriter.getZeroAttr(elemType)); 794 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 795 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 796 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 797 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 798 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 799 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 800 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 801 } 802 rewriter.replaceOp(op, desc); 803 return success(); 804 } 805 }; 806 807 /// Returns the strides if the memory underlying `memRefType` has a contiguous 808 /// static layout. 809 static llvm::Optional<SmallVector<int64_t, 4>> 810 computeContiguousStrides(MemRefType memRefType) { 811 int64_t offset; 812 SmallVector<int64_t, 4> strides; 813 if (failed(getStridesAndOffset(memRefType, strides, offset))) 814 return None; 815 if (!strides.empty() && strides.back() != 1) 816 return None; 817 // If no layout or identity layout, this is contiguous by definition. 818 if (memRefType.getLayout().isIdentity()) 819 return strides; 820 821 // Otherwise, we must determine contiguity form shapes. This can only ever 822 // work in static cases because MemRefType is underspecified to represent 823 // contiguous dynamic shapes in other ways than with just empty/identity 824 // layout. 825 auto sizes = memRefType.getShape(); 826 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 827 if (ShapedType::isDynamic(sizes[index + 1]) || 828 ShapedType::isDynamicStrideOrOffset(strides[index]) || 829 ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 830 return None; 831 if (strides[index] != strides[index + 1] * sizes[index + 1]) 832 return None; 833 } 834 return strides; 835 } 836 837 class VectorTypeCastOpConversion 838 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 839 public: 840 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 841 842 LogicalResult 843 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 844 ConversionPatternRewriter &rewriter) const override { 845 auto loc = castOp->getLoc(); 846 MemRefType sourceMemRefType = 847 castOp.getOperand().getType().cast<MemRefType>(); 848 MemRefType targetMemRefType = castOp.getType(); 849 850 // Only static shape casts supported atm. 851 if (!sourceMemRefType.hasStaticShape() || 852 !targetMemRefType.hasStaticShape()) 853 return failure(); 854 855 auto llvmSourceDescriptorTy = 856 adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>(); 857 if (!llvmSourceDescriptorTy) 858 return failure(); 859 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 860 861 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 862 .dyn_cast_or_null<LLVM::LLVMStructType>(); 863 if (!llvmTargetDescriptorTy) 864 return failure(); 865 866 // Only contiguous source buffers supported atm. 867 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 868 if (!sourceStrides) 869 return failure(); 870 auto targetStrides = computeContiguousStrides(targetMemRefType); 871 if (!targetStrides) 872 return failure(); 873 // Only support static strides for now, regardless of contiguity. 874 if (llvm::any_of(*targetStrides, [](int64_t stride) { 875 return ShapedType::isDynamicStrideOrOffset(stride); 876 })) 877 return failure(); 878 879 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 880 881 // Create descriptor. 882 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 883 Type llvmTargetElementTy = desc.getElementPtrType(); 884 // Set allocated ptr. 885 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 886 allocated = 887 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 888 desc.setAllocatedPtr(rewriter, loc, allocated); 889 // Set aligned ptr. 890 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 891 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 892 desc.setAlignedPtr(rewriter, loc, ptr); 893 // Fill offset 0. 894 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 895 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 896 desc.setOffset(rewriter, loc, zero); 897 898 // Fill size and stride descriptors in memref. 899 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 900 int64_t index = indexedSize.index(); 901 auto sizeAttr = 902 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 903 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 904 desc.setSize(rewriter, loc, index, size); 905 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 906 (*targetStrides)[index]); 907 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 908 desc.setStride(rewriter, loc, index, stride); 909 } 910 911 rewriter.replaceOp(castOp, {desc}); 912 return success(); 913 } 914 }; 915 916 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 917 public: 918 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 919 920 // Proof-of-concept lowering implementation that relies on a small 921 // runtime support library, which only needs to provide a few 922 // printing methods (single value for all data types, opening/closing 923 // bracket, comma, newline). The lowering fully unrolls a vector 924 // in terms of these elementary printing operations. The advantage 925 // of this approach is that the library can remain unaware of all 926 // low-level implementation details of vectors while still supporting 927 // output of any shaped and dimensioned vector. Due to full unrolling, 928 // this approach is less suited for very large vectors though. 929 // 930 // TODO: rely solely on libc in future? something else? 931 // 932 LogicalResult 933 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 934 ConversionPatternRewriter &rewriter) const override { 935 Type printType = printOp.getPrintType(); 936 937 if (typeConverter->convertType(printType) == nullptr) 938 return failure(); 939 940 // Make sure element type has runtime support. 941 PrintConversion conversion = PrintConversion::None; 942 VectorType vectorType = printType.dyn_cast<VectorType>(); 943 Type eltType = vectorType ? vectorType.getElementType() : printType; 944 Operation *printer; 945 if (eltType.isF32()) { 946 printer = 947 LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>()); 948 } else if (eltType.isF64()) { 949 printer = 950 LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>()); 951 } else if (eltType.isIndex()) { 952 printer = 953 LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>()); 954 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 955 // Integers need a zero or sign extension on the operand 956 // (depending on the source type) as well as a signed or 957 // unsigned print method. Up to 64-bit is supported. 958 unsigned width = intTy.getWidth(); 959 if (intTy.isUnsigned()) { 960 if (width <= 64) { 961 if (width < 64) 962 conversion = PrintConversion::ZeroExt64; 963 printer = LLVM::lookupOrCreatePrintU64Fn( 964 printOp->getParentOfType<ModuleOp>()); 965 } else { 966 return failure(); 967 } 968 } else { 969 assert(intTy.isSignless() || intTy.isSigned()); 970 if (width <= 64) { 971 // Note that we *always* zero extend booleans (1-bit integers), 972 // so that true/false is printed as 1/0 rather than -1/0. 973 if (width == 1) 974 conversion = PrintConversion::ZeroExt64; 975 else if (width < 64) 976 conversion = PrintConversion::SignExt64; 977 printer = LLVM::lookupOrCreatePrintI64Fn( 978 printOp->getParentOfType<ModuleOp>()); 979 } else { 980 return failure(); 981 } 982 } 983 } else { 984 return failure(); 985 } 986 987 // Unroll vector into elementary print calls. 988 int64_t rank = vectorType ? vectorType.getRank() : 0; 989 Type type = vectorType ? vectorType : eltType; 990 emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank, 991 conversion); 992 emitCall(rewriter, printOp->getLoc(), 993 LLVM::lookupOrCreatePrintNewlineFn( 994 printOp->getParentOfType<ModuleOp>())); 995 rewriter.eraseOp(printOp); 996 return success(); 997 } 998 999 private: 1000 enum class PrintConversion { 1001 // clang-format off 1002 None, 1003 ZeroExt64, 1004 SignExt64 1005 // clang-format on 1006 }; 1007 1008 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1009 Value value, Type type, Operation *printer, int64_t rank, 1010 PrintConversion conversion) const { 1011 VectorType vectorType = type.dyn_cast<VectorType>(); 1012 Location loc = op->getLoc(); 1013 if (!vectorType) { 1014 assert(rank == 0 && "The scalar case expects rank == 0"); 1015 switch (conversion) { 1016 case PrintConversion::ZeroExt64: 1017 value = rewriter.create<arith::ExtUIOp>( 1018 loc, value, IntegerType::get(rewriter.getContext(), 64)); 1019 break; 1020 case PrintConversion::SignExt64: 1021 value = rewriter.create<arith::ExtSIOp>( 1022 loc, value, IntegerType::get(rewriter.getContext(), 64)); 1023 break; 1024 case PrintConversion::None: 1025 break; 1026 } 1027 emitCall(rewriter, loc, printer, value); 1028 return; 1029 } 1030 1031 emitCall(rewriter, loc, 1032 LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>())); 1033 Operation *printComma = 1034 LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>()); 1035 1036 if (rank <= 1) { 1037 auto reducedType = vectorType.getElementType(); 1038 auto llvmType = typeConverter->convertType(reducedType); 1039 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); 1040 for (int64_t d = 0; d < dim; ++d) { 1041 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1042 llvmType, /*rank=*/0, /*pos=*/d); 1043 emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, 1044 conversion); 1045 if (d != dim - 1) 1046 emitCall(rewriter, loc, printComma); 1047 } 1048 emitCall( 1049 rewriter, loc, 1050 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1051 return; 1052 } 1053 1054 int64_t dim = vectorType.getDimSize(0); 1055 for (int64_t d = 0; d < dim; ++d) { 1056 auto reducedType = reducedVectorTypeFront(vectorType); 1057 auto llvmType = typeConverter->convertType(reducedType); 1058 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1059 llvmType, rank, d); 1060 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1061 conversion); 1062 if (d != dim - 1) 1063 emitCall(rewriter, loc, printComma); 1064 } 1065 emitCall(rewriter, loc, 1066 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1067 } 1068 1069 // Helper to emit a call. 1070 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1071 Operation *ref, ValueRange params = ValueRange()) { 1072 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1073 params); 1074 } 1075 }; 1076 1077 } // namespace 1078 1079 /// Populate the given list with patterns that convert from Vector to LLVM. 1080 void mlir::populateVectorToLLVMConversionPatterns( 1081 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1082 bool reassociateFPReductions) { 1083 MLIRContext *ctx = converter.getDialect()->getContext(); 1084 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1085 populateVectorInsertExtractStridedSliceTransforms(patterns); 1086 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1087 patterns 1088 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1089 VectorExtractElementOpConversion, VectorExtractOpConversion, 1090 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1091 VectorInsertOpConversion, VectorPrintOpConversion, 1092 VectorTypeCastOpConversion, 1093 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1094 VectorLoadStoreConversion<vector::MaskedLoadOp, 1095 vector::MaskedLoadOpAdaptor>, 1096 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1097 VectorLoadStoreConversion<vector::MaskedStoreOp, 1098 vector::MaskedStoreOpAdaptor>, 1099 VectorGatherOpConversion, VectorScatterOpConversion, 1100 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>( 1101 converter); 1102 // Transfer ops with rank > 1 are handled by VectorToSCF. 1103 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1104 } 1105 1106 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1107 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1108 patterns.add<VectorMatmulOpConversion>(converter); 1109 patterns.add<VectorFlatTransposeOpConversion>(converter); 1110 } 1111