1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Support/MathExtras.h" 21 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 24 using namespace mlir; 25 using namespace mlir::vector; 26 27 // Helper to reduce vector type by one rank at front. 28 static VectorType reducedVectorTypeFront(VectorType tp) { 29 assert((tp.getRank() > 1) && "unlowerable vector type"); 30 unsigned numScalableDims = tp.getNumScalableDims(); 31 if (tp.getShape().size() == numScalableDims) 32 --numScalableDims; 33 return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), 34 numScalableDims); 35 } 36 37 // Helper to reduce vector type by *all* but one rank at back. 38 static VectorType reducedVectorTypeBack(VectorType tp) { 39 assert((tp.getRank() > 1) && "unlowerable vector type"); 40 unsigned numScalableDims = tp.getNumScalableDims(); 41 if (numScalableDims > 0) 42 --numScalableDims; 43 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 44 numScalableDims); 45 } 46 47 // Helper that picks the proper sequence for inserting. 48 static Value insertOne(ConversionPatternRewriter &rewriter, 49 LLVMTypeConverter &typeConverter, Location loc, 50 Value val1, Value val2, Type llvmType, int64_t rank, 51 int64_t pos) { 52 assert(rank > 0 && "0-D vector corner case should have been handled already"); 53 if (rank == 1) { 54 auto idxType = rewriter.getIndexType(); 55 auto constant = rewriter.create<LLVM::ConstantOp>( 56 loc, typeConverter.convertType(idxType), 57 rewriter.getIntegerAttr(idxType, pos)); 58 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 59 constant); 60 } 61 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 62 rewriter.getI64ArrayAttr(pos)); 63 } 64 65 // Helper that picks the proper sequence for extracting. 66 static Value extractOne(ConversionPatternRewriter &rewriter, 67 LLVMTypeConverter &typeConverter, Location loc, 68 Value val, Type llvmType, int64_t rank, int64_t pos) { 69 if (rank <= 1) { 70 auto idxType = rewriter.getIndexType(); 71 auto constant = rewriter.create<LLVM::ConstantOp>( 72 loc, typeConverter.convertType(idxType), 73 rewriter.getIntegerAttr(idxType, pos)); 74 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 75 constant); 76 } 77 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 78 rewriter.getI64ArrayAttr(pos)); 79 } 80 81 // Helper that returns data layout alignment of a memref. 82 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 83 MemRefType memrefType, unsigned &align) { 84 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 85 if (!elementTy) 86 return failure(); 87 88 // TODO: this should use the MLIR data layout when it becomes available and 89 // stop depending on translation. 90 llvm::LLVMContext llvmContext; 91 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 92 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 93 return success(); 94 } 95 96 // Add an index vector component to a base pointer. This almost always succeeds 97 // unless the last stride is non-unit or the memory space is not zero. 98 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 99 Location loc, Value memref, Value base, 100 Value index, MemRefType memRefType, 101 VectorType vType, Value &ptrs) { 102 int64_t offset; 103 SmallVector<int64_t, 4> strides; 104 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 105 if (failed(successStrides) || strides.back() != 1 || 106 memRefType.getMemorySpaceAsInt() != 0) 107 return failure(); 108 auto pType = MemRefDescriptor(memref).getElementPtrType(); 109 auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); 110 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index); 111 return success(); 112 } 113 114 // Casts a strided element pointer to a vector pointer. The vector pointer 115 // will be in the same address space as the incoming memref type. 116 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 117 Value ptr, MemRefType memRefType, Type vt) { 118 auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt()); 119 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 120 } 121 122 namespace { 123 124 /// Trivial Vector to LLVM conversions 125 using VectorScaleOpConversion = 126 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 127 128 /// Conversion pattern for a vector.bitcast. 129 class VectorBitCastOpConversion 130 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 131 public: 132 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 133 134 LogicalResult 135 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 136 ConversionPatternRewriter &rewriter) const override { 137 // Only 0-D and 1-D vectors can be lowered to LLVM. 138 VectorType resultTy = bitCastOp.getResultVectorType(); 139 if (resultTy.getRank() > 1) 140 return failure(); 141 Type newResultTy = typeConverter->convertType(resultTy); 142 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 143 adaptor.getOperands()[0]); 144 return success(); 145 } 146 }; 147 148 /// Conversion pattern for a vector.matrix_multiply. 149 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 150 class VectorMatmulOpConversion 151 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 152 public: 153 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 154 155 LogicalResult 156 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 157 ConversionPatternRewriter &rewriter) const override { 158 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 159 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 160 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 161 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 162 return success(); 163 } 164 }; 165 166 /// Conversion pattern for a vector.flat_transpose. 167 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 168 class VectorFlatTransposeOpConversion 169 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 170 public: 171 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 172 173 LogicalResult 174 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 175 ConversionPatternRewriter &rewriter) const override { 176 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 177 transOp, typeConverter->convertType(transOp.getRes().getType()), 178 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 179 return success(); 180 } 181 }; 182 183 /// Overloaded utility that replaces a vector.load, vector.store, 184 /// vector.maskedload and vector.maskedstore with their respective LLVM 185 /// couterparts. 186 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 187 vector::LoadOpAdaptor adaptor, 188 VectorType vectorTy, Value ptr, unsigned align, 189 ConversionPatternRewriter &rewriter) { 190 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align); 191 } 192 193 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 194 vector::MaskedLoadOpAdaptor adaptor, 195 VectorType vectorTy, Value ptr, unsigned align, 196 ConversionPatternRewriter &rewriter) { 197 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 198 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 199 } 200 201 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 202 vector::StoreOpAdaptor adaptor, 203 VectorType vectorTy, Value ptr, unsigned align, 204 ConversionPatternRewriter &rewriter) { 205 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 206 ptr, align); 207 } 208 209 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 210 vector::MaskedStoreOpAdaptor adaptor, 211 VectorType vectorTy, Value ptr, unsigned align, 212 ConversionPatternRewriter &rewriter) { 213 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 214 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 215 } 216 217 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 218 /// vector.maskedstore. 219 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 220 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 221 public: 222 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 223 224 LogicalResult 225 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 226 typename LoadOrStoreOp::Adaptor adaptor, 227 ConversionPatternRewriter &rewriter) const override { 228 // Only 1-D vectors can be lowered to LLVM. 229 VectorType vectorTy = loadOrStoreOp.getVectorType(); 230 if (vectorTy.getRank() > 1) 231 return failure(); 232 233 auto loc = loadOrStoreOp->getLoc(); 234 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 235 236 // Resolve alignment. 237 unsigned align; 238 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 239 return failure(); 240 241 // Resolve address. 242 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) 243 .template cast<VectorType>(); 244 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 245 adaptor.getIndices(), rewriter); 246 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); 247 248 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 249 return success(); 250 } 251 }; 252 253 /// Conversion pattern for a vector.gather. 254 class VectorGatherOpConversion 255 : public ConvertOpToLLVMPattern<vector::GatherOp> { 256 public: 257 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 258 259 LogicalResult 260 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 261 ConversionPatternRewriter &rewriter) const override { 262 auto loc = gather->getLoc(); 263 MemRefType memRefType = gather.getMemRefType(); 264 265 // Resolve alignment. 266 unsigned align; 267 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 268 return failure(); 269 270 // Resolve address. 271 Value ptrs; 272 VectorType vType = gather.getVectorType(); 273 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 274 adaptor.getIndices(), rewriter); 275 if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, 276 adaptor.getIndexVec(), memRefType, vType, ptrs))) 277 return failure(); 278 279 // Replace with the gather intrinsic. 280 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 281 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 282 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 283 return success(); 284 } 285 }; 286 287 /// Conversion pattern for a vector.scatter. 288 class VectorScatterOpConversion 289 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 290 public: 291 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 292 293 LogicalResult 294 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 295 ConversionPatternRewriter &rewriter) const override { 296 auto loc = scatter->getLoc(); 297 MemRefType memRefType = scatter.getMemRefType(); 298 299 // Resolve alignment. 300 unsigned align; 301 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 302 return failure(); 303 304 // Resolve address. 305 Value ptrs; 306 VectorType vType = scatter.getVectorType(); 307 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 308 adaptor.getIndices(), rewriter); 309 if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, 310 adaptor.getIndexVec(), memRefType, vType, ptrs))) 311 return failure(); 312 313 // Replace with the scatter intrinsic. 314 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 315 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 316 rewriter.getI32IntegerAttr(align)); 317 return success(); 318 } 319 }; 320 321 /// Conversion pattern for a vector.expandload. 322 class VectorExpandLoadOpConversion 323 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 324 public: 325 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 326 327 LogicalResult 328 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 329 ConversionPatternRewriter &rewriter) const override { 330 auto loc = expand->getLoc(); 331 MemRefType memRefType = expand.getMemRefType(); 332 333 // Resolve address. 334 auto vtype = typeConverter->convertType(expand.getVectorType()); 335 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 336 adaptor.getIndices(), rewriter); 337 338 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 339 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 340 return success(); 341 } 342 }; 343 344 /// Conversion pattern for a vector.compressstore. 345 class VectorCompressStoreOpConversion 346 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 347 public: 348 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 349 350 LogicalResult 351 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 352 ConversionPatternRewriter &rewriter) const override { 353 auto loc = compress->getLoc(); 354 MemRefType memRefType = compress.getMemRefType(); 355 356 // Resolve address. 357 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 358 adaptor.getIndices(), rewriter); 359 360 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 361 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 362 return success(); 363 } 364 }; 365 366 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 367 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 368 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 369 /// non-null. 370 template <class VectorOp, class ScalarOp> 371 static Value createIntegerReductionArithmeticOpLowering( 372 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 373 Value vectorOperand, Value accumulator) { 374 Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand); 375 if (accumulator) 376 result = rewriter.create<ScalarOp>(loc, accumulator, result); 377 return result; 378 } 379 380 /// Helper method to lower a `vector.reduction` operation that performs 381 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 382 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 383 /// the accumulator value if non-null. 384 template <class VectorOp> 385 static Value createIntegerReductionComparisonOpLowering( 386 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 387 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 388 Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand); 389 if (accumulator) { 390 Value cmp = 391 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 392 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 393 } 394 return result; 395 } 396 397 /// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum 398 /// with vector types. 399 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, 400 Value rhs, bool isMin) { 401 auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>(); 402 Type i1Type = builder.getI1Type(); 403 if (auto vecType = lhs.getType().dyn_cast<VectorType>()) 404 i1Type = VectorType::get(vecType.getShape(), i1Type); 405 Value cmp = builder.create<LLVM::FCmpOp>( 406 loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, 407 lhs, rhs); 408 Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs); 409 Value isNan = builder.create<LLVM::FCmpOp>( 410 loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); 411 Value nan = builder.create<LLVM::ConstantOp>( 412 loc, lhs.getType(), 413 builder.getFloatAttr(floatType, 414 APFloat::getQNaN(floatType.getFloatSemantics()))); 415 return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel); 416 } 417 418 /// Conversion pattern for all vector reductions. 419 class VectorReductionOpConversion 420 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 421 public: 422 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 423 bool reassociateFPRed) 424 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 425 reassociateFPReductions(reassociateFPRed) {} 426 427 LogicalResult 428 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 429 ConversionPatternRewriter &rewriter) const override { 430 auto kind = reductionOp.getKind(); 431 Type eltType = reductionOp.getDest().getType(); 432 Type llvmType = typeConverter->convertType(eltType); 433 Value operand = adaptor.getVector(); 434 Value acc = adaptor.getAcc(); 435 Location loc = reductionOp.getLoc(); 436 if (eltType.isIntOrIndex()) { 437 // Integer reductions: add/mul/min/max/and/or/xor. 438 Value result; 439 switch (kind) { 440 case vector::CombiningKind::ADD: 441 result = 442 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 443 LLVM::AddOp>( 444 rewriter, loc, llvmType, operand, acc); 445 break; 446 case vector::CombiningKind::MUL: 447 result = 448 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 449 LLVM::MulOp>( 450 rewriter, loc, llvmType, operand, acc); 451 break; 452 case vector::CombiningKind::MINUI: 453 result = createIntegerReductionComparisonOpLowering< 454 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 455 LLVM::ICmpPredicate::ule); 456 break; 457 case vector::CombiningKind::MINSI: 458 result = createIntegerReductionComparisonOpLowering< 459 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 460 LLVM::ICmpPredicate::sle); 461 break; 462 case vector::CombiningKind::MAXUI: 463 result = createIntegerReductionComparisonOpLowering< 464 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 465 LLVM::ICmpPredicate::uge); 466 break; 467 case vector::CombiningKind::MAXSI: 468 result = createIntegerReductionComparisonOpLowering< 469 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 470 LLVM::ICmpPredicate::sge); 471 break; 472 case vector::CombiningKind::AND: 473 result = 474 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 475 LLVM::AndOp>( 476 rewriter, loc, llvmType, operand, acc); 477 break; 478 case vector::CombiningKind::OR: 479 result = 480 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 481 LLVM::OrOp>( 482 rewriter, loc, llvmType, operand, acc); 483 break; 484 case vector::CombiningKind::XOR: 485 result = 486 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 487 LLVM::XOrOp>( 488 rewriter, loc, llvmType, operand, acc); 489 break; 490 default: 491 return failure(); 492 } 493 rewriter.replaceOp(reductionOp, result); 494 495 return success(); 496 } 497 498 if (!eltType.isa<FloatType>()) 499 return failure(); 500 501 // Floating-point reductions: add/mul/min/max 502 if (kind == vector::CombiningKind::ADD) { 503 // Optional accumulator (or zero). 504 Value acc = adaptor.getOperands().size() > 1 505 ? adaptor.getOperands()[1] 506 : rewriter.create<LLVM::ConstantOp>( 507 reductionOp->getLoc(), llvmType, 508 rewriter.getZeroAttr(eltType)); 509 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 510 reductionOp, llvmType, acc, operand, 511 rewriter.getBoolAttr(reassociateFPReductions)); 512 } else if (kind == vector::CombiningKind::MUL) { 513 // Optional accumulator (or one). 514 Value acc = adaptor.getOperands().size() > 1 515 ? adaptor.getOperands()[1] 516 : rewriter.create<LLVM::ConstantOp>( 517 reductionOp->getLoc(), llvmType, 518 rewriter.getFloatAttr(eltType, 1.0)); 519 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 520 reductionOp, llvmType, acc, operand, 521 rewriter.getBoolAttr(reassociateFPReductions)); 522 } else if (kind == vector::CombiningKind::MINF) { 523 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 524 // NaNs/-0.0/+0.0 in the same way. 525 Value result = 526 rewriter.create<LLVM::vector_reduce_fmin>(loc, llvmType, operand); 527 if (acc) 528 result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true); 529 rewriter.replaceOp(reductionOp, result); 530 } else if (kind == vector::CombiningKind::MAXF) { 531 // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle 532 // NaNs/-0.0/+0.0 in the same way. 533 Value result = 534 rewriter.create<LLVM::vector_reduce_fmax>(loc, llvmType, operand); 535 if (acc) 536 result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false); 537 rewriter.replaceOp(reductionOp, result); 538 } else 539 return failure(); 540 541 return success(); 542 } 543 544 private: 545 const bool reassociateFPReductions; 546 }; 547 548 class VectorShuffleOpConversion 549 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 550 public: 551 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 552 553 LogicalResult 554 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 555 ConversionPatternRewriter &rewriter) const override { 556 auto loc = shuffleOp->getLoc(); 557 auto v1Type = shuffleOp.getV1VectorType(); 558 auto v2Type = shuffleOp.getV2VectorType(); 559 auto vectorType = shuffleOp.getVectorType(); 560 Type llvmType = typeConverter->convertType(vectorType); 561 auto maskArrayAttr = shuffleOp.getMask(); 562 563 // Bail if result type cannot be lowered. 564 if (!llvmType) 565 return failure(); 566 567 // Get rank and dimension sizes. 568 int64_t rank = vectorType.getRank(); 569 assert(v1Type.getRank() == rank); 570 assert(v2Type.getRank() == rank); 571 int64_t v1Dim = v1Type.getDimSize(0); 572 573 // For rank 1, where both operands have *exactly* the same vector type, 574 // there is direct shuffle support in LLVM. Use it! 575 if (rank == 1 && v1Type == v2Type) { 576 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 577 loc, adaptor.getV1(), adaptor.getV2(), maskArrayAttr); 578 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 579 return success(); 580 } 581 582 // For all other cases, insert the individual values individually. 583 Type eltType; 584 if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>()) 585 eltType = arrayType.getElementType(); 586 else 587 eltType = llvmType.cast<VectorType>().getElementType(); 588 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 589 int64_t insPos = 0; 590 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 591 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 592 Value value = adaptor.getV1(); 593 if (extPos >= v1Dim) { 594 extPos -= v1Dim; 595 value = adaptor.getV2(); 596 } 597 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 598 eltType, rank, extPos); 599 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 600 llvmType, rank, insPos++); 601 } 602 rewriter.replaceOp(shuffleOp, insert); 603 return success(); 604 } 605 }; 606 607 class VectorExtractElementOpConversion 608 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 609 public: 610 using ConvertOpToLLVMPattern< 611 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 612 613 LogicalResult 614 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 615 ConversionPatternRewriter &rewriter) const override { 616 auto vectorType = extractEltOp.getVectorType(); 617 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 618 619 // Bail if result type cannot be lowered. 620 if (!llvmType) 621 return failure(); 622 623 if (vectorType.getRank() == 0) { 624 Location loc = extractEltOp.getLoc(); 625 auto idxType = rewriter.getIndexType(); 626 auto zero = rewriter.create<LLVM::ConstantOp>( 627 loc, typeConverter->convertType(idxType), 628 rewriter.getIntegerAttr(idxType, 0)); 629 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 630 extractEltOp, llvmType, adaptor.getVector(), zero); 631 return success(); 632 } 633 634 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 635 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 636 return success(); 637 } 638 }; 639 640 class VectorExtractOpConversion 641 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 642 public: 643 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 644 645 LogicalResult 646 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 647 ConversionPatternRewriter &rewriter) const override { 648 auto loc = extractOp->getLoc(); 649 auto vectorType = extractOp.getVectorType(); 650 auto resultType = extractOp.getResult().getType(); 651 auto llvmResultType = typeConverter->convertType(resultType); 652 auto positionArrayAttr = extractOp.getPosition(); 653 654 // Bail if result type cannot be lowered. 655 if (!llvmResultType) 656 return failure(); 657 658 // Extract entire vector. Should be handled by folder, but just to be safe. 659 if (positionArrayAttr.empty()) { 660 rewriter.replaceOp(extractOp, adaptor.getVector()); 661 return success(); 662 } 663 664 // One-shot extraction of vector from array (only requires extractvalue). 665 if (resultType.isa<VectorType>()) { 666 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 667 loc, llvmResultType, adaptor.getVector(), positionArrayAttr); 668 rewriter.replaceOp(extractOp, extracted); 669 return success(); 670 } 671 672 // Potential extraction of 1-D vector from array. 673 auto *context = extractOp->getContext(); 674 Value extracted = adaptor.getVector(); 675 auto positionAttrs = positionArrayAttr.getValue(); 676 if (positionAttrs.size() > 1) { 677 auto oneDVectorType = reducedVectorTypeBack(vectorType); 678 auto nMinusOnePositionAttrs = 679 ArrayAttr::get(context, positionAttrs.drop_back()); 680 extracted = rewriter.create<LLVM::ExtractValueOp>( 681 loc, typeConverter->convertType(oneDVectorType), extracted, 682 nMinusOnePositionAttrs); 683 } 684 685 // Remaining extraction of element from 1-D LLVM vector 686 auto position = positionAttrs.back().cast<IntegerAttr>(); 687 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 688 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 689 extracted = 690 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 691 rewriter.replaceOp(extractOp, extracted); 692 693 return success(); 694 } 695 }; 696 697 /// Conversion pattern that turns a vector.fma on a 1-D vector 698 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 699 /// This does not match vectors of n >= 2 rank. 700 /// 701 /// Example: 702 /// ``` 703 /// vector.fma %a, %a, %a : vector<8xf32> 704 /// ``` 705 /// is converted to: 706 /// ``` 707 /// llvm.intr.fmuladd %va, %va, %va: 708 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 709 /// -> !llvm."<8 x f32>"> 710 /// ``` 711 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 712 public: 713 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 714 715 LogicalResult 716 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 717 ConversionPatternRewriter &rewriter) const override { 718 VectorType vType = fmaOp.getVectorType(); 719 if (vType.getRank() != 1) 720 return failure(); 721 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 722 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 723 return success(); 724 } 725 }; 726 727 class VectorInsertElementOpConversion 728 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 729 public: 730 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 731 732 LogicalResult 733 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 734 ConversionPatternRewriter &rewriter) const override { 735 auto vectorType = insertEltOp.getDestVectorType(); 736 auto llvmType = typeConverter->convertType(vectorType); 737 738 // Bail if result type cannot be lowered. 739 if (!llvmType) 740 return failure(); 741 742 if (vectorType.getRank() == 0) { 743 Location loc = insertEltOp.getLoc(); 744 auto idxType = rewriter.getIndexType(); 745 auto zero = rewriter.create<LLVM::ConstantOp>( 746 loc, typeConverter->convertType(idxType), 747 rewriter.getIntegerAttr(idxType, 0)); 748 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 749 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 750 return success(); 751 } 752 753 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 754 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 755 adaptor.getPosition()); 756 return success(); 757 } 758 }; 759 760 class VectorInsertOpConversion 761 : public ConvertOpToLLVMPattern<vector::InsertOp> { 762 public: 763 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 764 765 LogicalResult 766 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 767 ConversionPatternRewriter &rewriter) const override { 768 auto loc = insertOp->getLoc(); 769 auto sourceType = insertOp.getSourceType(); 770 auto destVectorType = insertOp.getDestVectorType(); 771 auto llvmResultType = typeConverter->convertType(destVectorType); 772 auto positionArrayAttr = insertOp.getPosition(); 773 774 // Bail if result type cannot be lowered. 775 if (!llvmResultType) 776 return failure(); 777 778 // Overwrite entire vector with value. Should be handled by folder, but 779 // just to be safe. 780 if (positionArrayAttr.empty()) { 781 rewriter.replaceOp(insertOp, adaptor.getSource()); 782 return success(); 783 } 784 785 // One-shot insertion of a vector into an array (only requires insertvalue). 786 if (sourceType.isa<VectorType>()) { 787 Value inserted = rewriter.create<LLVM::InsertValueOp>( 788 loc, llvmResultType, adaptor.getDest(), adaptor.getSource(), 789 positionArrayAttr); 790 rewriter.replaceOp(insertOp, inserted); 791 return success(); 792 } 793 794 // Potential extraction of 1-D vector from array. 795 auto *context = insertOp->getContext(); 796 Value extracted = adaptor.getDest(); 797 auto positionAttrs = positionArrayAttr.getValue(); 798 auto position = positionAttrs.back().cast<IntegerAttr>(); 799 auto oneDVectorType = destVectorType; 800 if (positionAttrs.size() > 1) { 801 oneDVectorType = reducedVectorTypeBack(destVectorType); 802 auto nMinusOnePositionAttrs = 803 ArrayAttr::get(context, positionAttrs.drop_back()); 804 extracted = rewriter.create<LLVM::ExtractValueOp>( 805 loc, typeConverter->convertType(oneDVectorType), extracted, 806 nMinusOnePositionAttrs); 807 } 808 809 // Insertion of an element into a 1-D LLVM vector. 810 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 811 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 812 Value inserted = rewriter.create<LLVM::InsertElementOp>( 813 loc, typeConverter->convertType(oneDVectorType), extracted, 814 adaptor.getSource(), constant); 815 816 // Potential insertion of resulting 1-D vector into array. 817 if (positionAttrs.size() > 1) { 818 auto nMinusOnePositionAttrs = 819 ArrayAttr::get(context, positionAttrs.drop_back()); 820 inserted = rewriter.create<LLVM::InsertValueOp>( 821 loc, llvmResultType, adaptor.getDest(), inserted, 822 nMinusOnePositionAttrs); 823 } 824 825 rewriter.replaceOp(insertOp, inserted); 826 return success(); 827 } 828 }; 829 830 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 831 /// 832 /// Example: 833 /// ``` 834 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 835 /// ``` 836 /// is rewritten into: 837 /// ``` 838 /// %r = splat %f0: vector<2x4xf32> 839 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 840 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 841 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 842 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 843 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 844 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 845 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 846 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 847 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 848 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 849 /// // %r3 holds the final value. 850 /// ``` 851 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 852 public: 853 using OpRewritePattern<FMAOp>::OpRewritePattern; 854 855 void initialize() { 856 // This pattern recursively unpacks one dimension at a time. The recursion 857 // bounded as the rank is strictly decreasing. 858 setHasBoundedRewriteRecursion(); 859 } 860 861 LogicalResult matchAndRewrite(FMAOp op, 862 PatternRewriter &rewriter) const override { 863 auto vType = op.getVectorType(); 864 if (vType.getRank() < 2) 865 return failure(); 866 867 auto loc = op.getLoc(); 868 auto elemType = vType.getElementType(); 869 Value zero = rewriter.create<arith::ConstantOp>( 870 loc, elemType, rewriter.getZeroAttr(elemType)); 871 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 872 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 873 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 874 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 875 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 876 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 877 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 878 } 879 rewriter.replaceOp(op, desc); 880 return success(); 881 } 882 }; 883 884 /// Returns the strides if the memory underlying `memRefType` has a contiguous 885 /// static layout. 886 static llvm::Optional<SmallVector<int64_t, 4>> 887 computeContiguousStrides(MemRefType memRefType) { 888 int64_t offset; 889 SmallVector<int64_t, 4> strides; 890 if (failed(getStridesAndOffset(memRefType, strides, offset))) 891 return None; 892 if (!strides.empty() && strides.back() != 1) 893 return None; 894 // If no layout or identity layout, this is contiguous by definition. 895 if (memRefType.getLayout().isIdentity()) 896 return strides; 897 898 // Otherwise, we must determine contiguity form shapes. This can only ever 899 // work in static cases because MemRefType is underspecified to represent 900 // contiguous dynamic shapes in other ways than with just empty/identity 901 // layout. 902 auto sizes = memRefType.getShape(); 903 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 904 if (ShapedType::isDynamic(sizes[index + 1]) || 905 ShapedType::isDynamicStrideOrOffset(strides[index]) || 906 ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 907 return None; 908 if (strides[index] != strides[index + 1] * sizes[index + 1]) 909 return None; 910 } 911 return strides; 912 } 913 914 class VectorTypeCastOpConversion 915 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 916 public: 917 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 918 919 LogicalResult 920 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 921 ConversionPatternRewriter &rewriter) const override { 922 auto loc = castOp->getLoc(); 923 MemRefType sourceMemRefType = 924 castOp.getOperand().getType().cast<MemRefType>(); 925 MemRefType targetMemRefType = castOp.getType(); 926 927 // Only static shape casts supported atm. 928 if (!sourceMemRefType.hasStaticShape() || 929 !targetMemRefType.hasStaticShape()) 930 return failure(); 931 932 auto llvmSourceDescriptorTy = 933 adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>(); 934 if (!llvmSourceDescriptorTy) 935 return failure(); 936 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 937 938 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 939 .dyn_cast_or_null<LLVM::LLVMStructType>(); 940 if (!llvmTargetDescriptorTy) 941 return failure(); 942 943 // Only contiguous source buffers supported atm. 944 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 945 if (!sourceStrides) 946 return failure(); 947 auto targetStrides = computeContiguousStrides(targetMemRefType); 948 if (!targetStrides) 949 return failure(); 950 // Only support static strides for now, regardless of contiguity. 951 if (llvm::any_of(*targetStrides, [](int64_t stride) { 952 return ShapedType::isDynamicStrideOrOffset(stride); 953 })) 954 return failure(); 955 956 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 957 958 // Create descriptor. 959 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 960 Type llvmTargetElementTy = desc.getElementPtrType(); 961 // Set allocated ptr. 962 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 963 allocated = 964 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 965 desc.setAllocatedPtr(rewriter, loc, allocated); 966 // Set aligned ptr. 967 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 968 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 969 desc.setAlignedPtr(rewriter, loc, ptr); 970 // Fill offset 0. 971 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 972 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 973 desc.setOffset(rewriter, loc, zero); 974 975 // Fill size and stride descriptors in memref. 976 for (const auto &indexedSize : 977 llvm::enumerate(targetMemRefType.getShape())) { 978 int64_t index = indexedSize.index(); 979 auto sizeAttr = 980 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 981 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 982 desc.setSize(rewriter, loc, index, size); 983 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 984 (*targetStrides)[index]); 985 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 986 desc.setStride(rewriter, loc, index, stride); 987 } 988 989 rewriter.replaceOp(castOp, {desc}); 990 return success(); 991 } 992 }; 993 994 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 995 /// Non-scalable versions of this operation are handled in Vector Transforms. 996 class VectorCreateMaskOpRewritePattern 997 : public OpRewritePattern<vector::CreateMaskOp> { 998 public: 999 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 1000 bool enableIndexOpt) 1001 : OpRewritePattern<vector::CreateMaskOp>(context), 1002 force32BitVectorIndices(enableIndexOpt) {} 1003 1004 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1005 PatternRewriter &rewriter) const override { 1006 auto dstType = op.getType(); 1007 if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable()) 1008 return failure(); 1009 IntegerType idxType = 1010 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1011 auto loc = op->getLoc(); 1012 Value indices = rewriter.create<LLVM::StepVectorOp>( 1013 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1014 /*isScalable=*/true)); 1015 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1016 op.getOperand(0)); 1017 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1018 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1019 indices, bounds); 1020 rewriter.replaceOp(op, comp); 1021 return success(); 1022 } 1023 1024 private: 1025 const bool force32BitVectorIndices; 1026 }; 1027 1028 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1029 public: 1030 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1031 1032 // Proof-of-concept lowering implementation that relies on a small 1033 // runtime support library, which only needs to provide a few 1034 // printing methods (single value for all data types, opening/closing 1035 // bracket, comma, newline). The lowering fully unrolls a vector 1036 // in terms of these elementary printing operations. The advantage 1037 // of this approach is that the library can remain unaware of all 1038 // low-level implementation details of vectors while still supporting 1039 // output of any shaped and dimensioned vector. Due to full unrolling, 1040 // this approach is less suited for very large vectors though. 1041 // 1042 // TODO: rely solely on libc in future? something else? 1043 // 1044 LogicalResult 1045 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1046 ConversionPatternRewriter &rewriter) const override { 1047 Type printType = printOp.getPrintType(); 1048 1049 if (typeConverter->convertType(printType) == nullptr) 1050 return failure(); 1051 1052 // Make sure element type has runtime support. 1053 PrintConversion conversion = PrintConversion::None; 1054 VectorType vectorType = printType.dyn_cast<VectorType>(); 1055 Type eltType = vectorType ? vectorType.getElementType() : printType; 1056 Operation *printer; 1057 if (eltType.isF32()) { 1058 printer = 1059 LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>()); 1060 } else if (eltType.isF64()) { 1061 printer = 1062 LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>()); 1063 } else if (eltType.isIndex()) { 1064 printer = 1065 LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>()); 1066 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1067 // Integers need a zero or sign extension on the operand 1068 // (depending on the source type) as well as a signed or 1069 // unsigned print method. Up to 64-bit is supported. 1070 unsigned width = intTy.getWidth(); 1071 if (intTy.isUnsigned()) { 1072 if (width <= 64) { 1073 if (width < 64) 1074 conversion = PrintConversion::ZeroExt64; 1075 printer = LLVM::lookupOrCreatePrintU64Fn( 1076 printOp->getParentOfType<ModuleOp>()); 1077 } else { 1078 return failure(); 1079 } 1080 } else { 1081 assert(intTy.isSignless() || intTy.isSigned()); 1082 if (width <= 64) { 1083 // Note that we *always* zero extend booleans (1-bit integers), 1084 // so that true/false is printed as 1/0 rather than -1/0. 1085 if (width == 1) 1086 conversion = PrintConversion::ZeroExt64; 1087 else if (width < 64) 1088 conversion = PrintConversion::SignExt64; 1089 printer = LLVM::lookupOrCreatePrintI64Fn( 1090 printOp->getParentOfType<ModuleOp>()); 1091 } else { 1092 return failure(); 1093 } 1094 } 1095 } else { 1096 return failure(); 1097 } 1098 1099 // Unroll vector into elementary print calls. 1100 int64_t rank = vectorType ? vectorType.getRank() : 0; 1101 Type type = vectorType ? vectorType : eltType; 1102 emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, 1103 conversion); 1104 emitCall(rewriter, printOp->getLoc(), 1105 LLVM::lookupOrCreatePrintNewlineFn( 1106 printOp->getParentOfType<ModuleOp>())); 1107 rewriter.eraseOp(printOp); 1108 return success(); 1109 } 1110 1111 private: 1112 enum class PrintConversion { 1113 // clang-format off 1114 None, 1115 ZeroExt64, 1116 SignExt64 1117 // clang-format on 1118 }; 1119 1120 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1121 Value value, Type type, Operation *printer, int64_t rank, 1122 PrintConversion conversion) const { 1123 VectorType vectorType = type.dyn_cast<VectorType>(); 1124 Location loc = op->getLoc(); 1125 if (!vectorType) { 1126 assert(rank == 0 && "The scalar case expects rank == 0"); 1127 switch (conversion) { 1128 case PrintConversion::ZeroExt64: 1129 value = rewriter.create<arith::ExtUIOp>( 1130 loc, IntegerType::get(rewriter.getContext(), 64), value); 1131 break; 1132 case PrintConversion::SignExt64: 1133 value = rewriter.create<arith::ExtSIOp>( 1134 loc, IntegerType::get(rewriter.getContext(), 64), value); 1135 break; 1136 case PrintConversion::None: 1137 break; 1138 } 1139 emitCall(rewriter, loc, printer, value); 1140 return; 1141 } 1142 1143 emitCall(rewriter, loc, 1144 LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>())); 1145 Operation *printComma = 1146 LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>()); 1147 1148 if (rank <= 1) { 1149 auto reducedType = vectorType.getElementType(); 1150 auto llvmType = typeConverter->convertType(reducedType); 1151 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); 1152 for (int64_t d = 0; d < dim; ++d) { 1153 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1154 llvmType, /*rank=*/0, /*pos=*/d); 1155 emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, 1156 conversion); 1157 if (d != dim - 1) 1158 emitCall(rewriter, loc, printComma); 1159 } 1160 emitCall( 1161 rewriter, loc, 1162 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1163 return; 1164 } 1165 1166 int64_t dim = vectorType.getDimSize(0); 1167 for (int64_t d = 0; d < dim; ++d) { 1168 auto reducedType = reducedVectorTypeFront(vectorType); 1169 auto llvmType = typeConverter->convertType(reducedType); 1170 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1171 llvmType, rank, d); 1172 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1173 conversion); 1174 if (d != dim - 1) 1175 emitCall(rewriter, loc, printComma); 1176 } 1177 emitCall(rewriter, loc, 1178 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1179 } 1180 1181 // Helper to emit a call. 1182 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1183 Operation *ref, ValueRange params = ValueRange()) { 1184 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1185 params); 1186 } 1187 }; 1188 1189 /// The Splat operation is lowered to an insertelement + a shufflevector 1190 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1191 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1192 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1193 1194 LogicalResult 1195 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1196 ConversionPatternRewriter &rewriter) const override { 1197 VectorType resultType = splatOp.getType().cast<VectorType>(); 1198 if (resultType.getRank() > 1) 1199 return failure(); 1200 1201 // First insert it into an undef vector so we can shuffle it. 1202 auto vectorType = typeConverter->convertType(splatOp.getType()); 1203 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1204 auto zero = rewriter.create<LLVM::ConstantOp>( 1205 splatOp.getLoc(), 1206 typeConverter->convertType(rewriter.getIntegerType(32)), 1207 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1208 1209 // For 0-d vector, we simply do `insertelement`. 1210 if (resultType.getRank() == 0) { 1211 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1212 splatOp, vectorType, undef, adaptor.getInput(), zero); 1213 return success(); 1214 } 1215 1216 // For 1-d vector, we additionally do a `vectorshuffle`. 1217 auto v = rewriter.create<LLVM::InsertElementOp>( 1218 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1219 1220 int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0); 1221 SmallVector<int32_t, 4> zeroValues(width, 0); 1222 1223 // Shuffle the value across the desired number of elements. 1224 ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); 1225 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1226 zeroAttrs); 1227 return success(); 1228 } 1229 }; 1230 1231 /// The Splat operation is lowered to an insertelement + a shufflevector 1232 /// operation. Splat to only 2+-d vector result types are lowered by the 1233 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1234 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1235 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1236 1237 LogicalResult 1238 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1239 ConversionPatternRewriter &rewriter) const override { 1240 VectorType resultType = splatOp.getType(); 1241 if (resultType.getRank() <= 1) 1242 return failure(); 1243 1244 // First insert it into an undef vector so we can shuffle it. 1245 auto loc = splatOp.getLoc(); 1246 auto vectorTypeInfo = 1247 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1248 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1249 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1250 if (!llvmNDVectorTy || !llvm1DVectorTy) 1251 return failure(); 1252 1253 // Construct returned value. 1254 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1255 1256 // Construct a 1-D vector with the splatted value that we insert in all the 1257 // places within the returned descriptor. 1258 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1259 auto zero = rewriter.create<LLVM::ConstantOp>( 1260 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1261 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1262 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1263 adaptor.getInput(), zero); 1264 1265 // Shuffle the value across the desired number of elements. 1266 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1267 SmallVector<int32_t, 4> zeroValues(width, 0); 1268 ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); 1269 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs); 1270 1271 // Iterate of linear index, convert to coords space and insert splatted 1-D 1272 // vector in each position. 1273 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { 1274 desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v, 1275 position); 1276 }); 1277 rewriter.replaceOp(splatOp, desc); 1278 return success(); 1279 } 1280 }; 1281 1282 } // namespace 1283 1284 /// Populate the given list with patterns that convert from Vector to LLVM. 1285 void mlir::populateVectorToLLVMConversionPatterns( 1286 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1287 bool reassociateFPReductions, bool force32BitVectorIndices) { 1288 MLIRContext *ctx = converter.getDialect()->getContext(); 1289 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1290 populateVectorInsertExtractStridedSliceTransforms(patterns); 1291 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1292 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1293 patterns 1294 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1295 VectorExtractElementOpConversion, VectorExtractOpConversion, 1296 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1297 VectorInsertOpConversion, VectorPrintOpConversion, 1298 VectorTypeCastOpConversion, VectorScaleOpConversion, 1299 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1300 VectorLoadStoreConversion<vector::MaskedLoadOp, 1301 vector::MaskedLoadOpAdaptor>, 1302 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1303 VectorLoadStoreConversion<vector::MaskedStoreOp, 1304 vector::MaskedStoreOpAdaptor>, 1305 VectorGatherOpConversion, VectorScatterOpConversion, 1306 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1307 VectorSplatOpLowering, VectorSplatNdOpLowering>(converter); 1308 // Transfer ops with rank > 1 are handled by VectorToSCF. 1309 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1310 } 1311 1312 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1313 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1314 patterns.add<VectorMatmulOpConversion>(converter); 1315 patterns.add<VectorFlatTransposeOpConversion>(converter); 1316 } 1317