1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Vector/VectorOps.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/MLIRContext.h" 21 #include "mlir/IR/Module.h" 22 #include "mlir/IR/Operation.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/IR/StandardTypes.h" 25 #include "mlir/IR/Types.h" 26 #include "mlir/Target/LLVMIR/TypeTranslation.h" 27 #include "mlir/Transforms/DialectConversion.h" 28 #include "mlir/Transforms/Passes.h" 29 #include "llvm/IR/DerivedTypes.h" 30 #include "llvm/IR/Module.h" 31 #include "llvm/IR/Type.h" 32 #include "llvm/Support/Allocator.h" 33 #include "llvm/Support/ErrorHandling.h" 34 35 using namespace mlir; 36 using namespace mlir::vector; 37 38 // Helper to reduce vector type by one rank at front. 39 static VectorType reducedVectorTypeFront(VectorType tp) { 40 assert((tp.getRank() > 1) && "unlowerable vector type"); 41 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 42 } 43 44 // Helper to reduce vector type by *all* but one rank at back. 45 static VectorType reducedVectorTypeBack(VectorType tp) { 46 assert((tp.getRank() > 1) && "unlowerable vector type"); 47 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 48 } 49 50 // Helper that picks the proper sequence for inserting. 51 static Value insertOne(ConversionPatternRewriter &rewriter, 52 LLVMTypeConverter &typeConverter, Location loc, 53 Value val1, Value val2, Type llvmType, int64_t rank, 54 int64_t pos) { 55 if (rank == 1) { 56 auto idxType = rewriter.getIndexType(); 57 auto constant = rewriter.create<LLVM::ConstantOp>( 58 loc, typeConverter.convertType(idxType), 59 rewriter.getIntegerAttr(idxType, pos)); 60 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 61 constant); 62 } 63 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 64 rewriter.getI64ArrayAttr(pos)); 65 } 66 67 // Helper that picks the proper sequence for inserting. 68 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 69 Value into, int64_t offset) { 70 auto vectorType = into.getType().cast<VectorType>(); 71 if (vectorType.getRank() > 1) 72 return rewriter.create<InsertOp>(loc, from, into, offset); 73 return rewriter.create<vector::InsertElementOp>( 74 loc, vectorType, from, into, 75 rewriter.create<ConstantIndexOp>(loc, offset)); 76 } 77 78 // Helper that picks the proper sequence for extracting. 79 static Value extractOne(ConversionPatternRewriter &rewriter, 80 LLVMTypeConverter &typeConverter, Location loc, 81 Value val, Type llvmType, int64_t rank, int64_t pos) { 82 if (rank == 1) { 83 auto idxType = rewriter.getIndexType(); 84 auto constant = rewriter.create<LLVM::ConstantOp>( 85 loc, typeConverter.convertType(idxType), 86 rewriter.getIntegerAttr(idxType, pos)); 87 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 88 constant); 89 } 90 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 91 rewriter.getI64ArrayAttr(pos)); 92 } 93 94 // Helper that picks the proper sequence for extracting. 95 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 96 int64_t offset) { 97 auto vectorType = vector.getType().cast<VectorType>(); 98 if (vectorType.getRank() > 1) 99 return rewriter.create<ExtractOp>(loc, vector, offset); 100 return rewriter.create<vector::ExtractElementOp>( 101 loc, vectorType.getElementType(), vector, 102 rewriter.create<ConstantIndexOp>(loc, offset)); 103 } 104 105 // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 106 // TODO: Better support for attribute subtype forwarding + slicing. 107 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 108 unsigned dropFront = 0, 109 unsigned dropBack = 0) { 110 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 111 auto range = arrayAttr.getAsRange<IntegerAttr>(); 112 SmallVector<int64_t, 4> res; 113 res.reserve(arrayAttr.size() - dropFront - dropBack); 114 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 115 it != eit; ++it) 116 res.push_back((*it).getValue().getSExtValue()); 117 return res; 118 } 119 120 // Helper that returns data layout alignment of an operation with memref. 121 template <typename T> 122 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op, 123 unsigned &align) { 124 Type elementTy = 125 typeConverter.convertType(op.getMemRefType().getElementType()); 126 if (!elementTy) 127 return failure(); 128 129 // TODO: this should use the MLIR data layout when it becomes available and 130 // stop depending on translation. 131 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 132 align = LLVM::TypeToLLVMIRTranslator(dialect->getLLVMContext()) 133 .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(), 134 dialect->getLLVMModule().getDataLayout()); 135 return success(); 136 } 137 138 // Helper that returns the base address of a memref. 139 LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc, 140 Value memref, MemRefType memRefType, Value &base) { 141 // Inspect stride and offset structure. 142 // 143 // TODO: flat memory only for now, generalize 144 // 145 int64_t offset; 146 SmallVector<int64_t, 4> strides; 147 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 148 if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || 149 offset != 0 || memRefType.getMemorySpace() != 0) 150 return failure(); 151 base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); 152 return success(); 153 } 154 155 // Helper that returns a pointer given a memref base. 156 LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc, 157 Value memref, MemRefType memRefType, Value &ptr) { 158 Value base; 159 if (failed(getBase(rewriter, loc, memref, memRefType, base))) 160 return failure(); 161 auto pType = MemRefDescriptor(memref).getElementType(); 162 ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base); 163 return success(); 164 } 165 166 // Helper that returns a bit-casted pointer given a memref base. 167 LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc, 168 Value memref, MemRefType memRefType, Type type, 169 Value &ptr) { 170 Value base; 171 if (failed(getBase(rewriter, loc, memref, memRefType, base))) 172 return failure(); 173 auto pType = type.template cast<LLVM::LLVMType>().getPointerTo(); 174 base = rewriter.create<LLVM::BitcastOp>(loc, pType, base); 175 ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base); 176 return success(); 177 } 178 179 // Helper that returns vector of pointers given a memref base and an index 180 // vector. 181 LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, 182 Value memref, Value indices, MemRefType memRefType, 183 VectorType vType, Type iType, Value &ptrs) { 184 Value base; 185 if (failed(getBase(rewriter, loc, memref, memRefType, base))) 186 return failure(); 187 auto pType = MemRefDescriptor(memref).getElementType(); 188 auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0)); 189 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices); 190 return success(); 191 } 192 193 static LogicalResult 194 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 195 LLVMTypeConverter &typeConverter, Location loc, 196 TransferReadOp xferOp, 197 ArrayRef<Value> operands, Value dataPtr) { 198 unsigned align; 199 if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 200 return failure(); 201 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align); 202 return success(); 203 } 204 205 static LogicalResult 206 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 207 LLVMTypeConverter &typeConverter, Location loc, 208 TransferReadOp xferOp, ArrayRef<Value> operands, 209 Value dataPtr, Value mask) { 210 auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 211 VectorType fillType = xferOp.getVectorType(); 212 Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 213 fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill); 214 215 Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 216 if (!vecTy) 217 return failure(); 218 219 unsigned align; 220 if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 221 return failure(); 222 223 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 224 xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 225 rewriter.getI32IntegerAttr(align)); 226 return success(); 227 } 228 229 static LogicalResult 230 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 231 LLVMTypeConverter &typeConverter, Location loc, 232 TransferWriteOp xferOp, 233 ArrayRef<Value> operands, Value dataPtr) { 234 unsigned align; 235 if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 236 return failure(); 237 auto adaptor = TransferWriteOpAdaptor(operands); 238 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr, 239 align); 240 return success(); 241 } 242 243 static LogicalResult 244 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 245 LLVMTypeConverter &typeConverter, Location loc, 246 TransferWriteOp xferOp, ArrayRef<Value> operands, 247 Value dataPtr, Value mask) { 248 unsigned align; 249 if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 250 return failure(); 251 252 auto adaptor = TransferWriteOpAdaptor(operands); 253 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 254 xferOp, adaptor.vector(), dataPtr, mask, 255 rewriter.getI32IntegerAttr(align)); 256 return success(); 257 } 258 259 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 260 ArrayRef<Value> operands) { 261 return TransferReadOpAdaptor(operands); 262 } 263 264 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 265 ArrayRef<Value> operands) { 266 return TransferWriteOpAdaptor(operands); 267 } 268 269 namespace { 270 271 /// Conversion pattern for a vector.matrix_multiply. 272 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 273 class VectorMatmulOpConversion : public ConvertToLLVMPattern { 274 public: 275 explicit VectorMatmulOpConversion(MLIRContext *context, 276 LLVMTypeConverter &typeConverter) 277 : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, 278 typeConverter) {} 279 280 LogicalResult 281 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 282 ConversionPatternRewriter &rewriter) const override { 283 auto matmulOp = cast<vector::MatmulOp>(op); 284 auto adaptor = vector::MatmulOpAdaptor(operands); 285 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 286 op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), 287 adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), 288 matmulOp.rhs_columns()); 289 return success(); 290 } 291 }; 292 293 /// Conversion pattern for a vector.flat_transpose. 294 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 295 class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern { 296 public: 297 explicit VectorFlatTransposeOpConversion(MLIRContext *context, 298 LLVMTypeConverter &typeConverter) 299 : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(), 300 context, typeConverter) {} 301 302 LogicalResult 303 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 304 ConversionPatternRewriter &rewriter) const override { 305 auto transOp = cast<vector::FlatTransposeOp>(op); 306 auto adaptor = vector::FlatTransposeOpAdaptor(operands); 307 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 308 transOp, typeConverter.convertType(transOp.res().getType()), 309 adaptor.matrix(), transOp.rows(), transOp.columns()); 310 return success(); 311 } 312 }; 313 314 /// Conversion pattern for a vector.maskedload. 315 class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern { 316 public: 317 explicit VectorMaskedLoadOpConversion(MLIRContext *context, 318 LLVMTypeConverter &typeConverter) 319 : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context, 320 typeConverter) {} 321 322 LogicalResult 323 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 324 ConversionPatternRewriter &rewriter) const override { 325 auto loc = op->getLoc(); 326 auto load = cast<vector::MaskedLoadOp>(op); 327 auto adaptor = vector::MaskedLoadOpAdaptor(operands); 328 329 // Resolve alignment. 330 unsigned align; 331 if (failed(getMemRefAlignment(typeConverter, load, align))) 332 return failure(); 333 334 auto vtype = typeConverter.convertType(load.getResultVectorType()); 335 Value ptr; 336 if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(), 337 vtype, ptr))) 338 return failure(); 339 340 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 341 load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(), 342 rewriter.getI32IntegerAttr(align)); 343 return success(); 344 } 345 }; 346 347 /// Conversion pattern for a vector.maskedstore. 348 class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern { 349 public: 350 explicit VectorMaskedStoreOpConversion(MLIRContext *context, 351 LLVMTypeConverter &typeConverter) 352 : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context, 353 typeConverter) {} 354 355 LogicalResult 356 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 357 ConversionPatternRewriter &rewriter) const override { 358 auto loc = op->getLoc(); 359 auto store = cast<vector::MaskedStoreOp>(op); 360 auto adaptor = vector::MaskedStoreOpAdaptor(operands); 361 362 // Resolve alignment. 363 unsigned align; 364 if (failed(getMemRefAlignment(typeConverter, store, align))) 365 return failure(); 366 367 auto vtype = typeConverter.convertType(store.getValueVectorType()); 368 Value ptr; 369 if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(), 370 vtype, ptr))) 371 return failure(); 372 373 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 374 store, adaptor.value(), ptr, adaptor.mask(), 375 rewriter.getI32IntegerAttr(align)); 376 return success(); 377 } 378 }; 379 380 /// Conversion pattern for a vector.gather. 381 class VectorGatherOpConversion : public ConvertToLLVMPattern { 382 public: 383 explicit VectorGatherOpConversion(MLIRContext *context, 384 LLVMTypeConverter &typeConverter) 385 : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context, 386 typeConverter) {} 387 388 LogicalResult 389 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 390 ConversionPatternRewriter &rewriter) const override { 391 auto loc = op->getLoc(); 392 auto gather = cast<vector::GatherOp>(op); 393 auto adaptor = vector::GatherOpAdaptor(operands); 394 395 // Resolve alignment. 396 unsigned align; 397 if (failed(getMemRefAlignment(typeConverter, gather, align))) 398 return failure(); 399 400 // Get index ptrs. 401 VectorType vType = gather.getResultVectorType(); 402 Type iType = gather.getIndicesVectorType().getElementType(); 403 Value ptrs; 404 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 405 gather.getMemRefType(), vType, iType, ptrs))) 406 return failure(); 407 408 // Replace with the gather intrinsic. 409 ValueRange v = (llvm::size(adaptor.pass_thru()) == 0) ? ValueRange({}) 410 : adaptor.pass_thru(); 411 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 412 gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), v, 413 rewriter.getI32IntegerAttr(align)); 414 return success(); 415 } 416 }; 417 418 /// Conversion pattern for a vector.scatter. 419 class VectorScatterOpConversion : public ConvertToLLVMPattern { 420 public: 421 explicit VectorScatterOpConversion(MLIRContext *context, 422 LLVMTypeConverter &typeConverter) 423 : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context, 424 typeConverter) {} 425 426 LogicalResult 427 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 428 ConversionPatternRewriter &rewriter) const override { 429 auto loc = op->getLoc(); 430 auto scatter = cast<vector::ScatterOp>(op); 431 auto adaptor = vector::ScatterOpAdaptor(operands); 432 433 // Resolve alignment. 434 unsigned align; 435 if (failed(getMemRefAlignment(typeConverter, scatter, align))) 436 return failure(); 437 438 // Get index ptrs. 439 VectorType vType = scatter.getValueVectorType(); 440 Type iType = scatter.getIndicesVectorType().getElementType(); 441 Value ptrs; 442 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 443 scatter.getMemRefType(), vType, iType, ptrs))) 444 return failure(); 445 446 // Replace with the scatter intrinsic. 447 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 448 scatter, adaptor.value(), ptrs, adaptor.mask(), 449 rewriter.getI32IntegerAttr(align)); 450 return success(); 451 } 452 }; 453 454 /// Conversion pattern for a vector.expandload. 455 class VectorExpandLoadOpConversion : public ConvertToLLVMPattern { 456 public: 457 explicit VectorExpandLoadOpConversion(MLIRContext *context, 458 LLVMTypeConverter &typeConverter) 459 : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context, 460 typeConverter) {} 461 462 LogicalResult 463 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 464 ConversionPatternRewriter &rewriter) const override { 465 auto loc = op->getLoc(); 466 auto expand = cast<vector::ExpandLoadOp>(op); 467 auto adaptor = vector::ExpandLoadOpAdaptor(operands); 468 469 Value ptr; 470 if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(), 471 ptr))) 472 return failure(); 473 474 auto vType = expand.getResultVectorType(); 475 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 476 op, typeConverter.convertType(vType), ptr, adaptor.mask(), 477 adaptor.pass_thru()); 478 return success(); 479 } 480 }; 481 482 /// Conversion pattern for a vector.compressstore. 483 class VectorCompressStoreOpConversion : public ConvertToLLVMPattern { 484 public: 485 explicit VectorCompressStoreOpConversion(MLIRContext *context, 486 LLVMTypeConverter &typeConverter) 487 : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(), 488 context, typeConverter) {} 489 490 LogicalResult 491 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 492 ConversionPatternRewriter &rewriter) const override { 493 auto loc = op->getLoc(); 494 auto compress = cast<vector::CompressStoreOp>(op); 495 auto adaptor = vector::CompressStoreOpAdaptor(operands); 496 497 Value ptr; 498 if (failed(getBasePtr(rewriter, loc, adaptor.base(), 499 compress.getMemRefType(), ptr))) 500 return failure(); 501 502 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 503 op, adaptor.value(), ptr, adaptor.mask()); 504 return success(); 505 } 506 }; 507 508 /// Conversion pattern for all vector reductions. 509 class VectorReductionOpConversion : public ConvertToLLVMPattern { 510 public: 511 explicit VectorReductionOpConversion(MLIRContext *context, 512 LLVMTypeConverter &typeConverter, 513 bool reassociateFP) 514 : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, 515 typeConverter), 516 reassociateFPReductions(reassociateFP) {} 517 518 LogicalResult 519 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 520 ConversionPatternRewriter &rewriter) const override { 521 auto reductionOp = cast<vector::ReductionOp>(op); 522 auto kind = reductionOp.kind(); 523 Type eltType = reductionOp.dest().getType(); 524 Type llvmType = typeConverter.convertType(eltType); 525 if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) { 526 // Integer reductions: add/mul/min/max/and/or/xor. 527 if (kind == "add") 528 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>( 529 op, llvmType, operands[0]); 530 else if (kind == "mul") 531 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>( 532 op, llvmType, operands[0]); 533 else if (kind == "min") 534 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>( 535 op, llvmType, operands[0]); 536 else if (kind == "max") 537 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>( 538 op, llvmType, operands[0]); 539 else if (kind == "and") 540 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>( 541 op, llvmType, operands[0]); 542 else if (kind == "or") 543 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>( 544 op, llvmType, operands[0]); 545 else if (kind == "xor") 546 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>( 547 op, llvmType, operands[0]); 548 else 549 return failure(); 550 return success(); 551 552 } else if (eltType.isF32() || eltType.isF64()) { 553 // Floating-point reductions: add/mul/min/max 554 if (kind == "add") { 555 // Optional accumulator (or zero). 556 Value acc = operands.size() > 1 ? operands[1] 557 : rewriter.create<LLVM::ConstantOp>( 558 op->getLoc(), llvmType, 559 rewriter.getZeroAttr(eltType)); 560 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>( 561 op, llvmType, acc, operands[0], 562 rewriter.getBoolAttr(reassociateFPReductions)); 563 } else if (kind == "mul") { 564 // Optional accumulator (or one). 565 Value acc = operands.size() > 1 566 ? operands[1] 567 : rewriter.create<LLVM::ConstantOp>( 568 op->getLoc(), llvmType, 569 rewriter.getFloatAttr(eltType, 1.0)); 570 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>( 571 op, llvmType, acc, operands[0], 572 rewriter.getBoolAttr(reassociateFPReductions)); 573 } else if (kind == "min") 574 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>( 575 op, llvmType, operands[0]); 576 else if (kind == "max") 577 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>( 578 op, llvmType, operands[0]); 579 else 580 return failure(); 581 return success(); 582 } 583 return failure(); 584 } 585 586 private: 587 const bool reassociateFPReductions; 588 }; 589 590 class VectorShuffleOpConversion : public ConvertToLLVMPattern { 591 public: 592 explicit VectorShuffleOpConversion(MLIRContext *context, 593 LLVMTypeConverter &typeConverter) 594 : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context, 595 typeConverter) {} 596 597 LogicalResult 598 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 599 ConversionPatternRewriter &rewriter) const override { 600 auto loc = op->getLoc(); 601 auto adaptor = vector::ShuffleOpAdaptor(operands); 602 auto shuffleOp = cast<vector::ShuffleOp>(op); 603 auto v1Type = shuffleOp.getV1VectorType(); 604 auto v2Type = shuffleOp.getV2VectorType(); 605 auto vectorType = shuffleOp.getVectorType(); 606 Type llvmType = typeConverter.convertType(vectorType); 607 auto maskArrayAttr = shuffleOp.mask(); 608 609 // Bail if result type cannot be lowered. 610 if (!llvmType) 611 return failure(); 612 613 // Get rank and dimension sizes. 614 int64_t rank = vectorType.getRank(); 615 assert(v1Type.getRank() == rank); 616 assert(v2Type.getRank() == rank); 617 int64_t v1Dim = v1Type.getDimSize(0); 618 619 // For rank 1, where both operands have *exactly* the same vector type, 620 // there is direct shuffle support in LLVM. Use it! 621 if (rank == 1 && v1Type == v2Type) { 622 Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>( 623 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 624 rewriter.replaceOp(op, shuffle); 625 return success(); 626 } 627 628 // For all other cases, insert the individual values individually. 629 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 630 int64_t insPos = 0; 631 for (auto en : llvm::enumerate(maskArrayAttr)) { 632 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 633 Value value = adaptor.v1(); 634 if (extPos >= v1Dim) { 635 extPos -= v1Dim; 636 value = adaptor.v2(); 637 } 638 Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType, 639 rank, extPos); 640 insert = insertOne(rewriter, typeConverter, loc, insert, extract, 641 llvmType, rank, insPos++); 642 } 643 rewriter.replaceOp(op, insert); 644 return success(); 645 } 646 }; 647 648 class VectorExtractElementOpConversion : public ConvertToLLVMPattern { 649 public: 650 explicit VectorExtractElementOpConversion(MLIRContext *context, 651 LLVMTypeConverter &typeConverter) 652 : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(), 653 context, typeConverter) {} 654 655 LogicalResult 656 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 657 ConversionPatternRewriter &rewriter) const override { 658 auto adaptor = vector::ExtractElementOpAdaptor(operands); 659 auto extractEltOp = cast<vector::ExtractElementOp>(op); 660 auto vectorType = extractEltOp.getVectorType(); 661 auto llvmType = typeConverter.convertType(vectorType.getElementType()); 662 663 // Bail if result type cannot be lowered. 664 if (!llvmType) 665 return failure(); 666 667 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 668 op, llvmType, adaptor.vector(), adaptor.position()); 669 return success(); 670 } 671 }; 672 673 class VectorExtractOpConversion : public ConvertToLLVMPattern { 674 public: 675 explicit VectorExtractOpConversion(MLIRContext *context, 676 LLVMTypeConverter &typeConverter) 677 : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context, 678 typeConverter) {} 679 680 LogicalResult 681 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 682 ConversionPatternRewriter &rewriter) const override { 683 auto loc = op->getLoc(); 684 auto adaptor = vector::ExtractOpAdaptor(operands); 685 auto extractOp = cast<vector::ExtractOp>(op); 686 auto vectorType = extractOp.getVectorType(); 687 auto resultType = extractOp.getResult().getType(); 688 auto llvmResultType = typeConverter.convertType(resultType); 689 auto positionArrayAttr = extractOp.position(); 690 691 // Bail if result type cannot be lowered. 692 if (!llvmResultType) 693 return failure(); 694 695 // One-shot extraction of vector from array (only requires extractvalue). 696 if (resultType.isa<VectorType>()) { 697 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 698 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 699 rewriter.replaceOp(op, extracted); 700 return success(); 701 } 702 703 // Potential extraction of 1-D vector from array. 704 auto *context = op->getContext(); 705 Value extracted = adaptor.vector(); 706 auto positionAttrs = positionArrayAttr.getValue(); 707 if (positionAttrs.size() > 1) { 708 auto oneDVectorType = reducedVectorTypeBack(vectorType); 709 auto nMinusOnePositionAttrs = 710 ArrayAttr::get(positionAttrs.drop_back(), context); 711 extracted = rewriter.create<LLVM::ExtractValueOp>( 712 loc, typeConverter.convertType(oneDVectorType), extracted, 713 nMinusOnePositionAttrs); 714 } 715 716 // Remaining extraction of element from 1-D LLVM vector 717 auto position = positionAttrs.back().cast<IntegerAttr>(); 718 auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 719 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 720 extracted = 721 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 722 rewriter.replaceOp(op, extracted); 723 724 return success(); 725 } 726 }; 727 728 /// Conversion pattern that turns a vector.fma on a 1-D vector 729 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 730 /// This does not match vectors of n >= 2 rank. 731 /// 732 /// Example: 733 /// ``` 734 /// vector.fma %a, %a, %a : vector<8xf32> 735 /// ``` 736 /// is converted to: 737 /// ``` 738 /// llvm.intr.fmuladd %va, %va, %va: 739 /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) 740 /// -> !llvm<"<8 x float>"> 741 /// ``` 742 class VectorFMAOp1DConversion : public ConvertToLLVMPattern { 743 public: 744 explicit VectorFMAOp1DConversion(MLIRContext *context, 745 LLVMTypeConverter &typeConverter) 746 : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context, 747 typeConverter) {} 748 749 LogicalResult 750 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 751 ConversionPatternRewriter &rewriter) const override { 752 auto adaptor = vector::FMAOpAdaptor(operands); 753 vector::FMAOp fmaOp = cast<vector::FMAOp>(op); 754 VectorType vType = fmaOp.getVectorType(); 755 if (vType.getRank() != 1) 756 return failure(); 757 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(), 758 adaptor.rhs(), adaptor.acc()); 759 return success(); 760 } 761 }; 762 763 class VectorInsertElementOpConversion : public ConvertToLLVMPattern { 764 public: 765 explicit VectorInsertElementOpConversion(MLIRContext *context, 766 LLVMTypeConverter &typeConverter) 767 : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(), 768 context, typeConverter) {} 769 770 LogicalResult 771 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 772 ConversionPatternRewriter &rewriter) const override { 773 auto adaptor = vector::InsertElementOpAdaptor(operands); 774 auto insertEltOp = cast<vector::InsertElementOp>(op); 775 auto vectorType = insertEltOp.getDestVectorType(); 776 auto llvmType = typeConverter.convertType(vectorType); 777 778 // Bail if result type cannot be lowered. 779 if (!llvmType) 780 return failure(); 781 782 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 783 op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); 784 return success(); 785 } 786 }; 787 788 class VectorInsertOpConversion : public ConvertToLLVMPattern { 789 public: 790 explicit VectorInsertOpConversion(MLIRContext *context, 791 LLVMTypeConverter &typeConverter) 792 : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context, 793 typeConverter) {} 794 795 LogicalResult 796 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 797 ConversionPatternRewriter &rewriter) const override { 798 auto loc = op->getLoc(); 799 auto adaptor = vector::InsertOpAdaptor(operands); 800 auto insertOp = cast<vector::InsertOp>(op); 801 auto sourceType = insertOp.getSourceType(); 802 auto destVectorType = insertOp.getDestVectorType(); 803 auto llvmResultType = typeConverter.convertType(destVectorType); 804 auto positionArrayAttr = insertOp.position(); 805 806 // Bail if result type cannot be lowered. 807 if (!llvmResultType) 808 return failure(); 809 810 // One-shot insertion of a vector into an array (only requires insertvalue). 811 if (sourceType.isa<VectorType>()) { 812 Value inserted = rewriter.create<LLVM::InsertValueOp>( 813 loc, llvmResultType, adaptor.dest(), adaptor.source(), 814 positionArrayAttr); 815 rewriter.replaceOp(op, inserted); 816 return success(); 817 } 818 819 // Potential extraction of 1-D vector from array. 820 auto *context = op->getContext(); 821 Value extracted = adaptor.dest(); 822 auto positionAttrs = positionArrayAttr.getValue(); 823 auto position = positionAttrs.back().cast<IntegerAttr>(); 824 auto oneDVectorType = destVectorType; 825 if (positionAttrs.size() > 1) { 826 oneDVectorType = reducedVectorTypeBack(destVectorType); 827 auto nMinusOnePositionAttrs = 828 ArrayAttr::get(positionAttrs.drop_back(), context); 829 extracted = rewriter.create<LLVM::ExtractValueOp>( 830 loc, typeConverter.convertType(oneDVectorType), extracted, 831 nMinusOnePositionAttrs); 832 } 833 834 // Insertion of an element into a 1-D LLVM vector. 835 auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 836 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 837 Value inserted = rewriter.create<LLVM::InsertElementOp>( 838 loc, typeConverter.convertType(oneDVectorType), extracted, 839 adaptor.source(), constant); 840 841 // Potential insertion of resulting 1-D vector into array. 842 if (positionAttrs.size() > 1) { 843 auto nMinusOnePositionAttrs = 844 ArrayAttr::get(positionAttrs.drop_back(), context); 845 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 846 adaptor.dest(), inserted, 847 nMinusOnePositionAttrs); 848 } 849 850 rewriter.replaceOp(op, inserted); 851 return success(); 852 } 853 }; 854 855 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 856 /// 857 /// Example: 858 /// ``` 859 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 860 /// ``` 861 /// is rewritten into: 862 /// ``` 863 /// %r = splat %f0: vector<2x4xf32> 864 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 865 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 866 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 867 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 868 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 869 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 870 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 871 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 872 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 873 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 874 /// // %r3 holds the final value. 875 /// ``` 876 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 877 public: 878 using OpRewritePattern<FMAOp>::OpRewritePattern; 879 880 LogicalResult matchAndRewrite(FMAOp op, 881 PatternRewriter &rewriter) const override { 882 auto vType = op.getVectorType(); 883 if (vType.getRank() < 2) 884 return failure(); 885 886 auto loc = op.getLoc(); 887 auto elemType = vType.getElementType(); 888 Value zero = rewriter.create<ConstantOp>(loc, elemType, 889 rewriter.getZeroAttr(elemType)); 890 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 891 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 892 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 893 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 894 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 895 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 896 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 897 } 898 rewriter.replaceOp(op, desc); 899 return success(); 900 } 901 }; 902 903 // When ranks are different, InsertStridedSlice needs to extract a properly 904 // ranked vector from the destination vector into which to insert. This pattern 905 // only takes care of this part and forwards the rest of the conversion to 906 // another pattern that converts InsertStridedSlice for operands of the same 907 // rank. 908 // 909 // RewritePattern for InsertStridedSliceOp where source and destination vectors 910 // have different ranks. In this case: 911 // 1. the proper subvector is extracted from the destination vector 912 // 2. a new InsertStridedSlice op is created to insert the source in the 913 // destination subvector 914 // 3. the destination subvector is inserted back in the proper place 915 // 4. the op is replaced by the result of step 3. 916 // The new InsertStridedSlice from step 2. will be picked up by a 917 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 918 class VectorInsertStridedSliceOpDifferentRankRewritePattern 919 : public OpRewritePattern<InsertStridedSliceOp> { 920 public: 921 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 922 923 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 924 PatternRewriter &rewriter) const override { 925 auto srcType = op.getSourceVectorType(); 926 auto dstType = op.getDestVectorType(); 927 928 if (op.offsets().getValue().empty()) 929 return failure(); 930 931 auto loc = op.getLoc(); 932 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 933 assert(rankDiff >= 0); 934 if (rankDiff == 0) 935 return failure(); 936 937 int64_t rankRest = dstType.getRank() - rankDiff; 938 // Extract / insert the subvector of matching rank and InsertStridedSlice 939 // on it. 940 Value extracted = 941 rewriter.create<ExtractOp>(loc, op.dest(), 942 getI64SubArray(op.offsets(), /*dropFront=*/0, 943 /*dropFront=*/rankRest)); 944 // A different pattern will kick in for InsertStridedSlice with matching 945 // ranks. 946 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 947 loc, op.source(), extracted, 948 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 949 getI64SubArray(op.strides(), /*dropFront=*/0)); 950 rewriter.replaceOpWithNewOp<InsertOp>( 951 op, stridedSliceInnerOp.getResult(), op.dest(), 952 getI64SubArray(op.offsets(), /*dropFront=*/0, 953 /*dropFront=*/rankRest)); 954 return success(); 955 } 956 }; 957 958 // RewritePattern for InsertStridedSliceOp where source and destination vectors 959 // have the same rank. In this case, we reduce 960 // 1. the proper subvector is extracted from the destination vector 961 // 2. a new InsertStridedSlice op is created to insert the source in the 962 // destination subvector 963 // 3. the destination subvector is inserted back in the proper place 964 // 4. the op is replaced by the result of step 3. 965 // The new InsertStridedSlice from step 2. will be picked up by a 966 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 967 class VectorInsertStridedSliceOpSameRankRewritePattern 968 : public OpRewritePattern<InsertStridedSliceOp> { 969 public: 970 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 971 972 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 973 PatternRewriter &rewriter) const override { 974 auto srcType = op.getSourceVectorType(); 975 auto dstType = op.getDestVectorType(); 976 977 if (op.offsets().getValue().empty()) 978 return failure(); 979 980 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 981 assert(rankDiff >= 0); 982 if (rankDiff != 0) 983 return failure(); 984 985 if (srcType == dstType) { 986 rewriter.replaceOp(op, op.source()); 987 return success(); 988 } 989 990 int64_t offset = 991 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 992 int64_t size = srcType.getShape().front(); 993 int64_t stride = 994 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 995 996 auto loc = op.getLoc(); 997 Value res = op.dest(); 998 // For each slice of the source vector along the most major dimension. 999 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1000 off += stride, ++idx) { 1001 // 1. extract the proper subvector (or element) from source 1002 Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 1003 if (extractedSource.getType().isa<VectorType>()) { 1004 // 2. If we have a vector, extract the proper subvector from destination 1005 // Otherwise we are at the element level and no need to recurse. 1006 Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 1007 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 1008 // smaller rank. 1009 extractedSource = rewriter.create<InsertStridedSliceOp>( 1010 loc, extractedSource, extractedDest, 1011 getI64SubArray(op.offsets(), /* dropFront=*/1), 1012 getI64SubArray(op.strides(), /* dropFront=*/1)); 1013 } 1014 // 4. Insert the extractedSource into the res vector. 1015 res = insertOne(rewriter, loc, extractedSource, res, off); 1016 } 1017 1018 rewriter.replaceOp(op, res); 1019 return success(); 1020 } 1021 /// This pattern creates recursive InsertStridedSliceOp, but the recursion is 1022 /// bounded as the rank is strictly decreasing. 1023 bool hasBoundedRewriteRecursion() const final { return true; } 1024 }; 1025 1026 class VectorTypeCastOpConversion : public ConvertToLLVMPattern { 1027 public: 1028 explicit VectorTypeCastOpConversion(MLIRContext *context, 1029 LLVMTypeConverter &typeConverter) 1030 : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context, 1031 typeConverter) {} 1032 1033 LogicalResult 1034 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1035 ConversionPatternRewriter &rewriter) const override { 1036 auto loc = op->getLoc(); 1037 vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); 1038 MemRefType sourceMemRefType = 1039 castOp.getOperand().getType().cast<MemRefType>(); 1040 MemRefType targetMemRefType = 1041 castOp.getResult().getType().cast<MemRefType>(); 1042 1043 // Only static shape casts supported atm. 1044 if (!sourceMemRefType.hasStaticShape() || 1045 !targetMemRefType.hasStaticShape()) 1046 return failure(); 1047 1048 auto llvmSourceDescriptorTy = 1049 operands[0].getType().dyn_cast<LLVM::LLVMType>(); 1050 if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 1051 return failure(); 1052 MemRefDescriptor sourceMemRef(operands[0]); 1053 1054 auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) 1055 .dyn_cast_or_null<LLVM::LLVMType>(); 1056 if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 1057 return failure(); 1058 1059 int64_t offset; 1060 SmallVector<int64_t, 4> strides; 1061 auto successStrides = 1062 getStridesAndOffset(sourceMemRefType, strides, offset); 1063 bool isContiguous = (strides.back() == 1); 1064 if (isContiguous) { 1065 auto sizes = sourceMemRefType.getShape(); 1066 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 1067 if (strides[index] != strides[index + 1] * sizes[index + 1]) { 1068 isContiguous = false; 1069 break; 1070 } 1071 } 1072 } 1073 // Only contiguous source tensors supported atm. 1074 if (failed(successStrides) || !isContiguous) 1075 return failure(); 1076 1077 auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 1078 1079 // Create descriptor. 1080 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1081 Type llvmTargetElementTy = desc.getElementType(); 1082 // Set allocated ptr. 1083 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1084 allocated = 1085 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1086 desc.setAllocatedPtr(rewriter, loc, allocated); 1087 // Set aligned ptr. 1088 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1089 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1090 desc.setAlignedPtr(rewriter, loc, ptr); 1091 // Fill offset 0. 1092 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1093 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1094 desc.setOffset(rewriter, loc, zero); 1095 1096 // Fill size and stride descriptors in memref. 1097 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 1098 int64_t index = indexedSize.index(); 1099 auto sizeAttr = 1100 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1101 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1102 desc.setSize(rewriter, loc, index, size); 1103 auto strideAttr = 1104 rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); 1105 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1106 desc.setStride(rewriter, loc, index, stride); 1107 } 1108 1109 rewriter.replaceOp(op, {desc}); 1110 return success(); 1111 } 1112 }; 1113 1114 /// Conversion pattern that converts a 1-D vector transfer read/write op in a 1115 /// sequence of: 1116 /// 1. Bitcast or addrspacecast to vector form. 1117 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1118 /// 3. Create a mask where offsetVector is compared against memref upper bound. 1119 /// 4. Rewrite op as a masked read or write. 1120 template <typename ConcreteOp> 1121 class VectorTransferConversion : public ConvertToLLVMPattern { 1122 public: 1123 explicit VectorTransferConversion(MLIRContext *context, 1124 LLVMTypeConverter &typeConv) 1125 : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, 1126 typeConv) {} 1127 1128 LogicalResult 1129 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1130 ConversionPatternRewriter &rewriter) const override { 1131 auto xferOp = cast<ConcreteOp>(op); 1132 auto adaptor = getTransferOpAdapter(xferOp, operands); 1133 1134 if (xferOp.getVectorType().getRank() > 1 || 1135 llvm::size(xferOp.indices()) == 0) 1136 return failure(); 1137 if (xferOp.permutation_map() != 1138 AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 1139 xferOp.getVectorType().getRank(), 1140 op->getContext())) 1141 return failure(); 1142 1143 auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 1144 1145 Location loc = op->getLoc(); 1146 Type i64Type = rewriter.getIntegerType(64); 1147 MemRefType memRefType = xferOp.getMemRefType(); 1148 1149 // 1. Get the source/dst address as an LLVM vector pointer. 1150 // The vector pointer would always be on address space 0, therefore 1151 // addrspacecast shall be used when source/dst memrefs are not on 1152 // address space 0. 1153 // TODO: support alignment when possible. 1154 Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), 1155 adaptor.indices(), rewriter, getModule()); 1156 auto vecTy = 1157 toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>(); 1158 Value vectorDataPtr; 1159 if (memRefType.getMemorySpace() == 0) 1160 vectorDataPtr = 1161 rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr); 1162 else 1163 vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 1164 loc, vecTy.getPointerTo(), dataPtr); 1165 1166 if (!xferOp.isMaskedDim(0)) 1167 return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc, 1168 xferOp, operands, vectorDataPtr); 1169 1170 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1171 unsigned vecWidth = vecTy.getVectorNumElements(); 1172 VectorType vectorCmpType = VectorType::get(vecWidth, i64Type); 1173 SmallVector<int64_t, 8> indices; 1174 indices.reserve(vecWidth); 1175 for (unsigned i = 0; i < vecWidth; ++i) 1176 indices.push_back(i); 1177 Value linearIndices = rewriter.create<ConstantOp>( 1178 loc, vectorCmpType, 1179 DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices))); 1180 linearIndices = rewriter.create<LLVM::DialectCastOp>( 1181 loc, toLLVMTy(vectorCmpType), linearIndices); 1182 1183 // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1184 // TODO: when the leaf transfer rank is k > 1 we need the last 1185 // `k` dimensions here. 1186 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 1187 Value offsetIndex = *(xferOp.indices().begin() + lastIndex); 1188 offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex); 1189 Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex); 1190 Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices); 1191 1192 // 4. Let dim the memref dimension, compute the vector comparison mask: 1193 // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 1194 Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex); 1195 dim = rewriter.create<IndexCastOp>(loc, i64Type, dim); 1196 dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim); 1197 Value mask = 1198 rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim); 1199 mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()), 1200 mask); 1201 1202 // 5. Rewrite as a masked read / write. 1203 return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp, 1204 operands, vectorDataPtr, mask); 1205 } 1206 }; 1207 1208 class VectorPrintOpConversion : public ConvertToLLVMPattern { 1209 public: 1210 explicit VectorPrintOpConversion(MLIRContext *context, 1211 LLVMTypeConverter &typeConverter) 1212 : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context, 1213 typeConverter) {} 1214 1215 // Proof-of-concept lowering implementation that relies on a small 1216 // runtime support library, which only needs to provide a few 1217 // printing methods (single value for all data types, opening/closing 1218 // bracket, comma, newline). The lowering fully unrolls a vector 1219 // in terms of these elementary printing operations. The advantage 1220 // of this approach is that the library can remain unaware of all 1221 // low-level implementation details of vectors while still supporting 1222 // output of any shaped and dimensioned vector. Due to full unrolling, 1223 // this approach is less suited for very large vectors though. 1224 // 1225 // TODO: rely solely on libc in future? something else? 1226 // 1227 LogicalResult 1228 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1229 ConversionPatternRewriter &rewriter) const override { 1230 auto printOp = cast<vector::PrintOp>(op); 1231 auto adaptor = vector::PrintOpAdaptor(operands); 1232 Type printType = printOp.getPrintType(); 1233 1234 if (typeConverter.convertType(printType) == nullptr) 1235 return failure(); 1236 1237 // Make sure element type has runtime support (currently just Float/Double). 1238 VectorType vectorType = printType.dyn_cast<VectorType>(); 1239 Type eltType = vectorType ? vectorType.getElementType() : printType; 1240 int64_t rank = vectorType ? vectorType.getRank() : 0; 1241 Operation *printer; 1242 if (eltType.isSignlessInteger(1) || eltType.isSignlessInteger(32)) 1243 printer = getPrintI32(op); 1244 else if (eltType.isSignlessInteger(64)) 1245 printer = getPrintI64(op); 1246 else if (eltType.isF32()) 1247 printer = getPrintFloat(op); 1248 else if (eltType.isF64()) 1249 printer = getPrintDouble(op); 1250 else 1251 return failure(); 1252 1253 // Unroll vector into elementary print calls. 1254 emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank); 1255 emitCall(rewriter, op->getLoc(), getPrintNewline(op)); 1256 rewriter.eraseOp(op); 1257 return success(); 1258 } 1259 1260 private: 1261 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1262 Value value, VectorType vectorType, Operation *printer, 1263 int64_t rank) const { 1264 Location loc = op->getLoc(); 1265 if (rank == 0) { 1266 if (value.getType() == 1267 LLVM::LLVMType::getInt1Ty(typeConverter.getDialect())) { 1268 // Convert i1 (bool) to i32 so we can use the print_i32 method. 1269 // This avoids the need for a print_i1 method with an unclear ABI. 1270 auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect()); 1271 auto trueVal = rewriter.create<ConstantOp>( 1272 loc, i32Type, rewriter.getI32IntegerAttr(1)); 1273 auto falseVal = rewriter.create<ConstantOp>( 1274 loc, i32Type, rewriter.getI32IntegerAttr(0)); 1275 value = rewriter.create<SelectOp>(loc, value, trueVal, falseVal); 1276 } 1277 emitCall(rewriter, loc, printer, value); 1278 return; 1279 } 1280 1281 emitCall(rewriter, loc, getPrintOpen(op)); 1282 Operation *printComma = getPrintComma(op); 1283 int64_t dim = vectorType.getDimSize(0); 1284 for (int64_t d = 0; d < dim; ++d) { 1285 auto reducedType = 1286 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1287 auto llvmType = typeConverter.convertType( 1288 rank > 1 ? reducedType : vectorType.getElementType()); 1289 Value nestedVal = 1290 extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d); 1291 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); 1292 if (d != dim - 1) 1293 emitCall(rewriter, loc, printComma); 1294 } 1295 emitCall(rewriter, loc, getPrintClose(op)); 1296 } 1297 1298 // Helper to emit a call. 1299 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1300 Operation *ref, ValueRange params = ValueRange()) { 1301 rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{}, 1302 rewriter.getSymbolRefAttr(ref), params); 1303 } 1304 1305 // Helper for printer method declaration (first hit) and lookup. 1306 static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect, 1307 StringRef name, ArrayRef<LLVM::LLVMType> params) { 1308 auto module = op->getParentOfType<ModuleOp>(); 1309 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1310 if (func) 1311 return func; 1312 OpBuilder moduleBuilder(module.getBodyRegion()); 1313 return moduleBuilder.create<LLVM::LLVMFuncOp>( 1314 op->getLoc(), name, 1315 LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect), 1316 params, /*isVarArg=*/false)); 1317 } 1318 1319 // Helpers for method names. 1320 Operation *getPrintI32(Operation *op) const { 1321 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1322 return getPrint(op, dialect, "print_i32", 1323 LLVM::LLVMType::getInt32Ty(dialect)); 1324 } 1325 Operation *getPrintI64(Operation *op) const { 1326 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1327 return getPrint(op, dialect, "print_i64", 1328 LLVM::LLVMType::getInt64Ty(dialect)); 1329 } 1330 Operation *getPrintFloat(Operation *op) const { 1331 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1332 return getPrint(op, dialect, "print_f32", 1333 LLVM::LLVMType::getFloatTy(dialect)); 1334 } 1335 Operation *getPrintDouble(Operation *op) const { 1336 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1337 return getPrint(op, dialect, "print_f64", 1338 LLVM::LLVMType::getDoubleTy(dialect)); 1339 } 1340 Operation *getPrintOpen(Operation *op) const { 1341 return getPrint(op, typeConverter.getDialect(), "print_open", {}); 1342 } 1343 Operation *getPrintClose(Operation *op) const { 1344 return getPrint(op, typeConverter.getDialect(), "print_close", {}); 1345 } 1346 Operation *getPrintComma(Operation *op) const { 1347 return getPrint(op, typeConverter.getDialect(), "print_comma", {}); 1348 } 1349 Operation *getPrintNewline(Operation *op) const { 1350 return getPrint(op, typeConverter.getDialect(), "print_newline", {}); 1351 } 1352 }; 1353 1354 /// Progressive lowering of ExtractStridedSliceOp to either: 1355 /// 1. extractelement + insertelement for the 1-D case 1356 /// 2. extract + optional strided_slice + insert for the n-D case. 1357 class VectorStridedSliceOpConversion 1358 : public OpRewritePattern<ExtractStridedSliceOp> { 1359 public: 1360 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 1361 1362 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 1363 PatternRewriter &rewriter) const override { 1364 auto dstType = op.getResult().getType().cast<VectorType>(); 1365 1366 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 1367 1368 int64_t offset = 1369 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1370 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 1371 int64_t stride = 1372 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1373 1374 auto loc = op.getLoc(); 1375 auto elemType = dstType.getElementType(); 1376 assert(elemType.isSignlessIntOrIndexOrFloat()); 1377 Value zero = rewriter.create<ConstantOp>(loc, elemType, 1378 rewriter.getZeroAttr(elemType)); 1379 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 1380 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1381 off += stride, ++idx) { 1382 Value extracted = extractOne(rewriter, loc, op.vector(), off); 1383 if (op.offsets().getValue().size() > 1) { 1384 extracted = rewriter.create<ExtractStridedSliceOp>( 1385 loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1), 1386 getI64SubArray(op.sizes(), /* dropFront=*/1), 1387 getI64SubArray(op.strides(), /* dropFront=*/1)); 1388 } 1389 res = insertOne(rewriter, loc, extracted, res, idx); 1390 } 1391 rewriter.replaceOp(op, {res}); 1392 return success(); 1393 } 1394 /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is 1395 /// bounded as the rank is strictly decreasing. 1396 bool hasBoundedRewriteRecursion() const final { return true; } 1397 }; 1398 1399 } // namespace 1400 1401 /// Populate the given list with patterns that convert from Vector to LLVM. 1402 void mlir::populateVectorToLLVMConversionPatterns( 1403 LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 1404 bool reassociateFPReductions) { 1405 MLIRContext *ctx = converter.getDialect()->getContext(); 1406 // clang-format off 1407 patterns.insert<VectorFMAOpNDRewritePattern, 1408 VectorInsertStridedSliceOpDifferentRankRewritePattern, 1409 VectorInsertStridedSliceOpSameRankRewritePattern, 1410 VectorStridedSliceOpConversion>(ctx); 1411 patterns.insert<VectorReductionOpConversion>( 1412 ctx, converter, reassociateFPReductions); 1413 patterns 1414 .insert<VectorShuffleOpConversion, 1415 VectorExtractElementOpConversion, 1416 VectorExtractOpConversion, 1417 VectorFMAOp1DConversion, 1418 VectorInsertElementOpConversion, 1419 VectorInsertOpConversion, 1420 VectorPrintOpConversion, 1421 VectorTransferConversion<TransferReadOp>, 1422 VectorTransferConversion<TransferWriteOp>, 1423 VectorTypeCastOpConversion, 1424 VectorMaskedLoadOpConversion, 1425 VectorMaskedStoreOpConversion, 1426 VectorGatherOpConversion, 1427 VectorScatterOpConversion, 1428 VectorExpandLoadOpConversion, 1429 VectorCompressStoreOpConversion>(ctx, converter); 1430 // clang-format on 1431 } 1432 1433 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1434 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1435 MLIRContext *ctx = converter.getDialect()->getContext(); 1436 patterns.insert<VectorMatmulOpConversion>(ctx, converter); 1437 patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter); 1438 } 1439 1440 namespace { 1441 struct LowerVectorToLLVMPass 1442 : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> { 1443 LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { 1444 this->reassociateFPReductions = options.reassociateFPReductions; 1445 } 1446 void runOnOperation() override; 1447 }; 1448 } // namespace 1449 1450 void LowerVectorToLLVMPass::runOnOperation() { 1451 // Perform progressive lowering of operations on slices and 1452 // all contraction operations. Also applies folding and DCE. 1453 { 1454 OwningRewritePatternList patterns; 1455 populateVectorToVectorCanonicalizationPatterns(patterns, &getContext()); 1456 populateVectorSlicesLoweringPatterns(patterns, &getContext()); 1457 populateVectorContractLoweringPatterns(patterns, &getContext()); 1458 applyPatternsAndFoldGreedily(getOperation(), patterns); 1459 } 1460 1461 // Convert to the LLVM IR dialect. 1462 LLVMTypeConverter converter(&getContext()); 1463 OwningRewritePatternList patterns; 1464 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 1465 populateVectorToLLVMConversionPatterns(converter, patterns, 1466 reassociateFPReductions); 1467 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 1468 populateStdToLLVMConversionPatterns(converter, patterns); 1469 1470 LLVMConversionTarget target(getContext()); 1471 if (failed(applyPartialConversion(getOperation(), target, patterns))) { 1472 signalPassFailure(); 1473 } 1474 } 1475 1476 std::unique_ptr<OperationPass<ModuleOp>> 1477 mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { 1478 return std::make_unique<LowerVectorToLLVMPass>(options); 1479 } 1480