1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/Support/MathExtras.h" 20 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 using namespace mlir::vector; 25 26 // Helper to reduce vector type by one rank at front. 27 static VectorType reducedVectorTypeFront(VectorType tp) { 28 assert((tp.getRank() > 1) && "unlowerable vector type"); 29 unsigned numScalableDims = tp.getNumScalableDims(); 30 if (tp.getShape().size() == numScalableDims) 31 --numScalableDims; 32 return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), 33 numScalableDims); 34 } 35 36 // Helper to reduce vector type by *all* but one rank at back. 37 static VectorType reducedVectorTypeBack(VectorType tp) { 38 assert((tp.getRank() > 1) && "unlowerable vector type"); 39 unsigned numScalableDims = tp.getNumScalableDims(); 40 if (numScalableDims > 0) 41 --numScalableDims; 42 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 43 numScalableDims); 44 } 45 46 // Helper that picks the proper sequence for inserting. 47 static Value insertOne(ConversionPatternRewriter &rewriter, 48 LLVMTypeConverter &typeConverter, Location loc, 49 Value val1, Value val2, Type llvmType, int64_t rank, 50 int64_t pos) { 51 assert(rank > 0 && "0-D vector corner case should have been handled already"); 52 if (rank == 1) { 53 auto idxType = rewriter.getIndexType(); 54 auto constant = rewriter.create<LLVM::ConstantOp>( 55 loc, typeConverter.convertType(idxType), 56 rewriter.getIntegerAttr(idxType, pos)); 57 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 58 constant); 59 } 60 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 61 rewriter.getI64ArrayAttr(pos)); 62 } 63 64 // Helper that picks the proper sequence for extracting. 65 static Value extractOne(ConversionPatternRewriter &rewriter, 66 LLVMTypeConverter &typeConverter, Location loc, 67 Value val, Type llvmType, int64_t rank, int64_t pos) { 68 if (rank <= 1) { 69 auto idxType = rewriter.getIndexType(); 70 auto constant = rewriter.create<LLVM::ConstantOp>( 71 loc, typeConverter.convertType(idxType), 72 rewriter.getIntegerAttr(idxType, pos)); 73 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 74 constant); 75 } 76 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 77 rewriter.getI64ArrayAttr(pos)); 78 } 79 80 // Helper that returns data layout alignment of a memref. 81 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 82 MemRefType memrefType, unsigned &align) { 83 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 84 if (!elementTy) 85 return failure(); 86 87 // TODO: this should use the MLIR data layout when it becomes available and 88 // stop depending on translation. 89 llvm::LLVMContext llvmContext; 90 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 91 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 92 return success(); 93 } 94 95 // Add an index vector component to a base pointer. This almost always succeeds 96 // unless the last stride is non-unit or the memory space is not zero. 97 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 98 Location loc, Value memref, Value base, 99 Value index, MemRefType memRefType, 100 VectorType vType, Value &ptrs) { 101 int64_t offset; 102 SmallVector<int64_t, 4> strides; 103 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 104 if (failed(successStrides) || strides.back() != 1 || 105 memRefType.getMemorySpaceAsInt() != 0) 106 return failure(); 107 auto pType = MemRefDescriptor(memref).getElementPtrType(); 108 auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); 109 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index); 110 return success(); 111 } 112 113 // Casts a strided element pointer to a vector pointer. The vector pointer 114 // will be in the same address space as the incoming memref type. 115 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 116 Value ptr, MemRefType memRefType, Type vt) { 117 auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt()); 118 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 119 } 120 121 namespace { 122 123 /// Trivial Vector to LLVM conversions 124 using VectorScaleOpConversion = 125 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 126 127 /// Conversion pattern for a vector.bitcast. 128 class VectorBitCastOpConversion 129 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 130 public: 131 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 132 133 LogicalResult 134 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 135 ConversionPatternRewriter &rewriter) const override { 136 // Only 0-D and 1-D vectors can be lowered to LLVM. 137 VectorType resultTy = bitCastOp.getResultVectorType(); 138 if (resultTy.getRank() > 1) 139 return failure(); 140 Type newResultTy = typeConverter->convertType(resultTy); 141 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 142 adaptor.getOperands()[0]); 143 return success(); 144 } 145 }; 146 147 /// Conversion pattern for a vector.matrix_multiply. 148 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 149 class VectorMatmulOpConversion 150 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 151 public: 152 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 153 154 LogicalResult 155 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 156 ConversionPatternRewriter &rewriter) const override { 157 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 158 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 159 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 160 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 161 return success(); 162 } 163 }; 164 165 /// Conversion pattern for a vector.flat_transpose. 166 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 167 class VectorFlatTransposeOpConversion 168 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 169 public: 170 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 171 172 LogicalResult 173 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 174 ConversionPatternRewriter &rewriter) const override { 175 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 176 transOp, typeConverter->convertType(transOp.getRes().getType()), 177 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 178 return success(); 179 } 180 }; 181 182 /// Overloaded utility that replaces a vector.load, vector.store, 183 /// vector.maskedload and vector.maskedstore with their respective LLVM 184 /// couterparts. 185 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 186 vector::LoadOpAdaptor adaptor, 187 VectorType vectorTy, Value ptr, unsigned align, 188 ConversionPatternRewriter &rewriter) { 189 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align); 190 } 191 192 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 193 vector::MaskedLoadOpAdaptor adaptor, 194 VectorType vectorTy, Value ptr, unsigned align, 195 ConversionPatternRewriter &rewriter) { 196 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 197 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 198 } 199 200 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 201 vector::StoreOpAdaptor adaptor, 202 VectorType vectorTy, Value ptr, unsigned align, 203 ConversionPatternRewriter &rewriter) { 204 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 205 ptr, align); 206 } 207 208 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 209 vector::MaskedStoreOpAdaptor adaptor, 210 VectorType vectorTy, Value ptr, unsigned align, 211 ConversionPatternRewriter &rewriter) { 212 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 213 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 214 } 215 216 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 217 /// vector.maskedstore. 218 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 219 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 220 public: 221 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 222 223 LogicalResult 224 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 225 typename LoadOrStoreOp::Adaptor adaptor, 226 ConversionPatternRewriter &rewriter) const override { 227 // Only 1-D vectors can be lowered to LLVM. 228 VectorType vectorTy = loadOrStoreOp.getVectorType(); 229 if (vectorTy.getRank() > 1) 230 return failure(); 231 232 auto loc = loadOrStoreOp->getLoc(); 233 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 234 235 // Resolve alignment. 236 unsigned align; 237 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 238 return failure(); 239 240 // Resolve address. 241 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) 242 .template cast<VectorType>(); 243 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 244 adaptor.getIndices(), rewriter); 245 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); 246 247 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 248 return success(); 249 } 250 }; 251 252 /// Conversion pattern for a vector.gather. 253 class VectorGatherOpConversion 254 : public ConvertOpToLLVMPattern<vector::GatherOp> { 255 public: 256 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 257 258 LogicalResult 259 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 260 ConversionPatternRewriter &rewriter) const override { 261 auto loc = gather->getLoc(); 262 MemRefType memRefType = gather.getMemRefType(); 263 264 // Resolve alignment. 265 unsigned align; 266 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 267 return failure(); 268 269 // Resolve address. 270 Value ptrs; 271 VectorType vType = gather.getVectorType(); 272 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 273 adaptor.getIndices(), rewriter); 274 if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, 275 adaptor.getIndexVec(), memRefType, vType, ptrs))) 276 return failure(); 277 278 // Replace with the gather intrinsic. 279 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 280 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 281 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 282 return success(); 283 } 284 }; 285 286 /// Conversion pattern for a vector.scatter. 287 class VectorScatterOpConversion 288 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 289 public: 290 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 291 292 LogicalResult 293 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 294 ConversionPatternRewriter &rewriter) const override { 295 auto loc = scatter->getLoc(); 296 MemRefType memRefType = scatter.getMemRefType(); 297 298 // Resolve alignment. 299 unsigned align; 300 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 301 return failure(); 302 303 // Resolve address. 304 Value ptrs; 305 VectorType vType = scatter.getVectorType(); 306 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 307 adaptor.getIndices(), rewriter); 308 if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, 309 adaptor.getIndexVec(), memRefType, vType, ptrs))) 310 return failure(); 311 312 // Replace with the scatter intrinsic. 313 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 314 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 315 rewriter.getI32IntegerAttr(align)); 316 return success(); 317 } 318 }; 319 320 /// Conversion pattern for a vector.expandload. 321 class VectorExpandLoadOpConversion 322 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 323 public: 324 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 325 326 LogicalResult 327 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 328 ConversionPatternRewriter &rewriter) const override { 329 auto loc = expand->getLoc(); 330 MemRefType memRefType = expand.getMemRefType(); 331 332 // Resolve address. 333 auto vtype = typeConverter->convertType(expand.getVectorType()); 334 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 335 adaptor.getIndices(), rewriter); 336 337 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 338 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 339 return success(); 340 } 341 }; 342 343 /// Conversion pattern for a vector.compressstore. 344 class VectorCompressStoreOpConversion 345 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 346 public: 347 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 348 349 LogicalResult 350 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 351 ConversionPatternRewriter &rewriter) const override { 352 auto loc = compress->getLoc(); 353 MemRefType memRefType = compress.getMemRefType(); 354 355 // Resolve address. 356 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 357 adaptor.getIndices(), rewriter); 358 359 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 360 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 361 return success(); 362 } 363 }; 364 365 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 366 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 367 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 368 /// non-null. 369 template <class VectorOp, class ScalarOp> 370 static Value createIntegerReductionArithmeticOpLowering( 371 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 372 Value vectorOperand, Value accumulator) { 373 Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand); 374 if (accumulator) 375 result = rewriter.create<ScalarOp>(loc, accumulator, result); 376 return result; 377 } 378 379 /// Helper method to lower a `vector.reduction` operation that performs 380 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 381 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 382 /// the accumulator value if non-null. 383 template <class VectorOp> 384 static Value createIntegerReductionComparisonOpLowering( 385 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 386 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 387 Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand); 388 if (accumulator) { 389 Value cmp = 390 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 391 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 392 } 393 return result; 394 } 395 396 /// Conversion pattern for all vector reductions. 397 class VectorReductionOpConversion 398 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 399 public: 400 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 401 bool reassociateFPRed) 402 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 403 reassociateFPReductions(reassociateFPRed) {} 404 405 LogicalResult 406 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 407 ConversionPatternRewriter &rewriter) const override { 408 auto kind = reductionOp.getKind(); 409 Type eltType = reductionOp.getDest().getType(); 410 Type llvmType = typeConverter->convertType(eltType); 411 Value operand = adaptor.getVector(); 412 Value acc = adaptor.getAcc(); 413 Location loc = reductionOp.getLoc(); 414 if (eltType.isIntOrIndex()) { 415 // Integer reductions: add/mul/min/max/and/or/xor. 416 Value result; 417 switch (kind) { 418 case vector::CombiningKind::ADD: 419 result = 420 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 421 LLVM::AddOp>( 422 rewriter, loc, llvmType, operand, acc); 423 break; 424 case vector::CombiningKind::MUL: 425 result = 426 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 427 LLVM::MulOp>( 428 rewriter, loc, llvmType, operand, acc); 429 break; 430 case vector::CombiningKind::MINUI: 431 result = createIntegerReductionComparisonOpLowering< 432 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 433 LLVM::ICmpPredicate::ule); 434 break; 435 case vector::CombiningKind::MINSI: 436 result = createIntegerReductionComparisonOpLowering< 437 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 438 LLVM::ICmpPredicate::sle); 439 break; 440 case vector::CombiningKind::MAXUI: 441 result = createIntegerReductionComparisonOpLowering< 442 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 443 LLVM::ICmpPredicate::uge); 444 break; 445 case vector::CombiningKind::MAXSI: 446 result = createIntegerReductionComparisonOpLowering< 447 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 448 LLVM::ICmpPredicate::sge); 449 break; 450 case vector::CombiningKind::AND: 451 result = 452 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 453 LLVM::AndOp>( 454 rewriter, loc, llvmType, operand, acc); 455 break; 456 case vector::CombiningKind::OR: 457 result = 458 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 459 LLVM::OrOp>( 460 rewriter, loc, llvmType, operand, acc); 461 break; 462 case vector::CombiningKind::XOR: 463 result = 464 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 465 LLVM::XOrOp>( 466 rewriter, loc, llvmType, operand, acc); 467 break; 468 default: 469 return failure(); 470 } 471 rewriter.replaceOp(reductionOp, result); 472 473 return success(); 474 } 475 476 if (!eltType.isa<FloatType>()) 477 return failure(); 478 479 // Floating-point reductions: add/mul/min/max 480 if (kind == vector::CombiningKind::ADD) { 481 // Optional accumulator (or zero). 482 Value acc = adaptor.getOperands().size() > 1 483 ? adaptor.getOperands()[1] 484 : rewriter.create<LLVM::ConstantOp>( 485 reductionOp->getLoc(), llvmType, 486 rewriter.getZeroAttr(eltType)); 487 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 488 reductionOp, llvmType, acc, operand, 489 rewriter.getBoolAttr(reassociateFPReductions)); 490 } else if (kind == vector::CombiningKind::MUL) { 491 // Optional accumulator (or one). 492 Value acc = adaptor.getOperands().size() > 1 493 ? adaptor.getOperands()[1] 494 : rewriter.create<LLVM::ConstantOp>( 495 reductionOp->getLoc(), llvmType, 496 rewriter.getFloatAttr(eltType, 1.0)); 497 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 498 reductionOp, llvmType, acc, operand, 499 rewriter.getBoolAttr(reassociateFPReductions)); 500 } else if (kind == vector::CombiningKind::MINF) 501 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 502 // NaNs/-0.0/+0.0 in the same way. 503 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp, 504 llvmType, operand); 505 else if (kind == vector::CombiningKind::MAXF) 506 // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle 507 // NaNs/-0.0/+0.0 in the same way. 508 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp, 509 llvmType, operand); 510 else 511 return failure(); 512 return success(); 513 } 514 515 private: 516 const bool reassociateFPReductions; 517 }; 518 519 class VectorShuffleOpConversion 520 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 521 public: 522 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 523 524 LogicalResult 525 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 526 ConversionPatternRewriter &rewriter) const override { 527 auto loc = shuffleOp->getLoc(); 528 auto v1Type = shuffleOp.getV1VectorType(); 529 auto v2Type = shuffleOp.getV2VectorType(); 530 auto vectorType = shuffleOp.getVectorType(); 531 Type llvmType = typeConverter->convertType(vectorType); 532 auto maskArrayAttr = shuffleOp.getMask(); 533 534 // Bail if result type cannot be lowered. 535 if (!llvmType) 536 return failure(); 537 538 // Get rank and dimension sizes. 539 int64_t rank = vectorType.getRank(); 540 assert(v1Type.getRank() == rank); 541 assert(v2Type.getRank() == rank); 542 int64_t v1Dim = v1Type.getDimSize(0); 543 544 // For rank 1, where both operands have *exactly* the same vector type, 545 // there is direct shuffle support in LLVM. Use it! 546 if (rank == 1 && v1Type == v2Type) { 547 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 548 loc, adaptor.getV1(), adaptor.getV2(), maskArrayAttr); 549 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 550 return success(); 551 } 552 553 // For all other cases, insert the individual values individually. 554 Type eltType; 555 if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>()) 556 eltType = arrayType.getElementType(); 557 else 558 eltType = llvmType.cast<VectorType>().getElementType(); 559 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 560 int64_t insPos = 0; 561 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 562 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 563 Value value = adaptor.getV1(); 564 if (extPos >= v1Dim) { 565 extPos -= v1Dim; 566 value = adaptor.getV2(); 567 } 568 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 569 eltType, rank, extPos); 570 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 571 llvmType, rank, insPos++); 572 } 573 rewriter.replaceOp(shuffleOp, insert); 574 return success(); 575 } 576 }; 577 578 class VectorExtractElementOpConversion 579 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 580 public: 581 using ConvertOpToLLVMPattern< 582 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 583 584 LogicalResult 585 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 586 ConversionPatternRewriter &rewriter) const override { 587 auto vectorType = extractEltOp.getVectorType(); 588 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 589 590 // Bail if result type cannot be lowered. 591 if (!llvmType) 592 return failure(); 593 594 if (vectorType.getRank() == 0) { 595 Location loc = extractEltOp.getLoc(); 596 auto idxType = rewriter.getIndexType(); 597 auto zero = rewriter.create<LLVM::ConstantOp>( 598 loc, typeConverter->convertType(idxType), 599 rewriter.getIntegerAttr(idxType, 0)); 600 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 601 extractEltOp, llvmType, adaptor.getVector(), zero); 602 return success(); 603 } 604 605 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 606 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 607 return success(); 608 } 609 }; 610 611 class VectorExtractOpConversion 612 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 613 public: 614 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 615 616 LogicalResult 617 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 618 ConversionPatternRewriter &rewriter) const override { 619 auto loc = extractOp->getLoc(); 620 auto vectorType = extractOp.getVectorType(); 621 auto resultType = extractOp.getResult().getType(); 622 auto llvmResultType = typeConverter->convertType(resultType); 623 auto positionArrayAttr = extractOp.getPosition(); 624 625 // Bail if result type cannot be lowered. 626 if (!llvmResultType) 627 return failure(); 628 629 // Extract entire vector. Should be handled by folder, but just to be safe. 630 if (positionArrayAttr.empty()) { 631 rewriter.replaceOp(extractOp, adaptor.getVector()); 632 return success(); 633 } 634 635 // One-shot extraction of vector from array (only requires extractvalue). 636 if (resultType.isa<VectorType>()) { 637 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 638 loc, llvmResultType, adaptor.getVector(), positionArrayAttr); 639 rewriter.replaceOp(extractOp, extracted); 640 return success(); 641 } 642 643 // Potential extraction of 1-D vector from array. 644 auto *context = extractOp->getContext(); 645 Value extracted = adaptor.getVector(); 646 auto positionAttrs = positionArrayAttr.getValue(); 647 if (positionAttrs.size() > 1) { 648 auto oneDVectorType = reducedVectorTypeBack(vectorType); 649 auto nMinusOnePositionAttrs = 650 ArrayAttr::get(context, positionAttrs.drop_back()); 651 extracted = rewriter.create<LLVM::ExtractValueOp>( 652 loc, typeConverter->convertType(oneDVectorType), extracted, 653 nMinusOnePositionAttrs); 654 } 655 656 // Remaining extraction of element from 1-D LLVM vector 657 auto position = positionAttrs.back().cast<IntegerAttr>(); 658 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 659 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 660 extracted = 661 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 662 rewriter.replaceOp(extractOp, extracted); 663 664 return success(); 665 } 666 }; 667 668 /// Conversion pattern that turns a vector.fma on a 1-D vector 669 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 670 /// This does not match vectors of n >= 2 rank. 671 /// 672 /// Example: 673 /// ``` 674 /// vector.fma %a, %a, %a : vector<8xf32> 675 /// ``` 676 /// is converted to: 677 /// ``` 678 /// llvm.intr.fmuladd %va, %va, %va: 679 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 680 /// -> !llvm."<8 x f32>"> 681 /// ``` 682 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 683 public: 684 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 685 686 LogicalResult 687 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 688 ConversionPatternRewriter &rewriter) const override { 689 VectorType vType = fmaOp.getVectorType(); 690 if (vType.getRank() != 1) 691 return failure(); 692 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 693 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 694 return success(); 695 } 696 }; 697 698 class VectorInsertElementOpConversion 699 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 700 public: 701 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 702 703 LogicalResult 704 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 705 ConversionPatternRewriter &rewriter) const override { 706 auto vectorType = insertEltOp.getDestVectorType(); 707 auto llvmType = typeConverter->convertType(vectorType); 708 709 // Bail if result type cannot be lowered. 710 if (!llvmType) 711 return failure(); 712 713 if (vectorType.getRank() == 0) { 714 Location loc = insertEltOp.getLoc(); 715 auto idxType = rewriter.getIndexType(); 716 auto zero = rewriter.create<LLVM::ConstantOp>( 717 loc, typeConverter->convertType(idxType), 718 rewriter.getIntegerAttr(idxType, 0)); 719 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 720 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 721 return success(); 722 } 723 724 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 725 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 726 adaptor.getPosition()); 727 return success(); 728 } 729 }; 730 731 class VectorInsertOpConversion 732 : public ConvertOpToLLVMPattern<vector::InsertOp> { 733 public: 734 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 735 736 LogicalResult 737 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 738 ConversionPatternRewriter &rewriter) const override { 739 auto loc = insertOp->getLoc(); 740 auto sourceType = insertOp.getSourceType(); 741 auto destVectorType = insertOp.getDestVectorType(); 742 auto llvmResultType = typeConverter->convertType(destVectorType); 743 auto positionArrayAttr = insertOp.getPosition(); 744 745 // Bail if result type cannot be lowered. 746 if (!llvmResultType) 747 return failure(); 748 749 // Overwrite entire vector with value. Should be handled by folder, but 750 // just to be safe. 751 if (positionArrayAttr.empty()) { 752 rewriter.replaceOp(insertOp, adaptor.getSource()); 753 return success(); 754 } 755 756 // One-shot insertion of a vector into an array (only requires insertvalue). 757 if (sourceType.isa<VectorType>()) { 758 Value inserted = rewriter.create<LLVM::InsertValueOp>( 759 loc, llvmResultType, adaptor.getDest(), adaptor.getSource(), 760 positionArrayAttr); 761 rewriter.replaceOp(insertOp, inserted); 762 return success(); 763 } 764 765 // Potential extraction of 1-D vector from array. 766 auto *context = insertOp->getContext(); 767 Value extracted = adaptor.getDest(); 768 auto positionAttrs = positionArrayAttr.getValue(); 769 auto position = positionAttrs.back().cast<IntegerAttr>(); 770 auto oneDVectorType = destVectorType; 771 if (positionAttrs.size() > 1) { 772 oneDVectorType = reducedVectorTypeBack(destVectorType); 773 auto nMinusOnePositionAttrs = 774 ArrayAttr::get(context, positionAttrs.drop_back()); 775 extracted = rewriter.create<LLVM::ExtractValueOp>( 776 loc, typeConverter->convertType(oneDVectorType), extracted, 777 nMinusOnePositionAttrs); 778 } 779 780 // Insertion of an element into a 1-D LLVM vector. 781 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 782 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 783 Value inserted = rewriter.create<LLVM::InsertElementOp>( 784 loc, typeConverter->convertType(oneDVectorType), extracted, 785 adaptor.getSource(), constant); 786 787 // Potential insertion of resulting 1-D vector into array. 788 if (positionAttrs.size() > 1) { 789 auto nMinusOnePositionAttrs = 790 ArrayAttr::get(context, positionAttrs.drop_back()); 791 inserted = rewriter.create<LLVM::InsertValueOp>( 792 loc, llvmResultType, adaptor.getDest(), inserted, 793 nMinusOnePositionAttrs); 794 } 795 796 rewriter.replaceOp(insertOp, inserted); 797 return success(); 798 } 799 }; 800 801 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 802 /// 803 /// Example: 804 /// ``` 805 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 806 /// ``` 807 /// is rewritten into: 808 /// ``` 809 /// %r = splat %f0: vector<2x4xf32> 810 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 811 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 812 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 813 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 814 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 815 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 816 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 817 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 818 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 819 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 820 /// // %r3 holds the final value. 821 /// ``` 822 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 823 public: 824 using OpRewritePattern<FMAOp>::OpRewritePattern; 825 826 void initialize() { 827 // This pattern recursively unpacks one dimension at a time. The recursion 828 // bounded as the rank is strictly decreasing. 829 setHasBoundedRewriteRecursion(); 830 } 831 832 LogicalResult matchAndRewrite(FMAOp op, 833 PatternRewriter &rewriter) const override { 834 auto vType = op.getVectorType(); 835 if (vType.getRank() < 2) 836 return failure(); 837 838 auto loc = op.getLoc(); 839 auto elemType = vType.getElementType(); 840 Value zero = rewriter.create<arith::ConstantOp>( 841 loc, elemType, rewriter.getZeroAttr(elemType)); 842 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 843 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 844 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 845 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 846 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 847 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 848 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 849 } 850 rewriter.replaceOp(op, desc); 851 return success(); 852 } 853 }; 854 855 /// Returns the strides if the memory underlying `memRefType` has a contiguous 856 /// static layout. 857 static llvm::Optional<SmallVector<int64_t, 4>> 858 computeContiguousStrides(MemRefType memRefType) { 859 int64_t offset; 860 SmallVector<int64_t, 4> strides; 861 if (failed(getStridesAndOffset(memRefType, strides, offset))) 862 return None; 863 if (!strides.empty() && strides.back() != 1) 864 return None; 865 // If no layout or identity layout, this is contiguous by definition. 866 if (memRefType.getLayout().isIdentity()) 867 return strides; 868 869 // Otherwise, we must determine contiguity form shapes. This can only ever 870 // work in static cases because MemRefType is underspecified to represent 871 // contiguous dynamic shapes in other ways than with just empty/identity 872 // layout. 873 auto sizes = memRefType.getShape(); 874 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 875 if (ShapedType::isDynamic(sizes[index + 1]) || 876 ShapedType::isDynamicStrideOrOffset(strides[index]) || 877 ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 878 return None; 879 if (strides[index] != strides[index + 1] * sizes[index + 1]) 880 return None; 881 } 882 return strides; 883 } 884 885 class VectorTypeCastOpConversion 886 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 887 public: 888 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 889 890 LogicalResult 891 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 892 ConversionPatternRewriter &rewriter) const override { 893 auto loc = castOp->getLoc(); 894 MemRefType sourceMemRefType = 895 castOp.getOperand().getType().cast<MemRefType>(); 896 MemRefType targetMemRefType = castOp.getType(); 897 898 // Only static shape casts supported atm. 899 if (!sourceMemRefType.hasStaticShape() || 900 !targetMemRefType.hasStaticShape()) 901 return failure(); 902 903 auto llvmSourceDescriptorTy = 904 adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>(); 905 if (!llvmSourceDescriptorTy) 906 return failure(); 907 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 908 909 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 910 .dyn_cast_or_null<LLVM::LLVMStructType>(); 911 if (!llvmTargetDescriptorTy) 912 return failure(); 913 914 // Only contiguous source buffers supported atm. 915 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 916 if (!sourceStrides) 917 return failure(); 918 auto targetStrides = computeContiguousStrides(targetMemRefType); 919 if (!targetStrides) 920 return failure(); 921 // Only support static strides for now, regardless of contiguity. 922 if (llvm::any_of(*targetStrides, [](int64_t stride) { 923 return ShapedType::isDynamicStrideOrOffset(stride); 924 })) 925 return failure(); 926 927 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 928 929 // Create descriptor. 930 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 931 Type llvmTargetElementTy = desc.getElementPtrType(); 932 // Set allocated ptr. 933 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 934 allocated = 935 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 936 desc.setAllocatedPtr(rewriter, loc, allocated); 937 // Set aligned ptr. 938 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 939 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 940 desc.setAlignedPtr(rewriter, loc, ptr); 941 // Fill offset 0. 942 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 943 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 944 desc.setOffset(rewriter, loc, zero); 945 946 // Fill size and stride descriptors in memref. 947 for (const auto &indexedSize : 948 llvm::enumerate(targetMemRefType.getShape())) { 949 int64_t index = indexedSize.index(); 950 auto sizeAttr = 951 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 952 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 953 desc.setSize(rewriter, loc, index, size); 954 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 955 (*targetStrides)[index]); 956 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 957 desc.setStride(rewriter, loc, index, stride); 958 } 959 960 rewriter.replaceOp(castOp, {desc}); 961 return success(); 962 } 963 }; 964 965 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 966 /// Non-scalable versions of this operation are handled in Vector Transforms. 967 class VectorCreateMaskOpRewritePattern 968 : public OpRewritePattern<vector::CreateMaskOp> { 969 public: 970 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 971 bool enableIndexOpt) 972 : OpRewritePattern<vector::CreateMaskOp>(context), 973 force32BitVectorIndices(enableIndexOpt) {} 974 975 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 976 PatternRewriter &rewriter) const override { 977 auto dstType = op.getType(); 978 if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable()) 979 return failure(); 980 IntegerType idxType = 981 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 982 auto loc = op->getLoc(); 983 Value indices = rewriter.create<LLVM::StepVectorOp>( 984 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 985 /*isScalable=*/true)); 986 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 987 op.getOperand(0)); 988 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 989 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 990 indices, bounds); 991 rewriter.replaceOp(op, comp); 992 return success(); 993 } 994 995 private: 996 const bool force32BitVectorIndices; 997 }; 998 999 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1000 public: 1001 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1002 1003 // Proof-of-concept lowering implementation that relies on a small 1004 // runtime support library, which only needs to provide a few 1005 // printing methods (single value for all data types, opening/closing 1006 // bracket, comma, newline). The lowering fully unrolls a vector 1007 // in terms of these elementary printing operations. The advantage 1008 // of this approach is that the library can remain unaware of all 1009 // low-level implementation details of vectors while still supporting 1010 // output of any shaped and dimensioned vector. Due to full unrolling, 1011 // this approach is less suited for very large vectors though. 1012 // 1013 // TODO: rely solely on libc in future? something else? 1014 // 1015 LogicalResult 1016 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1017 ConversionPatternRewriter &rewriter) const override { 1018 Type printType = printOp.getPrintType(); 1019 1020 if (typeConverter->convertType(printType) == nullptr) 1021 return failure(); 1022 1023 // Make sure element type has runtime support. 1024 PrintConversion conversion = PrintConversion::None; 1025 VectorType vectorType = printType.dyn_cast<VectorType>(); 1026 Type eltType = vectorType ? vectorType.getElementType() : printType; 1027 Operation *printer; 1028 if (eltType.isF32()) { 1029 printer = 1030 LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>()); 1031 } else if (eltType.isF64()) { 1032 printer = 1033 LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>()); 1034 } else if (eltType.isIndex()) { 1035 printer = 1036 LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>()); 1037 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1038 // Integers need a zero or sign extension on the operand 1039 // (depending on the source type) as well as a signed or 1040 // unsigned print method. Up to 64-bit is supported. 1041 unsigned width = intTy.getWidth(); 1042 if (intTy.isUnsigned()) { 1043 if (width <= 64) { 1044 if (width < 64) 1045 conversion = PrintConversion::ZeroExt64; 1046 printer = LLVM::lookupOrCreatePrintU64Fn( 1047 printOp->getParentOfType<ModuleOp>()); 1048 } else { 1049 return failure(); 1050 } 1051 } else { 1052 assert(intTy.isSignless() || intTy.isSigned()); 1053 if (width <= 64) { 1054 // Note that we *always* zero extend booleans (1-bit integers), 1055 // so that true/false is printed as 1/0 rather than -1/0. 1056 if (width == 1) 1057 conversion = PrintConversion::ZeroExt64; 1058 else if (width < 64) 1059 conversion = PrintConversion::SignExt64; 1060 printer = LLVM::lookupOrCreatePrintI64Fn( 1061 printOp->getParentOfType<ModuleOp>()); 1062 } else { 1063 return failure(); 1064 } 1065 } 1066 } else { 1067 return failure(); 1068 } 1069 1070 // Unroll vector into elementary print calls. 1071 int64_t rank = vectorType ? vectorType.getRank() : 0; 1072 Type type = vectorType ? vectorType : eltType; 1073 emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, 1074 conversion); 1075 emitCall(rewriter, printOp->getLoc(), 1076 LLVM::lookupOrCreatePrintNewlineFn( 1077 printOp->getParentOfType<ModuleOp>())); 1078 rewriter.eraseOp(printOp); 1079 return success(); 1080 } 1081 1082 private: 1083 enum class PrintConversion { 1084 // clang-format off 1085 None, 1086 ZeroExt64, 1087 SignExt64 1088 // clang-format on 1089 }; 1090 1091 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1092 Value value, Type type, Operation *printer, int64_t rank, 1093 PrintConversion conversion) const { 1094 VectorType vectorType = type.dyn_cast<VectorType>(); 1095 Location loc = op->getLoc(); 1096 if (!vectorType) { 1097 assert(rank == 0 && "The scalar case expects rank == 0"); 1098 switch (conversion) { 1099 case PrintConversion::ZeroExt64: 1100 value = rewriter.create<arith::ExtUIOp>( 1101 loc, IntegerType::get(rewriter.getContext(), 64), value); 1102 break; 1103 case PrintConversion::SignExt64: 1104 value = rewriter.create<arith::ExtSIOp>( 1105 loc, IntegerType::get(rewriter.getContext(), 64), value); 1106 break; 1107 case PrintConversion::None: 1108 break; 1109 } 1110 emitCall(rewriter, loc, printer, value); 1111 return; 1112 } 1113 1114 emitCall(rewriter, loc, 1115 LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>())); 1116 Operation *printComma = 1117 LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>()); 1118 1119 if (rank <= 1) { 1120 auto reducedType = vectorType.getElementType(); 1121 auto llvmType = typeConverter->convertType(reducedType); 1122 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); 1123 for (int64_t d = 0; d < dim; ++d) { 1124 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1125 llvmType, /*rank=*/0, /*pos=*/d); 1126 emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, 1127 conversion); 1128 if (d != dim - 1) 1129 emitCall(rewriter, loc, printComma); 1130 } 1131 emitCall( 1132 rewriter, loc, 1133 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1134 return; 1135 } 1136 1137 int64_t dim = vectorType.getDimSize(0); 1138 for (int64_t d = 0; d < dim; ++d) { 1139 auto reducedType = reducedVectorTypeFront(vectorType); 1140 auto llvmType = typeConverter->convertType(reducedType); 1141 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1142 llvmType, rank, d); 1143 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1144 conversion); 1145 if (d != dim - 1) 1146 emitCall(rewriter, loc, printComma); 1147 } 1148 emitCall(rewriter, loc, 1149 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1150 } 1151 1152 // Helper to emit a call. 1153 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1154 Operation *ref, ValueRange params = ValueRange()) { 1155 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1156 params); 1157 } 1158 }; 1159 1160 /// The Splat operation is lowered to an insertelement + a shufflevector 1161 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1162 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1163 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1164 1165 LogicalResult 1166 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1167 ConversionPatternRewriter &rewriter) const override { 1168 VectorType resultType = splatOp.getType().cast<VectorType>(); 1169 if (resultType.getRank() > 1) 1170 return failure(); 1171 1172 // First insert it into an undef vector so we can shuffle it. 1173 auto vectorType = typeConverter->convertType(splatOp.getType()); 1174 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1175 auto zero = rewriter.create<LLVM::ConstantOp>( 1176 splatOp.getLoc(), 1177 typeConverter->convertType(rewriter.getIntegerType(32)), 1178 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1179 1180 // For 0-d vector, we simply do `insertelement`. 1181 if (resultType.getRank() == 0) { 1182 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1183 splatOp, vectorType, undef, adaptor.getInput(), zero); 1184 return success(); 1185 } 1186 1187 // For 1-d vector, we additionally do a `vectorshuffle`. 1188 auto v = rewriter.create<LLVM::InsertElementOp>( 1189 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1190 1191 int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0); 1192 SmallVector<int32_t, 4> zeroValues(width, 0); 1193 1194 // Shuffle the value across the desired number of elements. 1195 ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); 1196 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1197 zeroAttrs); 1198 return success(); 1199 } 1200 }; 1201 1202 /// The Splat operation is lowered to an insertelement + a shufflevector 1203 /// operation. Splat to only 2+-d vector result types are lowered by the 1204 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1205 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1206 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1207 1208 LogicalResult 1209 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1210 ConversionPatternRewriter &rewriter) const override { 1211 VectorType resultType = splatOp.getType(); 1212 if (resultType.getRank() <= 1) 1213 return failure(); 1214 1215 // First insert it into an undef vector so we can shuffle it. 1216 auto loc = splatOp.getLoc(); 1217 auto vectorTypeInfo = 1218 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1219 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1220 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1221 if (!llvmNDVectorTy || !llvm1DVectorTy) 1222 return failure(); 1223 1224 // Construct returned value. 1225 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1226 1227 // Construct a 1-D vector with the splatted value that we insert in all the 1228 // places within the returned descriptor. 1229 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1230 auto zero = rewriter.create<LLVM::ConstantOp>( 1231 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1232 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1233 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1234 adaptor.getInput(), zero); 1235 1236 // Shuffle the value across the desired number of elements. 1237 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1238 SmallVector<int32_t, 4> zeroValues(width, 0); 1239 ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); 1240 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs); 1241 1242 // Iterate of linear index, convert to coords space and insert splatted 1-D 1243 // vector in each position. 1244 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { 1245 desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v, 1246 position); 1247 }); 1248 rewriter.replaceOp(splatOp, desc); 1249 return success(); 1250 } 1251 }; 1252 1253 } // namespace 1254 1255 /// Populate the given list with patterns that convert from Vector to LLVM. 1256 void mlir::populateVectorToLLVMConversionPatterns( 1257 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1258 bool reassociateFPReductions, bool force32BitVectorIndices) { 1259 MLIRContext *ctx = converter.getDialect()->getContext(); 1260 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1261 populateVectorInsertExtractStridedSliceTransforms(patterns); 1262 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1263 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1264 patterns 1265 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1266 VectorExtractElementOpConversion, VectorExtractOpConversion, 1267 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1268 VectorInsertOpConversion, VectorPrintOpConversion, 1269 VectorTypeCastOpConversion, VectorScaleOpConversion, 1270 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1271 VectorLoadStoreConversion<vector::MaskedLoadOp, 1272 vector::MaskedLoadOpAdaptor>, 1273 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1274 VectorLoadStoreConversion<vector::MaskedStoreOp, 1275 vector::MaskedStoreOpAdaptor>, 1276 VectorGatherOpConversion, VectorScatterOpConversion, 1277 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1278 VectorSplatOpLowering, VectorSplatNdOpLowering>(converter); 1279 // Transfer ops with rank > 1 are handled by VectorToSCF. 1280 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1281 } 1282 1283 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1284 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1285 patterns.add<VectorMatmulOpConversion>(converter); 1286 patterns.add<VectorFlatTransposeOpConversion>(converter); 1287 } 1288