1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Arith/Utils/Utils.h" 14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Target/LLVMIR/TypeToLLVM.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 using namespace mlir::vector; 25 26 // Helper to reduce vector type by one rank at front. 27 static VectorType reducedVectorTypeFront(VectorType tp) { 28 assert((tp.getRank() > 1) && "unlowerable vector type"); 29 unsigned numScalableDims = tp.getNumScalableDims(); 30 if (tp.getShape().size() == numScalableDims) 31 --numScalableDims; 32 return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), 33 numScalableDims); 34 } 35 36 // Helper to reduce vector type by *all* but one rank at back. 37 static VectorType reducedVectorTypeBack(VectorType tp) { 38 assert((tp.getRank() > 1) && "unlowerable vector type"); 39 unsigned numScalableDims = tp.getNumScalableDims(); 40 if (numScalableDims > 0) 41 --numScalableDims; 42 return VectorType::get(tp.getShape().take_back(), tp.getElementType(), 43 numScalableDims); 44 } 45 46 // Helper that picks the proper sequence for inserting. 47 static Value insertOne(ConversionPatternRewriter &rewriter, 48 LLVMTypeConverter &typeConverter, Location loc, 49 Value val1, Value val2, Type llvmType, int64_t rank, 50 int64_t pos) { 51 assert(rank > 0 && "0-D vector corner case should have been handled already"); 52 if (rank == 1) { 53 auto idxType = rewriter.getIndexType(); 54 auto constant = rewriter.create<LLVM::ConstantOp>( 55 loc, typeConverter.convertType(idxType), 56 rewriter.getIntegerAttr(idxType, pos)); 57 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 58 constant); 59 } 60 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); 61 } 62 63 // Helper that picks the proper sequence for extracting. 64 static Value extractOne(ConversionPatternRewriter &rewriter, 65 LLVMTypeConverter &typeConverter, Location loc, 66 Value val, Type llvmType, int64_t rank, int64_t pos) { 67 if (rank <= 1) { 68 auto idxType = rewriter.getIndexType(); 69 auto constant = rewriter.create<LLVM::ConstantOp>( 70 loc, typeConverter.convertType(idxType), 71 rewriter.getIntegerAttr(idxType, pos)); 72 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 73 constant); 74 } 75 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); 76 } 77 78 // Helper that returns data layout alignment of a memref. 79 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 80 MemRefType memrefType, unsigned &align) { 81 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 82 if (!elementTy) 83 return failure(); 84 85 // TODO: this should use the MLIR data layout when it becomes available and 86 // stop depending on translation. 87 llvm::LLVMContext llvmContext; 88 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 89 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 90 return success(); 91 } 92 93 // Check if the last stride is non-unit or the memory space is not zero. 94 static LogicalResult isMemRefTypeSupported(MemRefType memRefType) { 95 int64_t offset; 96 SmallVector<int64_t, 4> strides; 97 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 98 if (failed(successStrides) || strides.back() != 1 || 99 memRefType.getMemorySpaceAsInt() != 0) 100 return failure(); 101 return success(); 102 } 103 104 // Add an index vector component to a base pointer. 105 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 106 MemRefType memRefType, Value llvmMemref, Value base, 107 Value index, uint64_t vLen) { 108 assert(succeeded(isMemRefTypeSupported(memRefType)) && 109 "unsupported memref type"); 110 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType(); 111 auto ptrsType = LLVM::getFixedVectorType(pType, vLen); 112 return rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index); 113 } 114 115 // Casts a strided element pointer to a vector pointer. The vector pointer 116 // will be in the same address space as the incoming memref type. 117 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 118 Value ptr, MemRefType memRefType, Type vt) { 119 auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt()); 120 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 121 } 122 123 namespace { 124 125 /// Trivial Vector to LLVM conversions 126 using VectorScaleOpConversion = 127 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>; 128 129 /// Conversion pattern for a vector.bitcast. 130 class VectorBitCastOpConversion 131 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 132 public: 133 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 134 135 LogicalResult 136 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, 137 ConversionPatternRewriter &rewriter) const override { 138 // Only 0-D and 1-D vectors can be lowered to LLVM. 139 VectorType resultTy = bitCastOp.getResultVectorType(); 140 if (resultTy.getRank() > 1) 141 return failure(); 142 Type newResultTy = typeConverter->convertType(resultTy); 143 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 144 adaptor.getOperands()[0]); 145 return success(); 146 } 147 }; 148 149 /// Conversion pattern for a vector.matrix_multiply. 150 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 151 class VectorMatmulOpConversion 152 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 153 public: 154 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 155 156 LogicalResult 157 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, 158 ConversionPatternRewriter &rewriter) const override { 159 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 160 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), 161 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), 162 matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); 163 return success(); 164 } 165 }; 166 167 /// Conversion pattern for a vector.flat_transpose. 168 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 169 class VectorFlatTransposeOpConversion 170 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 171 public: 172 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 173 174 LogicalResult 175 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, 176 ConversionPatternRewriter &rewriter) const override { 177 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 178 transOp, typeConverter->convertType(transOp.getRes().getType()), 179 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); 180 return success(); 181 } 182 }; 183 184 /// Overloaded utility that replaces a vector.load, vector.store, 185 /// vector.maskedload and vector.maskedstore with their respective LLVM 186 /// couterparts. 187 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 188 vector::LoadOpAdaptor adaptor, 189 VectorType vectorTy, Value ptr, unsigned align, 190 ConversionPatternRewriter &rewriter) { 191 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align); 192 } 193 194 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 195 vector::MaskedLoadOpAdaptor adaptor, 196 VectorType vectorTy, Value ptr, unsigned align, 197 ConversionPatternRewriter &rewriter) { 198 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 199 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); 200 } 201 202 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 203 vector::StoreOpAdaptor adaptor, 204 VectorType vectorTy, Value ptr, unsigned align, 205 ConversionPatternRewriter &rewriter) { 206 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), 207 ptr, align); 208 } 209 210 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 211 vector::MaskedStoreOpAdaptor adaptor, 212 VectorType vectorTy, Value ptr, unsigned align, 213 ConversionPatternRewriter &rewriter) { 214 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 215 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); 216 } 217 218 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 219 /// vector.maskedstore. 220 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 221 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 222 public: 223 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 224 225 LogicalResult 226 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, 227 typename LoadOrStoreOp::Adaptor adaptor, 228 ConversionPatternRewriter &rewriter) const override { 229 // Only 1-D vectors can be lowered to LLVM. 230 VectorType vectorTy = loadOrStoreOp.getVectorType(); 231 if (vectorTy.getRank() > 1) 232 return failure(); 233 234 auto loc = loadOrStoreOp->getLoc(); 235 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 236 237 // Resolve alignment. 238 unsigned align; 239 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 240 return failure(); 241 242 // Resolve address. 243 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) 244 .template cast<VectorType>(); 245 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), 246 adaptor.getIndices(), rewriter); 247 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); 248 249 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 250 return success(); 251 } 252 }; 253 254 /// Conversion pattern for a vector.gather. 255 class VectorGatherOpConversion 256 : public ConvertOpToLLVMPattern<vector::GatherOp> { 257 public: 258 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 259 260 LogicalResult 261 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, 262 ConversionPatternRewriter &rewriter) const override { 263 MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>(); 264 assert(memRefType && "The base should be bufferized"); 265 266 if (failed(isMemRefTypeSupported(memRefType))) 267 return failure(); 268 269 auto loc = gather->getLoc(); 270 271 // Resolve alignment. 272 unsigned align; 273 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 274 return failure(); 275 276 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 277 adaptor.getIndices(), rewriter); 278 Value base = adaptor.getBase(); 279 280 auto llvmNDVectorTy = adaptor.getIndexVec().getType(); 281 // Handle the simple case of 1-D vector. 282 if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) { 283 auto vType = gather.getVectorType(); 284 // Resolve address. 285 Value ptrs = getIndexedPtrs(rewriter, loc, memRefType, base, ptr, 286 adaptor.getIndexVec(), 287 /*vLen=*/vType.getDimSize(0)); 288 // Replace with the gather intrinsic. 289 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 290 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), 291 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); 292 return success(); 293 } 294 295 auto callback = [align, memRefType, base, ptr, loc, &rewriter]( 296 Type llvm1DVectorTy, ValueRange vectorOperands) { 297 // Resolve address. 298 Value ptrs = getIndexedPtrs( 299 rewriter, loc, memRefType, base, ptr, /*index=*/vectorOperands[0], 300 LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()); 301 // Create the gather intrinsic. 302 return rewriter.create<LLVM::masked_gather>( 303 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1], 304 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align)); 305 }; 306 SmallVector<Value> vectorOperands = { 307 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()}; 308 return LLVM::detail::handleMultidimensionalVectors( 309 gather, vectorOperands, *getTypeConverter(), callback, rewriter); 310 } 311 }; 312 313 /// Conversion pattern for a vector.scatter. 314 class VectorScatterOpConversion 315 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 316 public: 317 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 318 319 LogicalResult 320 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, 321 ConversionPatternRewriter &rewriter) const override { 322 auto loc = scatter->getLoc(); 323 MemRefType memRefType = scatter.getMemRefType(); 324 325 if (failed(isMemRefTypeSupported(memRefType))) 326 return failure(); 327 328 // Resolve alignment. 329 unsigned align; 330 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 331 return failure(); 332 333 // Resolve address. 334 VectorType vType = scatter.getVectorType(); 335 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 336 adaptor.getIndices(), rewriter); 337 Value ptrs = 338 getIndexedPtrs(rewriter, loc, memRefType, adaptor.getBase(), ptr, 339 adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0)); 340 341 // Replace with the scatter intrinsic. 342 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 343 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), 344 rewriter.getI32IntegerAttr(align)); 345 return success(); 346 } 347 }; 348 349 /// Conversion pattern for a vector.expandload. 350 class VectorExpandLoadOpConversion 351 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 352 public: 353 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 354 355 LogicalResult 356 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, 357 ConversionPatternRewriter &rewriter) const override { 358 auto loc = expand->getLoc(); 359 MemRefType memRefType = expand.getMemRefType(); 360 361 // Resolve address. 362 auto vtype = typeConverter->convertType(expand.getVectorType()); 363 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 364 adaptor.getIndices(), rewriter); 365 366 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 367 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); 368 return success(); 369 } 370 }; 371 372 /// Conversion pattern for a vector.compressstore. 373 class VectorCompressStoreOpConversion 374 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 375 public: 376 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 377 378 LogicalResult 379 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, 380 ConversionPatternRewriter &rewriter) const override { 381 auto loc = compress->getLoc(); 382 MemRefType memRefType = compress.getMemRefType(); 383 384 // Resolve address. 385 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), 386 adaptor.getIndices(), rewriter); 387 388 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 389 compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); 390 return success(); 391 } 392 }; 393 394 /// Helper method to lower a `vector.reduction` op that performs an arithmetic 395 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use 396 /// and `ScalarOp` is the scalar operation used to add the accumulation value if 397 /// non-null. 398 template <class VectorOp, class ScalarOp> 399 static Value createIntegerReductionArithmeticOpLowering( 400 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 401 Value vectorOperand, Value accumulator) { 402 Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand); 403 if (accumulator) 404 result = rewriter.create<ScalarOp>(loc, accumulator, result); 405 return result; 406 } 407 408 /// Helper method to lower a `vector.reduction` operation that performs 409 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector 410 /// intrinsic to use and `predicate` is the predicate to use to compare+combine 411 /// the accumulator value if non-null. 412 template <class VectorOp> 413 static Value createIntegerReductionComparisonOpLowering( 414 ConversionPatternRewriter &rewriter, Location loc, Type llvmType, 415 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { 416 Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand); 417 if (accumulator) { 418 Value cmp = 419 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); 420 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); 421 } 422 return result; 423 } 424 425 /// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum 426 /// with vector types. 427 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, 428 Value rhs, bool isMin) { 429 auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>(); 430 Type i1Type = builder.getI1Type(); 431 if (auto vecType = lhs.getType().dyn_cast<VectorType>()) 432 i1Type = VectorType::get(vecType.getShape(), i1Type); 433 Value cmp = builder.create<LLVM::FCmpOp>( 434 loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, 435 lhs, rhs); 436 Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs); 437 Value isNan = builder.create<LLVM::FCmpOp>( 438 loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); 439 Value nan = builder.create<LLVM::ConstantOp>( 440 loc, lhs.getType(), 441 builder.getFloatAttr(floatType, 442 APFloat::getQNaN(floatType.getFloatSemantics()))); 443 return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel); 444 } 445 446 /// Conversion pattern for all vector reductions. 447 class VectorReductionOpConversion 448 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 449 public: 450 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 451 bool reassociateFPRed) 452 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 453 reassociateFPReductions(reassociateFPRed) {} 454 455 LogicalResult 456 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, 457 ConversionPatternRewriter &rewriter) const override { 458 auto kind = reductionOp.getKind(); 459 Type eltType = reductionOp.getDest().getType(); 460 Type llvmType = typeConverter->convertType(eltType); 461 Value operand = adaptor.getVector(); 462 Value acc = adaptor.getAcc(); 463 Location loc = reductionOp.getLoc(); 464 if (eltType.isIntOrIndex()) { 465 // Integer reductions: add/mul/min/max/and/or/xor. 466 Value result; 467 switch (kind) { 468 case vector::CombiningKind::ADD: 469 result = 470 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add, 471 LLVM::AddOp>( 472 rewriter, loc, llvmType, operand, acc); 473 break; 474 case vector::CombiningKind::MUL: 475 result = 476 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul, 477 LLVM::MulOp>( 478 rewriter, loc, llvmType, operand, acc); 479 break; 480 case vector::CombiningKind::MINUI: 481 result = createIntegerReductionComparisonOpLowering< 482 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, 483 LLVM::ICmpPredicate::ule); 484 break; 485 case vector::CombiningKind::MINSI: 486 result = createIntegerReductionComparisonOpLowering< 487 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, 488 LLVM::ICmpPredicate::sle); 489 break; 490 case vector::CombiningKind::MAXUI: 491 result = createIntegerReductionComparisonOpLowering< 492 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, 493 LLVM::ICmpPredicate::uge); 494 break; 495 case vector::CombiningKind::MAXSI: 496 result = createIntegerReductionComparisonOpLowering< 497 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, 498 LLVM::ICmpPredicate::sge); 499 break; 500 case vector::CombiningKind::AND: 501 result = 502 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and, 503 LLVM::AndOp>( 504 rewriter, loc, llvmType, operand, acc); 505 break; 506 case vector::CombiningKind::OR: 507 result = 508 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or, 509 LLVM::OrOp>( 510 rewriter, loc, llvmType, operand, acc); 511 break; 512 case vector::CombiningKind::XOR: 513 result = 514 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor, 515 LLVM::XOrOp>( 516 rewriter, loc, llvmType, operand, acc); 517 break; 518 default: 519 return failure(); 520 } 521 rewriter.replaceOp(reductionOp, result); 522 523 return success(); 524 } 525 526 if (!eltType.isa<FloatType>()) 527 return failure(); 528 529 // Floating-point reductions: add/mul/min/max 530 if (kind == vector::CombiningKind::ADD) { 531 // Optional accumulator (or zero). 532 Value acc = adaptor.getOperands().size() > 1 533 ? adaptor.getOperands()[1] 534 : rewriter.create<LLVM::ConstantOp>( 535 reductionOp->getLoc(), llvmType, 536 rewriter.getZeroAttr(eltType)); 537 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 538 reductionOp, llvmType, acc, operand, 539 rewriter.getBoolAttr(reassociateFPReductions)); 540 } else if (kind == vector::CombiningKind::MUL) { 541 // Optional accumulator (or one). 542 Value acc = adaptor.getOperands().size() > 1 543 ? adaptor.getOperands()[1] 544 : rewriter.create<LLVM::ConstantOp>( 545 reductionOp->getLoc(), llvmType, 546 rewriter.getFloatAttr(eltType, 1.0)); 547 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 548 reductionOp, llvmType, acc, operand, 549 rewriter.getBoolAttr(reassociateFPReductions)); 550 } else if (kind == vector::CombiningKind::MINF) { 551 // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle 552 // NaNs/-0.0/+0.0 in the same way. 553 Value result = 554 rewriter.create<LLVM::vector_reduce_fmin>(loc, llvmType, operand); 555 if (acc) 556 result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true); 557 rewriter.replaceOp(reductionOp, result); 558 } else if (kind == vector::CombiningKind::MAXF) { 559 // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle 560 // NaNs/-0.0/+0.0 in the same way. 561 Value result = 562 rewriter.create<LLVM::vector_reduce_fmax>(loc, llvmType, operand); 563 if (acc) 564 result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false); 565 rewriter.replaceOp(reductionOp, result); 566 } else 567 return failure(); 568 569 return success(); 570 } 571 572 private: 573 const bool reassociateFPReductions; 574 }; 575 576 class VectorShuffleOpConversion 577 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 578 public: 579 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 580 581 LogicalResult 582 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 583 ConversionPatternRewriter &rewriter) const override { 584 auto loc = shuffleOp->getLoc(); 585 auto v1Type = shuffleOp.getV1VectorType(); 586 auto v2Type = shuffleOp.getV2VectorType(); 587 auto vectorType = shuffleOp.getVectorType(); 588 Type llvmType = typeConverter->convertType(vectorType); 589 auto maskArrayAttr = shuffleOp.getMask(); 590 591 // Bail if result type cannot be lowered. 592 if (!llvmType) 593 return failure(); 594 595 // Get rank and dimension sizes. 596 int64_t rank = vectorType.getRank(); 597 #ifndef NDEBUG 598 bool wellFormed0DCase = 599 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; 600 bool wellFormedNDCase = 601 v1Type.getRank() == rank && v2Type.getRank() == rank; 602 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); 603 #endif 604 605 // For rank 0 and 1, where both operands have *exactly* the same vector 606 // type, there is direct shuffle support in LLVM. Use it! 607 if (rank <= 1 && v1Type == v2Type) { 608 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 609 loc, adaptor.getV1(), adaptor.getV2(), 610 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr)); 611 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 612 return success(); 613 } 614 615 // For all other cases, insert the individual values individually. 616 int64_t v1Dim = v1Type.getDimSize(0); 617 Type eltType; 618 if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>()) 619 eltType = arrayType.getElementType(); 620 else 621 eltType = llvmType.cast<VectorType>().getElementType(); 622 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 623 int64_t insPos = 0; 624 for (const auto &en : llvm::enumerate(maskArrayAttr)) { 625 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 626 Value value = adaptor.getV1(); 627 if (extPos >= v1Dim) { 628 extPos -= v1Dim; 629 value = adaptor.getV2(); 630 } 631 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 632 eltType, rank, extPos); 633 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 634 llvmType, rank, insPos++); 635 } 636 rewriter.replaceOp(shuffleOp, insert); 637 return success(); 638 } 639 }; 640 641 class VectorExtractElementOpConversion 642 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 643 public: 644 using ConvertOpToLLVMPattern< 645 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 646 647 LogicalResult 648 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, 649 ConversionPatternRewriter &rewriter) const override { 650 auto vectorType = extractEltOp.getVectorType(); 651 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 652 653 // Bail if result type cannot be lowered. 654 if (!llvmType) 655 return failure(); 656 657 if (vectorType.getRank() == 0) { 658 Location loc = extractEltOp.getLoc(); 659 auto idxType = rewriter.getIndexType(); 660 auto zero = rewriter.create<LLVM::ConstantOp>( 661 loc, typeConverter->convertType(idxType), 662 rewriter.getIntegerAttr(idxType, 0)); 663 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 664 extractEltOp, llvmType, adaptor.getVector(), zero); 665 return success(); 666 } 667 668 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 669 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); 670 return success(); 671 } 672 }; 673 674 class VectorExtractOpConversion 675 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 676 public: 677 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 678 679 LogicalResult 680 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 681 ConversionPatternRewriter &rewriter) const override { 682 auto loc = extractOp->getLoc(); 683 auto resultType = extractOp.getResult().getType(); 684 auto llvmResultType = typeConverter->convertType(resultType); 685 auto positionArrayAttr = extractOp.getPosition(); 686 687 // Bail if result type cannot be lowered. 688 if (!llvmResultType) 689 return failure(); 690 691 // Extract entire vector. Should be handled by folder, but just to be safe. 692 if (positionArrayAttr.empty()) { 693 rewriter.replaceOp(extractOp, adaptor.getVector()); 694 return success(); 695 } 696 697 // One-shot extraction of vector from array (only requires extractvalue). 698 if (resultType.isa<VectorType>()) { 699 SmallVector<int64_t> indices; 700 for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>()) 701 indices.push_back(idx.getInt()); 702 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 703 loc, adaptor.getVector(), indices); 704 rewriter.replaceOp(extractOp, extracted); 705 return success(); 706 } 707 708 // Potential extraction of 1-D vector from array. 709 Value extracted = adaptor.getVector(); 710 auto positionAttrs = positionArrayAttr.getValue(); 711 if (positionAttrs.size() > 1) { 712 SmallVector<int64_t> nMinusOnePosition; 713 for (auto idx : positionAttrs.drop_back()) 714 nMinusOnePosition.push_back(idx.cast<IntegerAttr>().getInt()); 715 extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted, 716 nMinusOnePosition); 717 } 718 719 // Remaining extraction of element from 1-D LLVM vector 720 auto position = positionAttrs.back().cast<IntegerAttr>(); 721 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 722 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 723 extracted = 724 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 725 rewriter.replaceOp(extractOp, extracted); 726 727 return success(); 728 } 729 }; 730 731 /// Conversion pattern that turns a vector.fma on a 1-D vector 732 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 733 /// This does not match vectors of n >= 2 rank. 734 /// 735 /// Example: 736 /// ``` 737 /// vector.fma %a, %a, %a : vector<8xf32> 738 /// ``` 739 /// is converted to: 740 /// ``` 741 /// llvm.intr.fmuladd %va, %va, %va: 742 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 743 /// -> !llvm."<8 x f32>"> 744 /// ``` 745 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 746 public: 747 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 748 749 LogicalResult 750 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 751 ConversionPatternRewriter &rewriter) const override { 752 VectorType vType = fmaOp.getVectorType(); 753 if (vType.getRank() > 1) 754 return failure(); 755 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( 756 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); 757 return success(); 758 } 759 }; 760 761 class VectorInsertElementOpConversion 762 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 763 public: 764 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 765 766 LogicalResult 767 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, 768 ConversionPatternRewriter &rewriter) const override { 769 auto vectorType = insertEltOp.getDestVectorType(); 770 auto llvmType = typeConverter->convertType(vectorType); 771 772 // Bail if result type cannot be lowered. 773 if (!llvmType) 774 return failure(); 775 776 if (vectorType.getRank() == 0) { 777 Location loc = insertEltOp.getLoc(); 778 auto idxType = rewriter.getIndexType(); 779 auto zero = rewriter.create<LLVM::ConstantOp>( 780 loc, typeConverter->convertType(idxType), 781 rewriter.getIntegerAttr(idxType, 0)); 782 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 783 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); 784 return success(); 785 } 786 787 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 788 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), 789 adaptor.getPosition()); 790 return success(); 791 } 792 }; 793 794 class VectorInsertOpConversion 795 : public ConvertOpToLLVMPattern<vector::InsertOp> { 796 public: 797 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 798 799 LogicalResult 800 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 801 ConversionPatternRewriter &rewriter) const override { 802 auto loc = insertOp->getLoc(); 803 auto sourceType = insertOp.getSourceType(); 804 auto destVectorType = insertOp.getDestVectorType(); 805 auto llvmResultType = typeConverter->convertType(destVectorType); 806 auto positionArrayAttr = insertOp.getPosition(); 807 808 // Bail if result type cannot be lowered. 809 if (!llvmResultType) 810 return failure(); 811 812 // Overwrite entire vector with value. Should be handled by folder, but 813 // just to be safe. 814 if (positionArrayAttr.empty()) { 815 rewriter.replaceOp(insertOp, adaptor.getSource()); 816 return success(); 817 } 818 819 // One-shot insertion of a vector into an array (only requires insertvalue). 820 if (sourceType.isa<VectorType>()) { 821 Value inserted = rewriter.create<LLVM::InsertValueOp>( 822 loc, adaptor.getDest(), adaptor.getSource(), 823 LLVM::convertArrayToIndices(positionArrayAttr)); 824 rewriter.replaceOp(insertOp, inserted); 825 return success(); 826 } 827 828 // Potential extraction of 1-D vector from array. 829 Value extracted = adaptor.getDest(); 830 auto positionAttrs = positionArrayAttr.getValue(); 831 auto position = positionAttrs.back().cast<IntegerAttr>(); 832 auto oneDVectorType = destVectorType; 833 if (positionAttrs.size() > 1) { 834 oneDVectorType = reducedVectorTypeBack(destVectorType); 835 extracted = rewriter.create<LLVM::ExtractValueOp>( 836 loc, extracted, 837 LLVM::convertArrayToIndices(positionAttrs.drop_back())); 838 } 839 840 // Insertion of an element into a 1-D LLVM vector. 841 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 842 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 843 Value inserted = rewriter.create<LLVM::InsertElementOp>( 844 loc, typeConverter->convertType(oneDVectorType), extracted, 845 adaptor.getSource(), constant); 846 847 // Potential insertion of resulting 1-D vector into array. 848 if (positionAttrs.size() > 1) { 849 inserted = rewriter.create<LLVM::InsertValueOp>( 850 loc, adaptor.getDest(), inserted, 851 LLVM::convertArrayToIndices(positionAttrs.drop_back())); 852 } 853 854 rewriter.replaceOp(insertOp, inserted); 855 return success(); 856 } 857 }; 858 859 /// Lower vector.scalable.insert ops to LLVM vector.insert 860 struct VectorScalableInsertOpLowering 861 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { 862 using ConvertOpToLLVMPattern< 863 vector::ScalableInsertOp>::ConvertOpToLLVMPattern; 864 865 LogicalResult 866 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, 867 ConversionPatternRewriter &rewriter) const override { 868 rewriter.replaceOpWithNewOp<LLVM::vector_insert>( 869 insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); 870 return success(); 871 } 872 }; 873 874 /// Lower vector.scalable.extract ops to LLVM vector.extract 875 struct VectorScalableExtractOpLowering 876 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { 877 using ConvertOpToLLVMPattern< 878 vector::ScalableExtractOp>::ConvertOpToLLVMPattern; 879 880 LogicalResult 881 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, 882 ConversionPatternRewriter &rewriter) const override { 883 rewriter.replaceOpWithNewOp<LLVM::vector_extract>( 884 extOp, typeConverter->convertType(extOp.getResultVectorType()), 885 adaptor.getSource(), adaptor.getPos()); 886 return success(); 887 } 888 }; 889 890 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 891 /// 892 /// Example: 893 /// ``` 894 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 895 /// ``` 896 /// is rewritten into: 897 /// ``` 898 /// %r = splat %f0: vector<2x4xf32> 899 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 900 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 901 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 902 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 903 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 904 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 905 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 906 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 907 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 908 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 909 /// // %r3 holds the final value. 910 /// ``` 911 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 912 public: 913 using OpRewritePattern<FMAOp>::OpRewritePattern; 914 915 void initialize() { 916 // This pattern recursively unpacks one dimension at a time. The recursion 917 // bounded as the rank is strictly decreasing. 918 setHasBoundedRewriteRecursion(); 919 } 920 921 LogicalResult matchAndRewrite(FMAOp op, 922 PatternRewriter &rewriter) const override { 923 auto vType = op.getVectorType(); 924 if (vType.getRank() < 2) 925 return failure(); 926 927 auto loc = op.getLoc(); 928 auto elemType = vType.getElementType(); 929 Value zero = rewriter.create<arith::ConstantOp>( 930 loc, elemType, rewriter.getZeroAttr(elemType)); 931 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); 932 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 933 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); 934 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); 935 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); 936 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 937 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 938 } 939 rewriter.replaceOp(op, desc); 940 return success(); 941 } 942 }; 943 944 /// Returns the strides if the memory underlying `memRefType` has a contiguous 945 /// static layout. 946 static llvm::Optional<SmallVector<int64_t, 4>> 947 computeContiguousStrides(MemRefType memRefType) { 948 int64_t offset; 949 SmallVector<int64_t, 4> strides; 950 if (failed(getStridesAndOffset(memRefType, strides, offset))) 951 return std::nullopt; 952 if (!strides.empty() && strides.back() != 1) 953 return std::nullopt; 954 // If no layout or identity layout, this is contiguous by definition. 955 if (memRefType.getLayout().isIdentity()) 956 return strides; 957 958 // Otherwise, we must determine contiguity form shapes. This can only ever 959 // work in static cases because MemRefType is underspecified to represent 960 // contiguous dynamic shapes in other ways than with just empty/identity 961 // layout. 962 auto sizes = memRefType.getShape(); 963 for (int index = 0, e = strides.size() - 1; index < e; ++index) { 964 if (ShapedType::isDynamic(sizes[index + 1]) || 965 ShapedType::isDynamic(strides[index]) || 966 ShapedType::isDynamic(strides[index + 1])) 967 return std::nullopt; 968 if (strides[index] != strides[index + 1] * sizes[index + 1]) 969 return std::nullopt; 970 } 971 return strides; 972 } 973 974 class VectorTypeCastOpConversion 975 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 976 public: 977 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 978 979 LogicalResult 980 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, 981 ConversionPatternRewriter &rewriter) const override { 982 auto loc = castOp->getLoc(); 983 MemRefType sourceMemRefType = 984 castOp.getOperand().getType().cast<MemRefType>(); 985 MemRefType targetMemRefType = castOp.getType(); 986 987 // Only static shape casts supported atm. 988 if (!sourceMemRefType.hasStaticShape() || 989 !targetMemRefType.hasStaticShape()) 990 return failure(); 991 992 auto llvmSourceDescriptorTy = 993 adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>(); 994 if (!llvmSourceDescriptorTy) 995 return failure(); 996 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); 997 998 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 999 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1000 if (!llvmTargetDescriptorTy) 1001 return failure(); 1002 1003 // Only contiguous source buffers supported atm. 1004 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1005 if (!sourceStrides) 1006 return failure(); 1007 auto targetStrides = computeContiguousStrides(targetMemRefType); 1008 if (!targetStrides) 1009 return failure(); 1010 // Only support static strides for now, regardless of contiguity. 1011 if (llvm::any_of(*targetStrides, ShapedType::isDynamic)) 1012 return failure(); 1013 1014 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1015 1016 // Create descriptor. 1017 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1018 Type llvmTargetElementTy = desc.getElementPtrType(); 1019 // Set allocated ptr. 1020 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1021 allocated = 1022 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1023 desc.setAllocatedPtr(rewriter, loc, allocated); 1024 // Set aligned ptr. 1025 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1026 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1027 desc.setAlignedPtr(rewriter, loc, ptr); 1028 // Fill offset 0. 1029 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1030 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1031 desc.setOffset(rewriter, loc, zero); 1032 1033 // Fill size and stride descriptors in memref. 1034 for (const auto &indexedSize : 1035 llvm::enumerate(targetMemRefType.getShape())) { 1036 int64_t index = indexedSize.index(); 1037 auto sizeAttr = 1038 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1039 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1040 desc.setSize(rewriter, loc, index, size); 1041 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1042 (*targetStrides)[index]); 1043 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1044 desc.setStride(rewriter, loc, index, stride); 1045 } 1046 1047 rewriter.replaceOp(castOp, {desc}); 1048 return success(); 1049 } 1050 }; 1051 1052 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). 1053 /// Non-scalable versions of this operation are handled in Vector Transforms. 1054 class VectorCreateMaskOpRewritePattern 1055 : public OpRewritePattern<vector::CreateMaskOp> { 1056 public: 1057 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, 1058 bool enableIndexOpt) 1059 : OpRewritePattern<vector::CreateMaskOp>(context), 1060 force32BitVectorIndices(enableIndexOpt) {} 1061 1062 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 1063 PatternRewriter &rewriter) const override { 1064 auto dstType = op.getType(); 1065 if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable()) 1066 return failure(); 1067 IntegerType idxType = 1068 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 1069 auto loc = op->getLoc(); 1070 Value indices = rewriter.create<LLVM::StepVectorOp>( 1071 loc, LLVM::getVectorType(idxType, dstType.getShape()[0], 1072 /*isScalable=*/true)); 1073 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, 1074 op.getOperand(0)); 1075 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 1076 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 1077 indices, bounds); 1078 rewriter.replaceOp(op, comp); 1079 return success(); 1080 } 1081 1082 private: 1083 const bool force32BitVectorIndices; 1084 }; 1085 1086 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1087 public: 1088 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1089 1090 // Proof-of-concept lowering implementation that relies on a small 1091 // runtime support library, which only needs to provide a few 1092 // printing methods (single value for all data types, opening/closing 1093 // bracket, comma, newline). The lowering fully unrolls a vector 1094 // in terms of these elementary printing operations. The advantage 1095 // of this approach is that the library can remain unaware of all 1096 // low-level implementation details of vectors while still supporting 1097 // output of any shaped and dimensioned vector. Due to full unrolling, 1098 // this approach is less suited for very large vectors though. 1099 // 1100 // TODO: rely solely on libc in future? something else? 1101 // 1102 LogicalResult 1103 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, 1104 ConversionPatternRewriter &rewriter) const override { 1105 Type printType = printOp.getPrintType(); 1106 1107 if (typeConverter->convertType(printType) == nullptr) 1108 return failure(); 1109 1110 // Make sure element type has runtime support. 1111 PrintConversion conversion = PrintConversion::None; 1112 VectorType vectorType = printType.dyn_cast<VectorType>(); 1113 Type eltType = vectorType ? vectorType.getElementType() : printType; 1114 Operation *printer; 1115 if (eltType.isF32()) { 1116 printer = 1117 LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>()); 1118 } else if (eltType.isF64()) { 1119 printer = 1120 LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>()); 1121 } else if (eltType.isIndex()) { 1122 printer = 1123 LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>()); 1124 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1125 // Integers need a zero or sign extension on the operand 1126 // (depending on the source type) as well as a signed or 1127 // unsigned print method. Up to 64-bit is supported. 1128 unsigned width = intTy.getWidth(); 1129 if (intTy.isUnsigned()) { 1130 if (width <= 64) { 1131 if (width < 64) 1132 conversion = PrintConversion::ZeroExt64; 1133 printer = LLVM::lookupOrCreatePrintU64Fn( 1134 printOp->getParentOfType<ModuleOp>()); 1135 } else { 1136 return failure(); 1137 } 1138 } else { 1139 assert(intTy.isSignless() || intTy.isSigned()); 1140 if (width <= 64) { 1141 // Note that we *always* zero extend booleans (1-bit integers), 1142 // so that true/false is printed as 1/0 rather than -1/0. 1143 if (width == 1) 1144 conversion = PrintConversion::ZeroExt64; 1145 else if (width < 64) 1146 conversion = PrintConversion::SignExt64; 1147 printer = LLVM::lookupOrCreatePrintI64Fn( 1148 printOp->getParentOfType<ModuleOp>()); 1149 } else { 1150 return failure(); 1151 } 1152 } 1153 } else { 1154 return failure(); 1155 } 1156 1157 // Unroll vector into elementary print calls. 1158 int64_t rank = vectorType ? vectorType.getRank() : 0; 1159 Type type = vectorType ? vectorType : eltType; 1160 emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, 1161 conversion); 1162 emitCall(rewriter, printOp->getLoc(), 1163 LLVM::lookupOrCreatePrintNewlineFn( 1164 printOp->getParentOfType<ModuleOp>())); 1165 rewriter.eraseOp(printOp); 1166 return success(); 1167 } 1168 1169 private: 1170 enum class PrintConversion { 1171 // clang-format off 1172 None, 1173 ZeroExt64, 1174 SignExt64 1175 // clang-format on 1176 }; 1177 1178 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1179 Value value, Type type, Operation *printer, int64_t rank, 1180 PrintConversion conversion) const { 1181 VectorType vectorType = type.dyn_cast<VectorType>(); 1182 Location loc = op->getLoc(); 1183 if (!vectorType) { 1184 assert(rank == 0 && "The scalar case expects rank == 0"); 1185 switch (conversion) { 1186 case PrintConversion::ZeroExt64: 1187 value = rewriter.create<arith::ExtUIOp>( 1188 loc, IntegerType::get(rewriter.getContext(), 64), value); 1189 break; 1190 case PrintConversion::SignExt64: 1191 value = rewriter.create<arith::ExtSIOp>( 1192 loc, IntegerType::get(rewriter.getContext(), 64), value); 1193 break; 1194 case PrintConversion::None: 1195 break; 1196 } 1197 emitCall(rewriter, loc, printer, value); 1198 return; 1199 } 1200 1201 emitCall(rewriter, loc, 1202 LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>())); 1203 Operation *printComma = 1204 LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>()); 1205 1206 if (rank <= 1) { 1207 auto reducedType = vectorType.getElementType(); 1208 auto llvmType = typeConverter->convertType(reducedType); 1209 int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); 1210 for (int64_t d = 0; d < dim; ++d) { 1211 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1212 llvmType, /*rank=*/0, /*pos=*/d); 1213 emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, 1214 conversion); 1215 if (d != dim - 1) 1216 emitCall(rewriter, loc, printComma); 1217 } 1218 emitCall( 1219 rewriter, loc, 1220 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1221 return; 1222 } 1223 1224 int64_t dim = vectorType.getDimSize(0); 1225 for (int64_t d = 0; d < dim; ++d) { 1226 auto reducedType = reducedVectorTypeFront(vectorType); 1227 auto llvmType = typeConverter->convertType(reducedType); 1228 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1229 llvmType, rank, d); 1230 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1231 conversion); 1232 if (d != dim - 1) 1233 emitCall(rewriter, loc, printComma); 1234 } 1235 emitCall(rewriter, loc, 1236 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1237 } 1238 1239 // Helper to emit a call. 1240 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1241 Operation *ref, ValueRange params = ValueRange()) { 1242 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), 1243 params); 1244 } 1245 }; 1246 1247 /// The Splat operation is lowered to an insertelement + a shufflevector 1248 /// operation. Splat to only 0-d and 1-d vector result types are lowered. 1249 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { 1250 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; 1251 1252 LogicalResult 1253 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, 1254 ConversionPatternRewriter &rewriter) const override { 1255 VectorType resultType = splatOp.getType().cast<VectorType>(); 1256 if (resultType.getRank() > 1) 1257 return failure(); 1258 1259 // First insert it into an undef vector so we can shuffle it. 1260 auto vectorType = typeConverter->convertType(splatOp.getType()); 1261 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); 1262 auto zero = rewriter.create<LLVM::ConstantOp>( 1263 splatOp.getLoc(), 1264 typeConverter->convertType(rewriter.getIntegerType(32)), 1265 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1266 1267 // For 0-d vector, we simply do `insertelement`. 1268 if (resultType.getRank() == 0) { 1269 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 1270 splatOp, vectorType, undef, adaptor.getInput(), zero); 1271 return success(); 1272 } 1273 1274 // For 1-d vector, we additionally do a `vectorshuffle`. 1275 auto v = rewriter.create<LLVM::InsertElementOp>( 1276 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); 1277 1278 int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0); 1279 SmallVector<int32_t> zeroValues(width, 0); 1280 1281 // Shuffle the value across the desired number of elements. 1282 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef, 1283 zeroValues); 1284 return success(); 1285 } 1286 }; 1287 1288 /// The Splat operation is lowered to an insertelement + a shufflevector 1289 /// operation. Splat to only 2+-d vector result types are lowered by the 1290 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. 1291 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { 1292 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; 1293 1294 LogicalResult 1295 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, 1296 ConversionPatternRewriter &rewriter) const override { 1297 VectorType resultType = splatOp.getType(); 1298 if (resultType.getRank() <= 1) 1299 return failure(); 1300 1301 // First insert it into an undef vector so we can shuffle it. 1302 auto loc = splatOp.getLoc(); 1303 auto vectorTypeInfo = 1304 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); 1305 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; 1306 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; 1307 if (!llvmNDVectorTy || !llvm1DVectorTy) 1308 return failure(); 1309 1310 // Construct returned value. 1311 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy); 1312 1313 // Construct a 1-D vector with the splatted value that we insert in all the 1314 // places within the returned descriptor. 1315 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy); 1316 auto zero = rewriter.create<LLVM::ConstantOp>( 1317 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 1318 rewriter.getZeroAttr(rewriter.getIntegerType(32))); 1319 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, 1320 adaptor.getInput(), zero); 1321 1322 // Shuffle the value across the desired number of elements. 1323 int64_t width = resultType.getDimSize(resultType.getRank() - 1); 1324 SmallVector<int32_t> zeroValues(width, 0); 1325 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); 1326 1327 // Iterate of linear index, convert to coords space and insert splatted 1-D 1328 // vector in each position. 1329 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { 1330 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); 1331 }); 1332 rewriter.replaceOp(splatOp, desc); 1333 return success(); 1334 } 1335 }; 1336 1337 } // namespace 1338 1339 /// Populate the given list with patterns that convert from Vector to LLVM. 1340 void mlir::populateVectorToLLVMConversionPatterns( 1341 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1342 bool reassociateFPReductions, bool force32BitVectorIndices) { 1343 MLIRContext *ctx = converter.getDialect()->getContext(); 1344 patterns.add<VectorFMAOpNDRewritePattern>(ctx); 1345 populateVectorInsertExtractStridedSliceTransforms(patterns); 1346 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); 1347 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices); 1348 patterns 1349 .add<VectorBitCastOpConversion, VectorShuffleOpConversion, 1350 VectorExtractElementOpConversion, VectorExtractOpConversion, 1351 VectorFMAOp1DConversion, VectorInsertElementOpConversion, 1352 VectorInsertOpConversion, VectorPrintOpConversion, 1353 VectorTypeCastOpConversion, VectorScaleOpConversion, 1354 VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>, 1355 VectorLoadStoreConversion<vector::MaskedLoadOp, 1356 vector::MaskedLoadOpAdaptor>, 1357 VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>, 1358 VectorLoadStoreConversion<vector::MaskedStoreOp, 1359 vector::MaskedStoreOpAdaptor>, 1360 VectorGatherOpConversion, VectorScatterOpConversion, 1361 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, 1362 VectorSplatOpLowering, VectorSplatNdOpLowering, 1363 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering>( 1364 converter); 1365 // Transfer ops with rank > 1 are handled by VectorToSCF. 1366 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 1367 } 1368 1369 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1370 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1371 patterns.add<VectorMatmulOpConversion>(converter); 1372 patterns.add<VectorFlatTransposeOpConversion>(converter); 1373 } 1374