1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Vector/VectorOps.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/MLIRContext.h" 21 #include "mlir/IR/Module.h" 22 #include "mlir/IR/Operation.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/IR/StandardTypes.h" 25 #include "mlir/IR/Types.h" 26 #include "mlir/Transforms/DialectConversion.h" 27 #include "mlir/Transforms/Passes.h" 28 #include "llvm/IR/DerivedTypes.h" 29 #include "llvm/IR/Module.h" 30 #include "llvm/IR/Type.h" 31 #include "llvm/Support/Allocator.h" 32 #include "llvm/Support/ErrorHandling.h" 33 34 using namespace mlir; 35 using namespace mlir::vector; 36 37 template <typename T> 38 static LLVM::LLVMType getPtrToElementType(T containerType, 39 LLVMTypeConverter &typeConverter) { 40 return typeConverter.convertType(containerType.getElementType()) 41 .template cast<LLVM::LLVMType>() 42 .getPointerTo(); 43 } 44 45 // Helper to reduce vector type by one rank at front. 46 static VectorType reducedVectorTypeFront(VectorType tp) { 47 assert((tp.getRank() > 1) && "unlowerable vector type"); 48 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 49 } 50 51 // Helper to reduce vector type by *all* but one rank at back. 52 static VectorType reducedVectorTypeBack(VectorType tp) { 53 assert((tp.getRank() > 1) && "unlowerable vector type"); 54 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 55 } 56 57 // Helper that picks the proper sequence for inserting. 58 static Value insertOne(ConversionPatternRewriter &rewriter, 59 LLVMTypeConverter &typeConverter, Location loc, 60 Value val1, Value val2, Type llvmType, int64_t rank, 61 int64_t pos) { 62 if (rank == 1) { 63 auto idxType = rewriter.getIndexType(); 64 auto constant = rewriter.create<LLVM::ConstantOp>( 65 loc, typeConverter.convertType(idxType), 66 rewriter.getIntegerAttr(idxType, pos)); 67 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 68 constant); 69 } 70 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 71 rewriter.getI64ArrayAttr(pos)); 72 } 73 74 // Helper that picks the proper sequence for inserting. 75 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 76 Value into, int64_t offset) { 77 auto vectorType = into.getType().cast<VectorType>(); 78 if (vectorType.getRank() > 1) 79 return rewriter.create<InsertOp>(loc, from, into, offset); 80 return rewriter.create<vector::InsertElementOp>( 81 loc, vectorType, from, into, 82 rewriter.create<ConstantIndexOp>(loc, offset)); 83 } 84 85 // Helper that picks the proper sequence for extracting. 86 static Value extractOne(ConversionPatternRewriter &rewriter, 87 LLVMTypeConverter &typeConverter, Location loc, 88 Value val, Type llvmType, int64_t rank, int64_t pos) { 89 if (rank == 1) { 90 auto idxType = rewriter.getIndexType(); 91 auto constant = rewriter.create<LLVM::ConstantOp>( 92 loc, typeConverter.convertType(idxType), 93 rewriter.getIntegerAttr(idxType, pos)); 94 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 95 constant); 96 } 97 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 98 rewriter.getI64ArrayAttr(pos)); 99 } 100 101 // Helper that picks the proper sequence for extracting. 102 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 103 int64_t offset) { 104 auto vectorType = vector.getType().cast<VectorType>(); 105 if (vectorType.getRank() > 1) 106 return rewriter.create<ExtractOp>(loc, vector, offset); 107 return rewriter.create<vector::ExtractElementOp>( 108 loc, vectorType.getElementType(), vector, 109 rewriter.create<ConstantIndexOp>(loc, offset)); 110 } 111 112 // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 113 // TODO(rriddle): Better support for attribute subtype forwarding + slicing. 114 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 115 unsigned dropFront = 0, 116 unsigned dropBack = 0) { 117 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 118 auto range = arrayAttr.getAsRange<IntegerAttr>(); 119 SmallVector<int64_t, 4> res; 120 res.reserve(arrayAttr.size() - dropFront - dropBack); 121 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 122 it != eit; ++it) 123 res.push_back((*it).getValue().getSExtValue()); 124 return res; 125 } 126 127 template <typename TransferOp> 128 LogicalResult getVectorTransferAlignment(LLVMTypeConverter &typeConverter, 129 TransferOp xferOp, unsigned &align) { 130 Type elementTy = 131 typeConverter.convertType(xferOp.getMemRefType().getElementType()); 132 if (!elementTy) 133 return failure(); 134 135 auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout(); 136 align = dataLayout.getPrefTypeAlignment( 137 elementTy.cast<LLVM::LLVMType>().getUnderlyingType()); 138 return success(); 139 } 140 141 static LogicalResult 142 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 143 LLVMTypeConverter &typeConverter, Location loc, 144 TransferReadOp xferOp, 145 ArrayRef<Value> operands, Value dataPtr) { 146 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr); 147 return success(); 148 } 149 150 static LogicalResult 151 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 152 LLVMTypeConverter &typeConverter, Location loc, 153 TransferReadOp xferOp, ArrayRef<Value> operands, 154 Value dataPtr, Value mask) { 155 auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 156 VectorType fillType = xferOp.getVectorType(); 157 Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 158 fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill); 159 160 Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 161 if (!vecTy) 162 return failure(); 163 164 unsigned align; 165 if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) 166 return failure(); 167 168 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 169 xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 170 rewriter.getI32IntegerAttr(align)); 171 return success(); 172 } 173 174 static LogicalResult 175 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 176 LLVMTypeConverter &typeConverter, Location loc, 177 TransferWriteOp xferOp, 178 ArrayRef<Value> operands, Value dataPtr) { 179 auto adaptor = TransferWriteOpAdaptor(operands); 180 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr); 181 return success(); 182 } 183 184 static LogicalResult 185 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 186 LLVMTypeConverter &typeConverter, Location loc, 187 TransferWriteOp xferOp, ArrayRef<Value> operands, 188 Value dataPtr, Value mask) { 189 unsigned align; 190 if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) 191 return failure(); 192 193 auto adaptor = TransferWriteOpAdaptor(operands); 194 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 195 xferOp, adaptor.vector(), dataPtr, mask, 196 rewriter.getI32IntegerAttr(align)); 197 return success(); 198 } 199 200 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 201 ArrayRef<Value> operands) { 202 return TransferReadOpAdaptor(operands); 203 } 204 205 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 206 ArrayRef<Value> operands) { 207 return TransferWriteOpAdaptor(operands); 208 } 209 210 namespace { 211 212 /// Conversion pattern for a vector.matrix_multiply. 213 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 214 class VectorMatmulOpConversion : public ConvertToLLVMPattern { 215 public: 216 explicit VectorMatmulOpConversion(MLIRContext *context, 217 LLVMTypeConverter &typeConverter) 218 : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, 219 typeConverter) {} 220 221 LogicalResult 222 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 223 ConversionPatternRewriter &rewriter) const override { 224 auto matmulOp = cast<vector::MatmulOp>(op); 225 auto adaptor = vector::MatmulOpAdaptor(operands); 226 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 227 op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), 228 adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), 229 matmulOp.rhs_columns()); 230 return success(); 231 } 232 }; 233 234 /// Conversion pattern for a vector.flat_transpose. 235 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 236 class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern { 237 public: 238 explicit VectorFlatTransposeOpConversion(MLIRContext *context, 239 LLVMTypeConverter &typeConverter) 240 : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(), 241 context, typeConverter) {} 242 243 LogicalResult 244 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 245 ConversionPatternRewriter &rewriter) const override { 246 auto transOp = cast<vector::FlatTransposeOp>(op); 247 auto adaptor = vector::FlatTransposeOpAdaptor(operands); 248 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 249 transOp, typeConverter.convertType(transOp.res().getType()), 250 adaptor.matrix(), transOp.rows(), transOp.columns()); 251 return success(); 252 } 253 }; 254 255 class VectorReductionOpConversion : public ConvertToLLVMPattern { 256 public: 257 explicit VectorReductionOpConversion(MLIRContext *context, 258 LLVMTypeConverter &typeConverter) 259 : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, 260 typeConverter) {} 261 262 LogicalResult 263 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 264 ConversionPatternRewriter &rewriter) const override { 265 auto reductionOp = cast<vector::ReductionOp>(op); 266 auto kind = reductionOp.kind(); 267 Type eltType = reductionOp.dest().getType(); 268 Type llvmType = typeConverter.convertType(eltType); 269 if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) { 270 // Integer reductions: add/mul/min/max/and/or/xor. 271 if (kind == "add") 272 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>( 273 op, llvmType, operands[0]); 274 else if (kind == "mul") 275 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>( 276 op, llvmType, operands[0]); 277 else if (kind == "min") 278 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>( 279 op, llvmType, operands[0]); 280 else if (kind == "max") 281 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>( 282 op, llvmType, operands[0]); 283 else if (kind == "and") 284 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>( 285 op, llvmType, operands[0]); 286 else if (kind == "or") 287 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>( 288 op, llvmType, operands[0]); 289 else if (kind == "xor") 290 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>( 291 op, llvmType, operands[0]); 292 else 293 return failure(); 294 return success(); 295 296 } else if (eltType.isF32() || eltType.isF64()) { 297 // Floating-point reductions: add/mul/min/max 298 if (kind == "add") { 299 // Optional accumulator (or zero). 300 Value acc = operands.size() > 1 ? operands[1] 301 : rewriter.create<LLVM::ConstantOp>( 302 op->getLoc(), llvmType, 303 rewriter.getZeroAttr(eltType)); 304 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>( 305 op, llvmType, acc, operands[0]); 306 } else if (kind == "mul") { 307 // Optional accumulator (or one). 308 Value acc = operands.size() > 1 309 ? operands[1] 310 : rewriter.create<LLVM::ConstantOp>( 311 op->getLoc(), llvmType, 312 rewriter.getFloatAttr(eltType, 1.0)); 313 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>( 314 op, llvmType, acc, operands[0]); 315 } else if (kind == "min") 316 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>( 317 op, llvmType, operands[0]); 318 else if (kind == "max") 319 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>( 320 op, llvmType, operands[0]); 321 else 322 return failure(); 323 return success(); 324 } 325 return failure(); 326 } 327 }; 328 329 class VectorShuffleOpConversion : public ConvertToLLVMPattern { 330 public: 331 explicit VectorShuffleOpConversion(MLIRContext *context, 332 LLVMTypeConverter &typeConverter) 333 : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context, 334 typeConverter) {} 335 336 LogicalResult 337 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 338 ConversionPatternRewriter &rewriter) const override { 339 auto loc = op->getLoc(); 340 auto adaptor = vector::ShuffleOpAdaptor(operands); 341 auto shuffleOp = cast<vector::ShuffleOp>(op); 342 auto v1Type = shuffleOp.getV1VectorType(); 343 auto v2Type = shuffleOp.getV2VectorType(); 344 auto vectorType = shuffleOp.getVectorType(); 345 Type llvmType = typeConverter.convertType(vectorType); 346 auto maskArrayAttr = shuffleOp.mask(); 347 348 // Bail if result type cannot be lowered. 349 if (!llvmType) 350 return failure(); 351 352 // Get rank and dimension sizes. 353 int64_t rank = vectorType.getRank(); 354 assert(v1Type.getRank() == rank); 355 assert(v2Type.getRank() == rank); 356 int64_t v1Dim = v1Type.getDimSize(0); 357 358 // For rank 1, where both operands have *exactly* the same vector type, 359 // there is direct shuffle support in LLVM. Use it! 360 if (rank == 1 && v1Type == v2Type) { 361 Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>( 362 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 363 rewriter.replaceOp(op, shuffle); 364 return success(); 365 } 366 367 // For all other cases, insert the individual values individually. 368 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 369 int64_t insPos = 0; 370 for (auto en : llvm::enumerate(maskArrayAttr)) { 371 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 372 Value value = adaptor.v1(); 373 if (extPos >= v1Dim) { 374 extPos -= v1Dim; 375 value = adaptor.v2(); 376 } 377 Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType, 378 rank, extPos); 379 insert = insertOne(rewriter, typeConverter, loc, insert, extract, 380 llvmType, rank, insPos++); 381 } 382 rewriter.replaceOp(op, insert); 383 return success(); 384 } 385 }; 386 387 class VectorExtractElementOpConversion : public ConvertToLLVMPattern { 388 public: 389 explicit VectorExtractElementOpConversion(MLIRContext *context, 390 LLVMTypeConverter &typeConverter) 391 : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(), 392 context, typeConverter) {} 393 394 LogicalResult 395 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 396 ConversionPatternRewriter &rewriter) const override { 397 auto adaptor = vector::ExtractElementOpAdaptor(operands); 398 auto extractEltOp = cast<vector::ExtractElementOp>(op); 399 auto vectorType = extractEltOp.getVectorType(); 400 auto llvmType = typeConverter.convertType(vectorType.getElementType()); 401 402 // Bail if result type cannot be lowered. 403 if (!llvmType) 404 return failure(); 405 406 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 407 op, llvmType, adaptor.vector(), adaptor.position()); 408 return success(); 409 } 410 }; 411 412 class VectorExtractOpConversion : public ConvertToLLVMPattern { 413 public: 414 explicit VectorExtractOpConversion(MLIRContext *context, 415 LLVMTypeConverter &typeConverter) 416 : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context, 417 typeConverter) {} 418 419 LogicalResult 420 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 421 ConversionPatternRewriter &rewriter) const override { 422 auto loc = op->getLoc(); 423 auto adaptor = vector::ExtractOpAdaptor(operands); 424 auto extractOp = cast<vector::ExtractOp>(op); 425 auto vectorType = extractOp.getVectorType(); 426 auto resultType = extractOp.getResult().getType(); 427 auto llvmResultType = typeConverter.convertType(resultType); 428 auto positionArrayAttr = extractOp.position(); 429 430 // Bail if result type cannot be lowered. 431 if (!llvmResultType) 432 return failure(); 433 434 // One-shot extraction of vector from array (only requires extractvalue). 435 if (resultType.isa<VectorType>()) { 436 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 437 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 438 rewriter.replaceOp(op, extracted); 439 return success(); 440 } 441 442 // Potential extraction of 1-D vector from array. 443 auto *context = op->getContext(); 444 Value extracted = adaptor.vector(); 445 auto positionAttrs = positionArrayAttr.getValue(); 446 if (positionAttrs.size() > 1) { 447 auto oneDVectorType = reducedVectorTypeBack(vectorType); 448 auto nMinusOnePositionAttrs = 449 ArrayAttr::get(positionAttrs.drop_back(), context); 450 extracted = rewriter.create<LLVM::ExtractValueOp>( 451 loc, typeConverter.convertType(oneDVectorType), extracted, 452 nMinusOnePositionAttrs); 453 } 454 455 // Remaining extraction of element from 1-D LLVM vector 456 auto position = positionAttrs.back().cast<IntegerAttr>(); 457 auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 458 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 459 extracted = 460 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 461 rewriter.replaceOp(op, extracted); 462 463 return success(); 464 } 465 }; 466 467 /// Conversion pattern that turns a vector.fma on a 1-D vector 468 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 469 /// This does not match vectors of n >= 2 rank. 470 /// 471 /// Example: 472 /// ``` 473 /// vector.fma %a, %a, %a : vector<8xf32> 474 /// ``` 475 /// is converted to: 476 /// ``` 477 /// llvm.intr.fma %va, %va, %va: 478 /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) 479 /// -> !llvm<"<8 x float>"> 480 /// ``` 481 class VectorFMAOp1DConversion : public ConvertToLLVMPattern { 482 public: 483 explicit VectorFMAOp1DConversion(MLIRContext *context, 484 LLVMTypeConverter &typeConverter) 485 : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context, 486 typeConverter) {} 487 488 LogicalResult 489 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 490 ConversionPatternRewriter &rewriter) const override { 491 auto adaptor = vector::FMAOpAdaptor(operands); 492 vector::FMAOp fmaOp = cast<vector::FMAOp>(op); 493 VectorType vType = fmaOp.getVectorType(); 494 if (vType.getRank() != 1) 495 return failure(); 496 rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(), 497 adaptor.acc()); 498 return success(); 499 } 500 }; 501 502 class VectorInsertElementOpConversion : public ConvertToLLVMPattern { 503 public: 504 explicit VectorInsertElementOpConversion(MLIRContext *context, 505 LLVMTypeConverter &typeConverter) 506 : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(), 507 context, typeConverter) {} 508 509 LogicalResult 510 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 511 ConversionPatternRewriter &rewriter) const override { 512 auto adaptor = vector::InsertElementOpAdaptor(operands); 513 auto insertEltOp = cast<vector::InsertElementOp>(op); 514 auto vectorType = insertEltOp.getDestVectorType(); 515 auto llvmType = typeConverter.convertType(vectorType); 516 517 // Bail if result type cannot be lowered. 518 if (!llvmType) 519 return failure(); 520 521 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 522 op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); 523 return success(); 524 } 525 }; 526 527 class VectorInsertOpConversion : public ConvertToLLVMPattern { 528 public: 529 explicit VectorInsertOpConversion(MLIRContext *context, 530 LLVMTypeConverter &typeConverter) 531 : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context, 532 typeConverter) {} 533 534 LogicalResult 535 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 536 ConversionPatternRewriter &rewriter) const override { 537 auto loc = op->getLoc(); 538 auto adaptor = vector::InsertOpAdaptor(operands); 539 auto insertOp = cast<vector::InsertOp>(op); 540 auto sourceType = insertOp.getSourceType(); 541 auto destVectorType = insertOp.getDestVectorType(); 542 auto llvmResultType = typeConverter.convertType(destVectorType); 543 auto positionArrayAttr = insertOp.position(); 544 545 // Bail if result type cannot be lowered. 546 if (!llvmResultType) 547 return failure(); 548 549 // One-shot insertion of a vector into an array (only requires insertvalue). 550 if (sourceType.isa<VectorType>()) { 551 Value inserted = rewriter.create<LLVM::InsertValueOp>( 552 loc, llvmResultType, adaptor.dest(), adaptor.source(), 553 positionArrayAttr); 554 rewriter.replaceOp(op, inserted); 555 return success(); 556 } 557 558 // Potential extraction of 1-D vector from array. 559 auto *context = op->getContext(); 560 Value extracted = adaptor.dest(); 561 auto positionAttrs = positionArrayAttr.getValue(); 562 auto position = positionAttrs.back().cast<IntegerAttr>(); 563 auto oneDVectorType = destVectorType; 564 if (positionAttrs.size() > 1) { 565 oneDVectorType = reducedVectorTypeBack(destVectorType); 566 auto nMinusOnePositionAttrs = 567 ArrayAttr::get(positionAttrs.drop_back(), context); 568 extracted = rewriter.create<LLVM::ExtractValueOp>( 569 loc, typeConverter.convertType(oneDVectorType), extracted, 570 nMinusOnePositionAttrs); 571 } 572 573 // Insertion of an element into a 1-D LLVM vector. 574 auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 575 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 576 Value inserted = rewriter.create<LLVM::InsertElementOp>( 577 loc, typeConverter.convertType(oneDVectorType), extracted, 578 adaptor.source(), constant); 579 580 // Potential insertion of resulting 1-D vector into array. 581 if (positionAttrs.size() > 1) { 582 auto nMinusOnePositionAttrs = 583 ArrayAttr::get(positionAttrs.drop_back(), context); 584 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 585 adaptor.dest(), inserted, 586 nMinusOnePositionAttrs); 587 } 588 589 rewriter.replaceOp(op, inserted); 590 return success(); 591 } 592 }; 593 594 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 595 /// 596 /// Example: 597 /// ``` 598 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 599 /// ``` 600 /// is rewritten into: 601 /// ``` 602 /// %r = splat %f0: vector<2x4xf32> 603 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 604 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 605 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 606 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 607 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 608 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 609 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 610 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 611 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 612 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 613 /// // %r3 holds the final value. 614 /// ``` 615 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 616 public: 617 using OpRewritePattern<FMAOp>::OpRewritePattern; 618 619 LogicalResult matchAndRewrite(FMAOp op, 620 PatternRewriter &rewriter) const override { 621 auto vType = op.getVectorType(); 622 if (vType.getRank() < 2) 623 return failure(); 624 625 auto loc = op.getLoc(); 626 auto elemType = vType.getElementType(); 627 Value zero = rewriter.create<ConstantOp>(loc, elemType, 628 rewriter.getZeroAttr(elemType)); 629 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 630 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 631 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 632 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 633 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 634 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 635 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 636 } 637 rewriter.replaceOp(op, desc); 638 return success(); 639 } 640 }; 641 642 // When ranks are different, InsertStridedSlice needs to extract a properly 643 // ranked vector from the destination vector into which to insert. This pattern 644 // only takes care of this part and forwards the rest of the conversion to 645 // another pattern that converts InsertStridedSlice for operands of the same 646 // rank. 647 // 648 // RewritePattern for InsertStridedSliceOp where source and destination vectors 649 // have different ranks. In this case: 650 // 1. the proper subvector is extracted from the destination vector 651 // 2. a new InsertStridedSlice op is created to insert the source in the 652 // destination subvector 653 // 3. the destination subvector is inserted back in the proper place 654 // 4. the op is replaced by the result of step 3. 655 // The new InsertStridedSlice from step 2. will be picked up by a 656 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 657 class VectorInsertStridedSliceOpDifferentRankRewritePattern 658 : public OpRewritePattern<InsertStridedSliceOp> { 659 public: 660 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 661 662 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 663 PatternRewriter &rewriter) const override { 664 auto srcType = op.getSourceVectorType(); 665 auto dstType = op.getDestVectorType(); 666 667 if (op.offsets().getValue().empty()) 668 return failure(); 669 670 auto loc = op.getLoc(); 671 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 672 assert(rankDiff >= 0); 673 if (rankDiff == 0) 674 return failure(); 675 676 int64_t rankRest = dstType.getRank() - rankDiff; 677 // Extract / insert the subvector of matching rank and InsertStridedSlice 678 // on it. 679 Value extracted = 680 rewriter.create<ExtractOp>(loc, op.dest(), 681 getI64SubArray(op.offsets(), /*dropFront=*/0, 682 /*dropFront=*/rankRest)); 683 // A different pattern will kick in for InsertStridedSlice with matching 684 // ranks. 685 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 686 loc, op.source(), extracted, 687 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 688 getI64SubArray(op.strides(), /*dropFront=*/0)); 689 rewriter.replaceOpWithNewOp<InsertOp>( 690 op, stridedSliceInnerOp.getResult(), op.dest(), 691 getI64SubArray(op.offsets(), /*dropFront=*/0, 692 /*dropFront=*/rankRest)); 693 return success(); 694 } 695 }; 696 697 // RewritePattern for InsertStridedSliceOp where source and destination vectors 698 // have the same rank. In this case, we reduce 699 // 1. the proper subvector is extracted from the destination vector 700 // 2. a new InsertStridedSlice op is created to insert the source in the 701 // destination subvector 702 // 3. the destination subvector is inserted back in the proper place 703 // 4. the op is replaced by the result of step 3. 704 // The new InsertStridedSlice from step 2. will be picked up by a 705 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 706 class VectorInsertStridedSliceOpSameRankRewritePattern 707 : public OpRewritePattern<InsertStridedSliceOp> { 708 public: 709 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 710 711 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 712 PatternRewriter &rewriter) const override { 713 auto srcType = op.getSourceVectorType(); 714 auto dstType = op.getDestVectorType(); 715 716 if (op.offsets().getValue().empty()) 717 return failure(); 718 719 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 720 assert(rankDiff >= 0); 721 if (rankDiff != 0) 722 return failure(); 723 724 if (srcType == dstType) { 725 rewriter.replaceOp(op, op.source()); 726 return success(); 727 } 728 729 int64_t offset = 730 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 731 int64_t size = srcType.getShape().front(); 732 int64_t stride = 733 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 734 735 auto loc = op.getLoc(); 736 Value res = op.dest(); 737 // For each slice of the source vector along the most major dimension. 738 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 739 off += stride, ++idx) { 740 // 1. extract the proper subvector (or element) from source 741 Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 742 if (extractedSource.getType().isa<VectorType>()) { 743 // 2. If we have a vector, extract the proper subvector from destination 744 // Otherwise we are at the element level and no need to recurse. 745 Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 746 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 747 // smaller rank. 748 extractedSource = rewriter.create<InsertStridedSliceOp>( 749 loc, extractedSource, extractedDest, 750 getI64SubArray(op.offsets(), /* dropFront=*/1), 751 getI64SubArray(op.strides(), /* dropFront=*/1)); 752 } 753 // 4. Insert the extractedSource into the res vector. 754 res = insertOne(rewriter, loc, extractedSource, res, off); 755 } 756 757 rewriter.replaceOp(op, res); 758 return success(); 759 } 760 /// This pattern creates recursive InsertStridedSliceOp, but the recursion is 761 /// bounded as the rank is strictly decreasing. 762 bool hasBoundedRewriteRecursion() const final { return true; } 763 }; 764 765 class VectorTypeCastOpConversion : public ConvertToLLVMPattern { 766 public: 767 explicit VectorTypeCastOpConversion(MLIRContext *context, 768 LLVMTypeConverter &typeConverter) 769 : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context, 770 typeConverter) {} 771 772 LogicalResult 773 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 774 ConversionPatternRewriter &rewriter) const override { 775 auto loc = op->getLoc(); 776 vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); 777 MemRefType sourceMemRefType = 778 castOp.getOperand().getType().cast<MemRefType>(); 779 MemRefType targetMemRefType = 780 castOp.getResult().getType().cast<MemRefType>(); 781 782 // Only static shape casts supported atm. 783 if (!sourceMemRefType.hasStaticShape() || 784 !targetMemRefType.hasStaticShape()) 785 return failure(); 786 787 auto llvmSourceDescriptorTy = 788 operands[0].getType().dyn_cast<LLVM::LLVMType>(); 789 if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 790 return failure(); 791 MemRefDescriptor sourceMemRef(operands[0]); 792 793 auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) 794 .dyn_cast_or_null<LLVM::LLVMType>(); 795 if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 796 return failure(); 797 798 int64_t offset; 799 SmallVector<int64_t, 4> strides; 800 auto successStrides = 801 getStridesAndOffset(sourceMemRefType, strides, offset); 802 bool isContiguous = (strides.back() == 1); 803 if (isContiguous) { 804 auto sizes = sourceMemRefType.getShape(); 805 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 806 if (strides[index] != strides[index + 1] * sizes[index + 1]) { 807 isContiguous = false; 808 break; 809 } 810 } 811 } 812 // Only contiguous source tensors supported atm. 813 if (failed(successStrides) || !isContiguous) 814 return failure(); 815 816 auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); 817 818 // Create descriptor. 819 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 820 Type llvmTargetElementTy = desc.getElementType(); 821 // Set allocated ptr. 822 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 823 allocated = 824 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 825 desc.setAllocatedPtr(rewriter, loc, allocated); 826 // Set aligned ptr. 827 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 828 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 829 desc.setAlignedPtr(rewriter, loc, ptr); 830 // Fill offset 0. 831 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 832 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 833 desc.setOffset(rewriter, loc, zero); 834 835 // Fill size and stride descriptors in memref. 836 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 837 int64_t index = indexedSize.index(); 838 auto sizeAttr = 839 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 840 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 841 desc.setSize(rewriter, loc, index, size); 842 auto strideAttr = 843 rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); 844 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 845 desc.setStride(rewriter, loc, index, stride); 846 } 847 848 rewriter.replaceOp(op, {desc}); 849 return success(); 850 } 851 }; 852 853 /// Conversion pattern that converts a 1-D vector transfer read/write op in a 854 /// sequence of: 855 /// 1. Bitcast or addrspacecast to vector form. 856 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 857 /// 3. Create a mask where offsetVector is compared against memref upper bound. 858 /// 4. Rewrite op as a masked read or write. 859 template <typename ConcreteOp> 860 class VectorTransferConversion : public ConvertToLLVMPattern { 861 public: 862 explicit VectorTransferConversion(MLIRContext *context, 863 LLVMTypeConverter &typeConv) 864 : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, 865 typeConv) {} 866 867 LogicalResult 868 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 869 ConversionPatternRewriter &rewriter) const override { 870 auto xferOp = cast<ConcreteOp>(op); 871 auto adaptor = getTransferOpAdapter(xferOp, operands); 872 873 if (xferOp.getVectorType().getRank() > 1 || 874 llvm::size(xferOp.indices()) == 0) 875 return failure(); 876 if (xferOp.permutation_map() != 877 AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 878 xferOp.getVectorType().getRank(), 879 op->getContext())) 880 return failure(); 881 882 auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 883 884 Location loc = op->getLoc(); 885 Type i64Type = rewriter.getIntegerType(64); 886 MemRefType memRefType = xferOp.getMemRefType(); 887 888 // 1. Get the source/dst address as an LLVM vector pointer. 889 // The vector pointer would always be on address space 0, therefore 890 // addrspacecast shall be used when source/dst memrefs are not on 891 // address space 0. 892 // TODO: support alignment when possible. 893 Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), 894 adaptor.indices(), rewriter, getModule()); 895 auto vecTy = 896 toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>(); 897 Value vectorDataPtr; 898 if (memRefType.getMemorySpace() == 0) 899 vectorDataPtr = 900 rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr); 901 else 902 vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 903 loc, vecTy.getPointerTo(), dataPtr); 904 905 if (!xferOp.isMaskedDim(0)) 906 return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc, 907 xferOp, operands, vectorDataPtr); 908 909 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 910 unsigned vecWidth = vecTy.getVectorNumElements(); 911 VectorType vectorCmpType = VectorType::get(vecWidth, i64Type); 912 SmallVector<int64_t, 8> indices; 913 indices.reserve(vecWidth); 914 for (unsigned i = 0; i < vecWidth; ++i) 915 indices.push_back(i); 916 Value linearIndices = rewriter.create<ConstantOp>( 917 loc, vectorCmpType, 918 DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices))); 919 linearIndices = rewriter.create<LLVM::DialectCastOp>( 920 loc, toLLVMTy(vectorCmpType), linearIndices); 921 922 // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 923 // TODO(ntv, ajcbik): when the leaf transfer rank is k > 1 we need the last 924 // `k` dimensions here. 925 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 926 Value offsetIndex = *(xferOp.indices().begin() + lastIndex); 927 offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex); 928 Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex); 929 Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices); 930 931 // 4. Let dim the memref dimension, compute the vector comparison mask: 932 // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 933 Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex); 934 dim = rewriter.create<IndexCastOp>(loc, i64Type, dim); 935 dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim); 936 Value mask = 937 rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim); 938 mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()), 939 mask); 940 941 // 5. Rewrite as a masked read / write. 942 return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp, 943 operands, vectorDataPtr, mask); 944 } 945 }; 946 947 class VectorPrintOpConversion : public ConvertToLLVMPattern { 948 public: 949 explicit VectorPrintOpConversion(MLIRContext *context, 950 LLVMTypeConverter &typeConverter) 951 : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context, 952 typeConverter) {} 953 954 // Proof-of-concept lowering implementation that relies on a small 955 // runtime support library, which only needs to provide a few 956 // printing methods (single value for all data types, opening/closing 957 // bracket, comma, newline). The lowering fully unrolls a vector 958 // in terms of these elementary printing operations. The advantage 959 // of this approach is that the library can remain unaware of all 960 // low-level implementation details of vectors while still supporting 961 // output of any shaped and dimensioned vector. Due to full unrolling, 962 // this approach is less suited for very large vectors though. 963 // 964 // TODO(ajcbik): rely solely on libc in future? something else? 965 // 966 LogicalResult 967 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 968 ConversionPatternRewriter &rewriter) const override { 969 auto printOp = cast<vector::PrintOp>(op); 970 auto adaptor = vector::PrintOpAdaptor(operands); 971 Type printType = printOp.getPrintType(); 972 973 if (typeConverter.convertType(printType) == nullptr) 974 return failure(); 975 976 // Make sure element type has runtime support (currently just Float/Double). 977 VectorType vectorType = printType.dyn_cast<VectorType>(); 978 Type eltType = vectorType ? vectorType.getElementType() : printType; 979 int64_t rank = vectorType ? vectorType.getRank() : 0; 980 Operation *printer; 981 if (eltType.isSignlessInteger(1)) 982 printer = getPrintI1(op); 983 else if (eltType.isSignlessInteger(32)) 984 printer = getPrintI32(op); 985 else if (eltType.isSignlessInteger(64)) 986 printer = getPrintI64(op); 987 else if (eltType.isF32()) 988 printer = getPrintFloat(op); 989 else if (eltType.isF64()) 990 printer = getPrintDouble(op); 991 else 992 return failure(); 993 994 // Unroll vector into elementary print calls. 995 emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank); 996 emitCall(rewriter, op->getLoc(), getPrintNewline(op)); 997 rewriter.eraseOp(op); 998 return success(); 999 } 1000 1001 private: 1002 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1003 Value value, VectorType vectorType, Operation *printer, 1004 int64_t rank) const { 1005 Location loc = op->getLoc(); 1006 if (rank == 0) { 1007 emitCall(rewriter, loc, printer, value); 1008 return; 1009 } 1010 1011 emitCall(rewriter, loc, getPrintOpen(op)); 1012 Operation *printComma = getPrintComma(op); 1013 int64_t dim = vectorType.getDimSize(0); 1014 for (int64_t d = 0; d < dim; ++d) { 1015 auto reducedType = 1016 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1017 auto llvmType = typeConverter.convertType( 1018 rank > 1 ? reducedType : vectorType.getElementType()); 1019 Value nestedVal = 1020 extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d); 1021 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); 1022 if (d != dim - 1) 1023 emitCall(rewriter, loc, printComma); 1024 } 1025 emitCall(rewriter, loc, getPrintClose(op)); 1026 } 1027 1028 // Helper to emit a call. 1029 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1030 Operation *ref, ValueRange params = ValueRange()) { 1031 rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{}, 1032 rewriter.getSymbolRefAttr(ref), params); 1033 } 1034 1035 // Helper for printer method declaration (first hit) and lookup. 1036 static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect, 1037 StringRef name, ArrayRef<LLVM::LLVMType> params) { 1038 auto module = op->getParentOfType<ModuleOp>(); 1039 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1040 if (func) 1041 return func; 1042 OpBuilder moduleBuilder(module.getBodyRegion()); 1043 return moduleBuilder.create<LLVM::LLVMFuncOp>( 1044 op->getLoc(), name, 1045 LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect), 1046 params, /*isVarArg=*/false)); 1047 } 1048 1049 // Helpers for method names. 1050 Operation *getPrintI1(Operation *op) const { 1051 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1052 return getPrint(op, dialect, "print_i1", 1053 LLVM::LLVMType::getInt1Ty(dialect)); 1054 } 1055 Operation *getPrintI32(Operation *op) const { 1056 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1057 return getPrint(op, dialect, "print_i32", 1058 LLVM::LLVMType::getInt32Ty(dialect)); 1059 } 1060 Operation *getPrintI64(Operation *op) const { 1061 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1062 return getPrint(op, dialect, "print_i64", 1063 LLVM::LLVMType::getInt64Ty(dialect)); 1064 } 1065 Operation *getPrintFloat(Operation *op) const { 1066 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1067 return getPrint(op, dialect, "print_f32", 1068 LLVM::LLVMType::getFloatTy(dialect)); 1069 } 1070 Operation *getPrintDouble(Operation *op) const { 1071 LLVM::LLVMDialect *dialect = typeConverter.getDialect(); 1072 return getPrint(op, dialect, "print_f64", 1073 LLVM::LLVMType::getDoubleTy(dialect)); 1074 } 1075 Operation *getPrintOpen(Operation *op) const { 1076 return getPrint(op, typeConverter.getDialect(), "print_open", {}); 1077 } 1078 Operation *getPrintClose(Operation *op) const { 1079 return getPrint(op, typeConverter.getDialect(), "print_close", {}); 1080 } 1081 Operation *getPrintComma(Operation *op) const { 1082 return getPrint(op, typeConverter.getDialect(), "print_comma", {}); 1083 } 1084 Operation *getPrintNewline(Operation *op) const { 1085 return getPrint(op, typeConverter.getDialect(), "print_newline", {}); 1086 } 1087 }; 1088 1089 /// Progressive lowering of ExtractStridedSliceOp to either: 1090 /// 1. extractelement + insertelement for the 1-D case 1091 /// 2. extract + optional strided_slice + insert for the n-D case. 1092 class VectorStridedSliceOpConversion 1093 : public OpRewritePattern<ExtractStridedSliceOp> { 1094 public: 1095 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 1096 1097 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 1098 PatternRewriter &rewriter) const override { 1099 auto dstType = op.getResult().getType().cast<VectorType>(); 1100 1101 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 1102 1103 int64_t offset = 1104 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1105 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 1106 int64_t stride = 1107 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1108 1109 auto loc = op.getLoc(); 1110 auto elemType = dstType.getElementType(); 1111 assert(elemType.isSignlessIntOrIndexOrFloat()); 1112 Value zero = rewriter.create<ConstantOp>(loc, elemType, 1113 rewriter.getZeroAttr(elemType)); 1114 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 1115 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1116 off += stride, ++idx) { 1117 Value extracted = extractOne(rewriter, loc, op.vector(), off); 1118 if (op.offsets().getValue().size() > 1) { 1119 extracted = rewriter.create<ExtractStridedSliceOp>( 1120 loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1), 1121 getI64SubArray(op.sizes(), /* dropFront=*/1), 1122 getI64SubArray(op.strides(), /* dropFront=*/1)); 1123 } 1124 res = insertOne(rewriter, loc, extracted, res, idx); 1125 } 1126 rewriter.replaceOp(op, {res}); 1127 return success(); 1128 } 1129 /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is 1130 /// bounded as the rank is strictly decreasing. 1131 bool hasBoundedRewriteRecursion() const final { return true; } 1132 }; 1133 1134 } // namespace 1135 1136 /// Populate the given list with patterns that convert from Vector to LLVM. 1137 void mlir::populateVectorToLLVMConversionPatterns( 1138 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1139 MLIRContext *ctx = converter.getDialect()->getContext(); 1140 // clang-format off 1141 patterns.insert<VectorFMAOpNDRewritePattern, 1142 VectorInsertStridedSliceOpDifferentRankRewritePattern, 1143 VectorInsertStridedSliceOpSameRankRewritePattern, 1144 VectorStridedSliceOpConversion>(ctx); 1145 patterns 1146 .insert<VectorReductionOpConversion, 1147 VectorShuffleOpConversion, 1148 VectorExtractElementOpConversion, 1149 VectorExtractOpConversion, 1150 VectorFMAOp1DConversion, 1151 VectorInsertElementOpConversion, 1152 VectorInsertOpConversion, 1153 VectorPrintOpConversion, 1154 VectorTransferConversion<TransferReadOp>, 1155 VectorTransferConversion<TransferWriteOp>, 1156 VectorTypeCastOpConversion>(ctx, converter); 1157 // clang-format on 1158 } 1159 1160 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1161 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1162 MLIRContext *ctx = converter.getDialect()->getContext(); 1163 patterns.insert<VectorMatmulOpConversion>(ctx, converter); 1164 patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter); 1165 } 1166 1167 namespace { 1168 struct LowerVectorToLLVMPass 1169 : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> { 1170 void runOnOperation() override; 1171 }; 1172 } // namespace 1173 1174 void LowerVectorToLLVMPass::runOnOperation() { 1175 // Perform progressive lowering of operations on slices and 1176 // all contraction operations. Also applies folding and DCE. 1177 { 1178 OwningRewritePatternList patterns; 1179 populateVectorToVectorCanonicalizationPatterns(patterns, &getContext()); 1180 populateVectorSlicesLoweringPatterns(patterns, &getContext()); 1181 populateVectorContractLoweringPatterns(patterns, &getContext()); 1182 applyPatternsAndFoldGreedily(getOperation(), patterns); 1183 } 1184 1185 // Convert to the LLVM IR dialect. 1186 LLVMTypeConverter converter(&getContext()); 1187 OwningRewritePatternList patterns; 1188 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 1189 populateVectorToLLVMConversionPatterns(converter, patterns); 1190 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 1191 populateStdToLLVMConversionPatterns(converter, patterns); 1192 1193 LLVMConversionTarget target(getContext()); 1194 target.addDynamicallyLegalOp<FuncOp>( 1195 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1196 if (failed(applyPartialConversion(getOperation(), target, patterns, 1197 &converter))) { 1198 signalPassFailure(); 1199 } 1200 } 1201 1202 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertVectorToLLVMPass() { 1203 return std::make_unique<LowerVectorToLLVMPass>(); 1204 } 1205