1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Vector/VectorOps.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/Target/LLVMIR/TypeTranslation.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 20 using namespace mlir; 21 using namespace mlir::vector; 22 23 // Helper to reduce vector type by one rank at front. 24 static VectorType reducedVectorTypeFront(VectorType tp) { 25 assert((tp.getRank() > 1) && "unlowerable vector type"); 26 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 27 } 28 29 // Helper to reduce vector type by *all* but one rank at back. 30 static VectorType reducedVectorTypeBack(VectorType tp) { 31 assert((tp.getRank() > 1) && "unlowerable vector type"); 32 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 33 } 34 35 // Helper that picks the proper sequence for inserting. 36 static Value insertOne(ConversionPatternRewriter &rewriter, 37 LLVMTypeConverter &typeConverter, Location loc, 38 Value val1, Value val2, Type llvmType, int64_t rank, 39 int64_t pos) { 40 if (rank == 1) { 41 auto idxType = rewriter.getIndexType(); 42 auto constant = rewriter.create<LLVM::ConstantOp>( 43 loc, typeConverter.convertType(idxType), 44 rewriter.getIntegerAttr(idxType, pos)); 45 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 46 constant); 47 } 48 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 49 rewriter.getI64ArrayAttr(pos)); 50 } 51 52 // Helper that picks the proper sequence for inserting. 53 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 54 Value into, int64_t offset) { 55 auto vectorType = into.getType().cast<VectorType>(); 56 if (vectorType.getRank() > 1) 57 return rewriter.create<InsertOp>(loc, from, into, offset); 58 return rewriter.create<vector::InsertElementOp>( 59 loc, vectorType, from, into, 60 rewriter.create<ConstantIndexOp>(loc, offset)); 61 } 62 63 // Helper that picks the proper sequence for extracting. 64 static Value extractOne(ConversionPatternRewriter &rewriter, 65 LLVMTypeConverter &typeConverter, Location loc, 66 Value val, Type llvmType, int64_t rank, int64_t pos) { 67 if (rank == 1) { 68 auto idxType = rewriter.getIndexType(); 69 auto constant = rewriter.create<LLVM::ConstantOp>( 70 loc, typeConverter.convertType(idxType), 71 rewriter.getIntegerAttr(idxType, pos)); 72 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 73 constant); 74 } 75 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 76 rewriter.getI64ArrayAttr(pos)); 77 } 78 79 // Helper that picks the proper sequence for extracting. 80 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 81 int64_t offset) { 82 auto vectorType = vector.getType().cast<VectorType>(); 83 if (vectorType.getRank() > 1) 84 return rewriter.create<ExtractOp>(loc, vector, offset); 85 return rewriter.create<vector::ExtractElementOp>( 86 loc, vectorType.getElementType(), vector, 87 rewriter.create<ConstantIndexOp>(loc, offset)); 88 } 89 90 // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 91 // TODO: Better support for attribute subtype forwarding + slicing. 92 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 93 unsigned dropFront = 0, 94 unsigned dropBack = 0) { 95 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 96 auto range = arrayAttr.getAsRange<IntegerAttr>(); 97 SmallVector<int64_t, 4> res; 98 res.reserve(arrayAttr.size() - dropFront - dropBack); 99 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 100 it != eit; ++it) 101 res.push_back((*it).getValue().getSExtValue()); 102 return res; 103 } 104 105 // Helper that returns a vector comparison that constructs a mask: 106 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 107 // 108 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 109 // much more compact, IR for this operation, but LLVM eventually 110 // generates more elaborate instructions for this intrinsic since it 111 // is very conservative on the boundary conditions. 112 static Value buildVectorComparison(ConversionPatternRewriter &rewriter, 113 Operation *op, bool enableIndexOptimizations, 114 int64_t dim, Value b, Value *off = nullptr) { 115 auto loc = op->getLoc(); 116 // If we can assume all indices fit in 32-bit, we perform the vector 117 // comparison in 32-bit to get a higher degree of SIMD parallelism. 118 // Otherwise we perform the vector comparison using 64-bit indices. 119 Value indices; 120 Type idxType; 121 if (enableIndexOptimizations) { 122 indices = rewriter.create<ConstantOp>( 123 loc, rewriter.getI32VectorAttr( 124 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)))); 125 idxType = rewriter.getI32Type(); 126 } else { 127 indices = rewriter.create<ConstantOp>( 128 loc, rewriter.getI64VectorAttr( 129 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)))); 130 idxType = rewriter.getI64Type(); 131 } 132 // Add in an offset if requested. 133 if (off) { 134 Value o = rewriter.create<IndexCastOp>(loc, idxType, *off); 135 Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o); 136 indices = rewriter.create<AddIOp>(loc, ov, indices); 137 } 138 // Construct the vector comparison. 139 Value bound = rewriter.create<IndexCastOp>(loc, idxType, b); 140 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 141 return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds); 142 } 143 144 // Helper that returns data layout alignment of a memref. 145 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 146 MemRefType memrefType, unsigned &align) { 147 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 148 if (!elementTy) 149 return failure(); 150 151 // TODO: this should use the MLIR data layout when it becomes available and 152 // stop depending on translation. 153 llvm::LLVMContext llvmContext; 154 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 155 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 156 return success(); 157 } 158 159 // Helper that returns the base address of a memref. 160 static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc, 161 Value memref, MemRefType memRefType, Value &base) { 162 // Inspect stride and offset structure. 163 // 164 // TODO: flat memory only for now, generalize 165 // 166 int64_t offset; 167 SmallVector<int64_t, 4> strides; 168 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 169 if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || 170 offset != 0 || memRefType.getMemorySpace() != 0) 171 return failure(); 172 base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); 173 return success(); 174 } 175 176 // Helper that returns a pointer given a memref base. 177 static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, 178 Location loc, Value memref, 179 MemRefType memRefType, Value &ptr) { 180 Value base; 181 if (failed(getBase(rewriter, loc, memref, memRefType, base))) 182 return failure(); 183 auto pType = MemRefDescriptor(memref).getElementPtrType(); 184 ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base); 185 return success(); 186 } 187 188 // Helper that returns a bit-casted pointer given a memref base. 189 static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, 190 Location loc, Value memref, 191 MemRefType memRefType, Type type, Value &ptr) { 192 Value base; 193 if (failed(getBase(rewriter, loc, memref, memRefType, base))) 194 return failure(); 195 auto pType = LLVM::LLVMPointerType::get(type); 196 base = rewriter.create<LLVM::BitcastOp>(loc, pType, base); 197 ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base); 198 return success(); 199 } 200 201 // Helper that returns vector of pointers given a memref base and an index 202 // vector. 203 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 204 Location loc, Value memref, Value indices, 205 MemRefType memRefType, VectorType vType, 206 Type iType, Value &ptrs) { 207 Value base; 208 if (failed(getBase(rewriter, loc, memref, memRefType, base))) 209 return failure(); 210 auto pType = MemRefDescriptor(memref).getElementPtrType(); 211 auto ptrsType = LLVM::LLVMFixedVectorType::get(pType, vType.getDimSize(0)); 212 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices); 213 return success(); 214 } 215 216 static LogicalResult 217 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 218 LLVMTypeConverter &typeConverter, Location loc, 219 TransferReadOp xferOp, 220 ArrayRef<Value> operands, Value dataPtr) { 221 unsigned align; 222 if (failed(getMemRefAlignment( 223 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 224 return failure(); 225 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align); 226 return success(); 227 } 228 229 static LogicalResult 230 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 231 LLVMTypeConverter &typeConverter, Location loc, 232 TransferReadOp xferOp, ArrayRef<Value> operands, 233 Value dataPtr, Value mask) { 234 auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 235 VectorType fillType = xferOp.getVectorType(); 236 Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 237 fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill); 238 239 Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 240 if (!vecTy) 241 return failure(); 242 243 unsigned align; 244 if (failed(getMemRefAlignment( 245 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 246 return failure(); 247 248 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 249 xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 250 rewriter.getI32IntegerAttr(align)); 251 return success(); 252 } 253 254 static LogicalResult 255 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 256 LLVMTypeConverter &typeConverter, Location loc, 257 TransferWriteOp xferOp, 258 ArrayRef<Value> operands, Value dataPtr) { 259 unsigned align; 260 if (failed(getMemRefAlignment( 261 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 262 return failure(); 263 auto adaptor = TransferWriteOpAdaptor(operands); 264 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr, 265 align); 266 return success(); 267 } 268 269 static LogicalResult 270 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 271 LLVMTypeConverter &typeConverter, Location loc, 272 TransferWriteOp xferOp, ArrayRef<Value> operands, 273 Value dataPtr, Value mask) { 274 unsigned align; 275 if (failed(getMemRefAlignment( 276 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 277 return failure(); 278 279 auto adaptor = TransferWriteOpAdaptor(operands); 280 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 281 xferOp, adaptor.vector(), dataPtr, mask, 282 rewriter.getI32IntegerAttr(align)); 283 return success(); 284 } 285 286 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 287 ArrayRef<Value> operands) { 288 return TransferReadOpAdaptor(operands); 289 } 290 291 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 292 ArrayRef<Value> operands) { 293 return TransferWriteOpAdaptor(operands); 294 } 295 296 namespace { 297 298 /// Conversion pattern for a vector.matrix_multiply. 299 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 300 class VectorMatmulOpConversion 301 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 302 public: 303 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 304 305 LogicalResult 306 matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands, 307 ConversionPatternRewriter &rewriter) const override { 308 auto adaptor = vector::MatmulOpAdaptor(operands); 309 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 310 matmulOp, typeConverter->convertType(matmulOp.res().getType()), 311 adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), 312 matmulOp.lhs_columns(), matmulOp.rhs_columns()); 313 return success(); 314 } 315 }; 316 317 /// Conversion pattern for a vector.flat_transpose. 318 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 319 class VectorFlatTransposeOpConversion 320 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 321 public: 322 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 323 324 LogicalResult 325 matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands, 326 ConversionPatternRewriter &rewriter) const override { 327 auto adaptor = vector::FlatTransposeOpAdaptor(operands); 328 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 329 transOp, typeConverter->convertType(transOp.res().getType()), 330 adaptor.matrix(), transOp.rows(), transOp.columns()); 331 return success(); 332 } 333 }; 334 335 /// Conversion pattern for a vector.maskedload. 336 class VectorMaskedLoadOpConversion 337 : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> { 338 public: 339 using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern; 340 341 LogicalResult 342 matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands, 343 ConversionPatternRewriter &rewriter) const override { 344 auto loc = load->getLoc(); 345 auto adaptor = vector::MaskedLoadOpAdaptor(operands); 346 347 // Resolve alignment. 348 unsigned align; 349 if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(), 350 align))) 351 return failure(); 352 353 auto vtype = typeConverter->convertType(load.getResultVectorType()); 354 Value ptr; 355 if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(), 356 vtype, ptr))) 357 return failure(); 358 359 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 360 load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(), 361 rewriter.getI32IntegerAttr(align)); 362 return success(); 363 } 364 }; 365 366 /// Conversion pattern for a vector.maskedstore. 367 class VectorMaskedStoreOpConversion 368 : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> { 369 public: 370 using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern; 371 372 LogicalResult 373 matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands, 374 ConversionPatternRewriter &rewriter) const override { 375 auto loc = store->getLoc(); 376 auto adaptor = vector::MaskedStoreOpAdaptor(operands); 377 378 // Resolve alignment. 379 unsigned align; 380 if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(), 381 align))) 382 return failure(); 383 384 auto vtype = typeConverter->convertType(store.getValueVectorType()); 385 Value ptr; 386 if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(), 387 vtype, ptr))) 388 return failure(); 389 390 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 391 store, adaptor.value(), ptr, adaptor.mask(), 392 rewriter.getI32IntegerAttr(align)); 393 return success(); 394 } 395 }; 396 397 /// Conversion pattern for a vector.gather. 398 class VectorGatherOpConversion 399 : public ConvertOpToLLVMPattern<vector::GatherOp> { 400 public: 401 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 402 403 LogicalResult 404 matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands, 405 ConversionPatternRewriter &rewriter) const override { 406 auto loc = gather->getLoc(); 407 auto adaptor = vector::GatherOpAdaptor(operands); 408 409 // Resolve alignment. 410 unsigned align; 411 if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(), 412 align))) 413 return failure(); 414 415 // Get index ptrs. 416 VectorType vType = gather.getResultVectorType(); 417 Type iType = gather.getIndicesVectorType().getElementType(); 418 Value ptrs; 419 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 420 gather.getMemRefType(), vType, iType, ptrs))) 421 return failure(); 422 423 // Replace with the gather intrinsic. 424 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 425 gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), 426 adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 427 return success(); 428 } 429 }; 430 431 /// Conversion pattern for a vector.scatter. 432 class VectorScatterOpConversion 433 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 434 public: 435 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 436 437 LogicalResult 438 matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands, 439 ConversionPatternRewriter &rewriter) const override { 440 auto loc = scatter->getLoc(); 441 auto adaptor = vector::ScatterOpAdaptor(operands); 442 443 // Resolve alignment. 444 unsigned align; 445 if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(), 446 align))) 447 return failure(); 448 449 // Get index ptrs. 450 VectorType vType = scatter.getValueVectorType(); 451 Type iType = scatter.getIndicesVectorType().getElementType(); 452 Value ptrs; 453 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 454 scatter.getMemRefType(), vType, iType, ptrs))) 455 return failure(); 456 457 // Replace with the scatter intrinsic. 458 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 459 scatter, adaptor.value(), ptrs, adaptor.mask(), 460 rewriter.getI32IntegerAttr(align)); 461 return success(); 462 } 463 }; 464 465 /// Conversion pattern for a vector.expandload. 466 class VectorExpandLoadOpConversion 467 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 468 public: 469 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 470 471 LogicalResult 472 matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands, 473 ConversionPatternRewriter &rewriter) const override { 474 auto loc = expand->getLoc(); 475 auto adaptor = vector::ExpandLoadOpAdaptor(operands); 476 477 Value ptr; 478 if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(), 479 ptr))) 480 return failure(); 481 482 auto vType = expand.getResultVectorType(); 483 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 484 expand, typeConverter->convertType(vType), ptr, adaptor.mask(), 485 adaptor.pass_thru()); 486 return success(); 487 } 488 }; 489 490 /// Conversion pattern for a vector.compressstore. 491 class VectorCompressStoreOpConversion 492 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 493 public: 494 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 495 496 LogicalResult 497 matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands, 498 ConversionPatternRewriter &rewriter) const override { 499 auto loc = compress->getLoc(); 500 auto adaptor = vector::CompressStoreOpAdaptor(operands); 501 502 Value ptr; 503 if (failed(getBasePtr(rewriter, loc, adaptor.base(), 504 compress.getMemRefType(), ptr))) 505 return failure(); 506 507 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 508 compress, adaptor.value(), ptr, adaptor.mask()); 509 return success(); 510 } 511 }; 512 513 /// Conversion pattern for all vector reductions. 514 class VectorReductionOpConversion 515 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 516 public: 517 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 518 bool reassociateFPRed) 519 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 520 reassociateFPReductions(reassociateFPRed) {} 521 522 LogicalResult 523 matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands, 524 ConversionPatternRewriter &rewriter) const override { 525 auto kind = reductionOp.kind(); 526 Type eltType = reductionOp.dest().getType(); 527 Type llvmType = typeConverter->convertType(eltType); 528 if (eltType.isIntOrIndex()) { 529 // Integer reductions: add/mul/min/max/and/or/xor. 530 if (kind == "add") 531 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>( 532 reductionOp, llvmType, operands[0]); 533 else if (kind == "mul") 534 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>( 535 reductionOp, llvmType, operands[0]); 536 else if (kind == "min" && 537 (eltType.isIndex() || eltType.isUnsignedInteger())) 538 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 539 reductionOp, llvmType, operands[0]); 540 else if (kind == "min") 541 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 542 reductionOp, llvmType, operands[0]); 543 else if (kind == "max" && 544 (eltType.isIndex() || eltType.isUnsignedInteger())) 545 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 546 reductionOp, llvmType, operands[0]); 547 else if (kind == "max") 548 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 549 reductionOp, llvmType, operands[0]); 550 else if (kind == "and") 551 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>( 552 reductionOp, llvmType, operands[0]); 553 else if (kind == "or") 554 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>( 555 reductionOp, llvmType, operands[0]); 556 else if (kind == "xor") 557 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>( 558 reductionOp, llvmType, operands[0]); 559 else 560 return failure(); 561 return success(); 562 } 563 564 if (!eltType.isa<FloatType>()) 565 return failure(); 566 567 // Floating-point reductions: add/mul/min/max 568 if (kind == "add") { 569 // Optional accumulator (or zero). 570 Value acc = operands.size() > 1 ? operands[1] 571 : rewriter.create<LLVM::ConstantOp>( 572 reductionOp->getLoc(), llvmType, 573 rewriter.getZeroAttr(eltType)); 574 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 575 reductionOp, llvmType, acc, operands[0], 576 rewriter.getBoolAttr(reassociateFPReductions)); 577 } else if (kind == "mul") { 578 // Optional accumulator (or one). 579 Value acc = operands.size() > 1 580 ? operands[1] 581 : rewriter.create<LLVM::ConstantOp>( 582 reductionOp->getLoc(), llvmType, 583 rewriter.getFloatAttr(eltType, 1.0)); 584 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 585 reductionOp, llvmType, acc, operands[0], 586 rewriter.getBoolAttr(reassociateFPReductions)); 587 } else if (kind == "min") 588 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>( 589 reductionOp, llvmType, operands[0]); 590 else if (kind == "max") 591 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>( 592 reductionOp, llvmType, operands[0]); 593 else 594 return failure(); 595 return success(); 596 } 597 598 private: 599 const bool reassociateFPReductions; 600 }; 601 602 /// Conversion pattern for a vector.create_mask (1-D only). 603 class VectorCreateMaskOpConversion 604 : public ConvertOpToLLVMPattern<vector::CreateMaskOp> { 605 public: 606 explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv, 607 bool enableIndexOpt) 608 : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv), 609 enableIndexOptimizations(enableIndexOpt) {} 610 611 LogicalResult 612 matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands, 613 ConversionPatternRewriter &rewriter) const override { 614 auto dstType = op.getType(); 615 int64_t rank = dstType.getRank(); 616 if (rank == 1) { 617 rewriter.replaceOp( 618 op, buildVectorComparison(rewriter, op, enableIndexOptimizations, 619 dstType.getDimSize(0), operands[0])); 620 return success(); 621 } 622 return failure(); 623 } 624 625 private: 626 const bool enableIndexOptimizations; 627 }; 628 629 class VectorShuffleOpConversion 630 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 631 public: 632 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 633 634 LogicalResult 635 matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands, 636 ConversionPatternRewriter &rewriter) const override { 637 auto loc = shuffleOp->getLoc(); 638 auto adaptor = vector::ShuffleOpAdaptor(operands); 639 auto v1Type = shuffleOp.getV1VectorType(); 640 auto v2Type = shuffleOp.getV2VectorType(); 641 auto vectorType = shuffleOp.getVectorType(); 642 Type llvmType = typeConverter->convertType(vectorType); 643 auto maskArrayAttr = shuffleOp.mask(); 644 645 // Bail if result type cannot be lowered. 646 if (!llvmType) 647 return failure(); 648 649 // Get rank and dimension sizes. 650 int64_t rank = vectorType.getRank(); 651 assert(v1Type.getRank() == rank); 652 assert(v2Type.getRank() == rank); 653 int64_t v1Dim = v1Type.getDimSize(0); 654 655 // For rank 1, where both operands have *exactly* the same vector type, 656 // there is direct shuffle support in LLVM. Use it! 657 if (rank == 1 && v1Type == v2Type) { 658 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 659 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 660 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 661 return success(); 662 } 663 664 // For all other cases, insert the individual values individually. 665 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 666 int64_t insPos = 0; 667 for (auto en : llvm::enumerate(maskArrayAttr)) { 668 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 669 Value value = adaptor.v1(); 670 if (extPos >= v1Dim) { 671 extPos -= v1Dim; 672 value = adaptor.v2(); 673 } 674 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 675 llvmType, rank, extPos); 676 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 677 llvmType, rank, insPos++); 678 } 679 rewriter.replaceOp(shuffleOp, insert); 680 return success(); 681 } 682 }; 683 684 class VectorExtractElementOpConversion 685 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 686 public: 687 using ConvertOpToLLVMPattern< 688 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 689 690 LogicalResult 691 matchAndRewrite(vector::ExtractElementOp extractEltOp, 692 ArrayRef<Value> operands, 693 ConversionPatternRewriter &rewriter) const override { 694 auto adaptor = vector::ExtractElementOpAdaptor(operands); 695 auto vectorType = extractEltOp.getVectorType(); 696 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 697 698 // Bail if result type cannot be lowered. 699 if (!llvmType) 700 return failure(); 701 702 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 703 extractEltOp, llvmType, adaptor.vector(), adaptor.position()); 704 return success(); 705 } 706 }; 707 708 class VectorExtractOpConversion 709 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 710 public: 711 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 712 713 LogicalResult 714 matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 715 ConversionPatternRewriter &rewriter) const override { 716 auto loc = extractOp->getLoc(); 717 auto adaptor = vector::ExtractOpAdaptor(operands); 718 auto vectorType = extractOp.getVectorType(); 719 auto resultType = extractOp.getResult().getType(); 720 auto llvmResultType = typeConverter->convertType(resultType); 721 auto positionArrayAttr = extractOp.position(); 722 723 // Bail if result type cannot be lowered. 724 if (!llvmResultType) 725 return failure(); 726 727 // One-shot extraction of vector from array (only requires extractvalue). 728 if (resultType.isa<VectorType>()) { 729 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 730 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 731 rewriter.replaceOp(extractOp, extracted); 732 return success(); 733 } 734 735 // Potential extraction of 1-D vector from array. 736 auto *context = extractOp->getContext(); 737 Value extracted = adaptor.vector(); 738 auto positionAttrs = positionArrayAttr.getValue(); 739 if (positionAttrs.size() > 1) { 740 auto oneDVectorType = reducedVectorTypeBack(vectorType); 741 auto nMinusOnePositionAttrs = 742 ArrayAttr::get(positionAttrs.drop_back(), context); 743 extracted = rewriter.create<LLVM::ExtractValueOp>( 744 loc, typeConverter->convertType(oneDVectorType), extracted, 745 nMinusOnePositionAttrs); 746 } 747 748 // Remaining extraction of element from 1-D LLVM vector 749 auto position = positionAttrs.back().cast<IntegerAttr>(); 750 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 751 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 752 extracted = 753 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 754 rewriter.replaceOp(extractOp, extracted); 755 756 return success(); 757 } 758 }; 759 760 /// Conversion pattern that turns a vector.fma on a 1-D vector 761 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 762 /// This does not match vectors of n >= 2 rank. 763 /// 764 /// Example: 765 /// ``` 766 /// vector.fma %a, %a, %a : vector<8xf32> 767 /// ``` 768 /// is converted to: 769 /// ``` 770 /// llvm.intr.fmuladd %va, %va, %va: 771 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 772 /// -> !llvm."<8 x f32>"> 773 /// ``` 774 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 775 public: 776 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 777 778 LogicalResult 779 matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands, 780 ConversionPatternRewriter &rewriter) const override { 781 auto adaptor = vector::FMAOpAdaptor(operands); 782 VectorType vType = fmaOp.getVectorType(); 783 if (vType.getRank() != 1) 784 return failure(); 785 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(), 786 adaptor.rhs(), adaptor.acc()); 787 return success(); 788 } 789 }; 790 791 class VectorInsertElementOpConversion 792 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 793 public: 794 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 795 796 LogicalResult 797 matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands, 798 ConversionPatternRewriter &rewriter) const override { 799 auto adaptor = vector::InsertElementOpAdaptor(operands); 800 auto vectorType = insertEltOp.getDestVectorType(); 801 auto llvmType = typeConverter->convertType(vectorType); 802 803 // Bail if result type cannot be lowered. 804 if (!llvmType) 805 return failure(); 806 807 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 808 insertEltOp, llvmType, adaptor.dest(), adaptor.source(), 809 adaptor.position()); 810 return success(); 811 } 812 }; 813 814 class VectorInsertOpConversion 815 : public ConvertOpToLLVMPattern<vector::InsertOp> { 816 public: 817 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 818 819 LogicalResult 820 matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 821 ConversionPatternRewriter &rewriter) const override { 822 auto loc = insertOp->getLoc(); 823 auto adaptor = vector::InsertOpAdaptor(operands); 824 auto sourceType = insertOp.getSourceType(); 825 auto destVectorType = insertOp.getDestVectorType(); 826 auto llvmResultType = typeConverter->convertType(destVectorType); 827 auto positionArrayAttr = insertOp.position(); 828 829 // Bail if result type cannot be lowered. 830 if (!llvmResultType) 831 return failure(); 832 833 // One-shot insertion of a vector into an array (only requires insertvalue). 834 if (sourceType.isa<VectorType>()) { 835 Value inserted = rewriter.create<LLVM::InsertValueOp>( 836 loc, llvmResultType, adaptor.dest(), adaptor.source(), 837 positionArrayAttr); 838 rewriter.replaceOp(insertOp, inserted); 839 return success(); 840 } 841 842 // Potential extraction of 1-D vector from array. 843 auto *context = insertOp->getContext(); 844 Value extracted = adaptor.dest(); 845 auto positionAttrs = positionArrayAttr.getValue(); 846 auto position = positionAttrs.back().cast<IntegerAttr>(); 847 auto oneDVectorType = destVectorType; 848 if (positionAttrs.size() > 1) { 849 oneDVectorType = reducedVectorTypeBack(destVectorType); 850 auto nMinusOnePositionAttrs = 851 ArrayAttr::get(positionAttrs.drop_back(), context); 852 extracted = rewriter.create<LLVM::ExtractValueOp>( 853 loc, typeConverter->convertType(oneDVectorType), extracted, 854 nMinusOnePositionAttrs); 855 } 856 857 // Insertion of an element into a 1-D LLVM vector. 858 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 859 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 860 Value inserted = rewriter.create<LLVM::InsertElementOp>( 861 loc, typeConverter->convertType(oneDVectorType), extracted, 862 adaptor.source(), constant); 863 864 // Potential insertion of resulting 1-D vector into array. 865 if (positionAttrs.size() > 1) { 866 auto nMinusOnePositionAttrs = 867 ArrayAttr::get(positionAttrs.drop_back(), context); 868 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 869 adaptor.dest(), inserted, 870 nMinusOnePositionAttrs); 871 } 872 873 rewriter.replaceOp(insertOp, inserted); 874 return success(); 875 } 876 }; 877 878 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 879 /// 880 /// Example: 881 /// ``` 882 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 883 /// ``` 884 /// is rewritten into: 885 /// ``` 886 /// %r = splat %f0: vector<2x4xf32> 887 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 888 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 889 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 890 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 891 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 892 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 893 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 894 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 895 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 896 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 897 /// // %r3 holds the final value. 898 /// ``` 899 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 900 public: 901 using OpRewritePattern<FMAOp>::OpRewritePattern; 902 903 LogicalResult matchAndRewrite(FMAOp op, 904 PatternRewriter &rewriter) const override { 905 auto vType = op.getVectorType(); 906 if (vType.getRank() < 2) 907 return failure(); 908 909 auto loc = op.getLoc(); 910 auto elemType = vType.getElementType(); 911 Value zero = rewriter.create<ConstantOp>(loc, elemType, 912 rewriter.getZeroAttr(elemType)); 913 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 914 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 915 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 916 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 917 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 918 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 919 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 920 } 921 rewriter.replaceOp(op, desc); 922 return success(); 923 } 924 }; 925 926 // When ranks are different, InsertStridedSlice needs to extract a properly 927 // ranked vector from the destination vector into which to insert. This pattern 928 // only takes care of this part and forwards the rest of the conversion to 929 // another pattern that converts InsertStridedSlice for operands of the same 930 // rank. 931 // 932 // RewritePattern for InsertStridedSliceOp where source and destination vectors 933 // have different ranks. In this case: 934 // 1. the proper subvector is extracted from the destination vector 935 // 2. a new InsertStridedSlice op is created to insert the source in the 936 // destination subvector 937 // 3. the destination subvector is inserted back in the proper place 938 // 4. the op is replaced by the result of step 3. 939 // The new InsertStridedSlice from step 2. will be picked up by a 940 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 941 class VectorInsertStridedSliceOpDifferentRankRewritePattern 942 : public OpRewritePattern<InsertStridedSliceOp> { 943 public: 944 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 945 946 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 947 PatternRewriter &rewriter) const override { 948 auto srcType = op.getSourceVectorType(); 949 auto dstType = op.getDestVectorType(); 950 951 if (op.offsets().getValue().empty()) 952 return failure(); 953 954 auto loc = op.getLoc(); 955 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 956 assert(rankDiff >= 0); 957 if (rankDiff == 0) 958 return failure(); 959 960 int64_t rankRest = dstType.getRank() - rankDiff; 961 // Extract / insert the subvector of matching rank and InsertStridedSlice 962 // on it. 963 Value extracted = 964 rewriter.create<ExtractOp>(loc, op.dest(), 965 getI64SubArray(op.offsets(), /*dropFront=*/0, 966 /*dropBack=*/rankRest)); 967 // A different pattern will kick in for InsertStridedSlice with matching 968 // ranks. 969 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 970 loc, op.source(), extracted, 971 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 972 getI64SubArray(op.strides(), /*dropFront=*/0)); 973 rewriter.replaceOpWithNewOp<InsertOp>( 974 op, stridedSliceInnerOp.getResult(), op.dest(), 975 getI64SubArray(op.offsets(), /*dropFront=*/0, 976 /*dropBack=*/rankRest)); 977 return success(); 978 } 979 }; 980 981 // RewritePattern for InsertStridedSliceOp where source and destination vectors 982 // have the same rank. In this case, we reduce 983 // 1. the proper subvector is extracted from the destination vector 984 // 2. a new InsertStridedSlice op is created to insert the source in the 985 // destination subvector 986 // 3. the destination subvector is inserted back in the proper place 987 // 4. the op is replaced by the result of step 3. 988 // The new InsertStridedSlice from step 2. will be picked up by a 989 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 990 class VectorInsertStridedSliceOpSameRankRewritePattern 991 : public OpRewritePattern<InsertStridedSliceOp> { 992 public: 993 VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx) 994 : OpRewritePattern<InsertStridedSliceOp>(ctx) { 995 // This pattern creates recursive InsertStridedSliceOp, but the recursion is 996 // bounded as the rank is strictly decreasing. 997 setHasBoundedRewriteRecursion(); 998 } 999 1000 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 1001 PatternRewriter &rewriter) const override { 1002 auto srcType = op.getSourceVectorType(); 1003 auto dstType = op.getDestVectorType(); 1004 1005 if (op.offsets().getValue().empty()) 1006 return failure(); 1007 1008 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 1009 assert(rankDiff >= 0); 1010 if (rankDiff != 0) 1011 return failure(); 1012 1013 if (srcType == dstType) { 1014 rewriter.replaceOp(op, op.source()); 1015 return success(); 1016 } 1017 1018 int64_t offset = 1019 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1020 int64_t size = srcType.getShape().front(); 1021 int64_t stride = 1022 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1023 1024 auto loc = op.getLoc(); 1025 Value res = op.dest(); 1026 // For each slice of the source vector along the most major dimension. 1027 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1028 off += stride, ++idx) { 1029 // 1. extract the proper subvector (or element) from source 1030 Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 1031 if (extractedSource.getType().isa<VectorType>()) { 1032 // 2. If we have a vector, extract the proper subvector from destination 1033 // Otherwise we are at the element level and no need to recurse. 1034 Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 1035 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 1036 // smaller rank. 1037 extractedSource = rewriter.create<InsertStridedSliceOp>( 1038 loc, extractedSource, extractedDest, 1039 getI64SubArray(op.offsets(), /* dropFront=*/1), 1040 getI64SubArray(op.strides(), /* dropFront=*/1)); 1041 } 1042 // 4. Insert the extractedSource into the res vector. 1043 res = insertOne(rewriter, loc, extractedSource, res, off); 1044 } 1045 1046 rewriter.replaceOp(op, res); 1047 return success(); 1048 } 1049 }; 1050 1051 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1052 /// static layout. 1053 static llvm::Optional<SmallVector<int64_t, 4>> 1054 computeContiguousStrides(MemRefType memRefType) { 1055 int64_t offset; 1056 SmallVector<int64_t, 4> strides; 1057 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1058 return None; 1059 if (!strides.empty() && strides.back() != 1) 1060 return None; 1061 // If no layout or identity layout, this is contiguous by definition. 1062 if (memRefType.getAffineMaps().empty() || 1063 memRefType.getAffineMaps().front().isIdentity()) 1064 return strides; 1065 1066 // Otherwise, we must determine contiguity form shapes. This can only ever 1067 // work in static cases because MemRefType is underspecified to represent 1068 // contiguous dynamic shapes in other ways than with just empty/identity 1069 // layout. 1070 auto sizes = memRefType.getShape(); 1071 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 1072 if (ShapedType::isDynamic(sizes[index + 1]) || 1073 ShapedType::isDynamicStrideOrOffset(strides[index]) || 1074 ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 1075 return None; 1076 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1077 return None; 1078 } 1079 return strides; 1080 } 1081 1082 class VectorTypeCastOpConversion 1083 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1084 public: 1085 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1086 1087 LogicalResult 1088 matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands, 1089 ConversionPatternRewriter &rewriter) const override { 1090 auto loc = castOp->getLoc(); 1091 MemRefType sourceMemRefType = 1092 castOp.getOperand().getType().cast<MemRefType>(); 1093 MemRefType targetMemRefType = castOp.getType(); 1094 1095 // Only static shape casts supported atm. 1096 if (!sourceMemRefType.hasStaticShape() || 1097 !targetMemRefType.hasStaticShape()) 1098 return failure(); 1099 1100 auto llvmSourceDescriptorTy = 1101 operands[0].getType().dyn_cast<LLVM::LLVMStructType>(); 1102 if (!llvmSourceDescriptorTy) 1103 return failure(); 1104 MemRefDescriptor sourceMemRef(operands[0]); 1105 1106 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1107 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1108 if (!llvmTargetDescriptorTy) 1109 return failure(); 1110 1111 // Only contiguous source buffers supported atm. 1112 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1113 if (!sourceStrides) 1114 return failure(); 1115 auto targetStrides = computeContiguousStrides(targetMemRefType); 1116 if (!targetStrides) 1117 return failure(); 1118 // Only support static strides for now, regardless of contiguity. 1119 if (llvm::any_of(*targetStrides, [](int64_t stride) { 1120 return ShapedType::isDynamicStrideOrOffset(stride); 1121 })) 1122 return failure(); 1123 1124 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1125 1126 // Create descriptor. 1127 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1128 Type llvmTargetElementTy = desc.getElementPtrType(); 1129 // Set allocated ptr. 1130 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1131 allocated = 1132 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1133 desc.setAllocatedPtr(rewriter, loc, allocated); 1134 // Set aligned ptr. 1135 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1136 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1137 desc.setAlignedPtr(rewriter, loc, ptr); 1138 // Fill offset 0. 1139 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1140 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1141 desc.setOffset(rewriter, loc, zero); 1142 1143 // Fill size and stride descriptors in memref. 1144 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 1145 int64_t index = indexedSize.index(); 1146 auto sizeAttr = 1147 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1148 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1149 desc.setSize(rewriter, loc, index, size); 1150 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1151 (*targetStrides)[index]); 1152 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1153 desc.setStride(rewriter, loc, index, stride); 1154 } 1155 1156 rewriter.replaceOp(castOp, {desc}); 1157 return success(); 1158 } 1159 }; 1160 1161 /// Conversion pattern that converts a 1-D vector transfer read/write op in a 1162 /// sequence of: 1163 /// 1. Get the source/dst address as an LLVM vector pointer. 1164 /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1165 /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1166 /// 4. Create a mask where offsetVector is compared against memref upper bound. 1167 /// 5. Rewrite op as a masked read or write. 1168 template <typename ConcreteOp> 1169 class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> { 1170 public: 1171 explicit VectorTransferConversion(LLVMTypeConverter &typeConv, 1172 bool enableIndexOpt) 1173 : ConvertOpToLLVMPattern<ConcreteOp>(typeConv), 1174 enableIndexOptimizations(enableIndexOpt) {} 1175 1176 LogicalResult 1177 matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands, 1178 ConversionPatternRewriter &rewriter) const override { 1179 auto adaptor = getTransferOpAdapter(xferOp, operands); 1180 1181 if (xferOp.getVectorType().getRank() > 1 || 1182 llvm::size(xferOp.indices()) == 0) 1183 return failure(); 1184 if (xferOp.permutation_map() != 1185 AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 1186 xferOp.getVectorType().getRank(), 1187 xferOp->getContext())) 1188 return failure(); 1189 auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 1190 if (!memRefType) 1191 return failure(); 1192 // Only contiguous source tensors supported atm. 1193 auto strides = computeContiguousStrides(memRefType); 1194 if (!strides) 1195 return failure(); 1196 1197 auto toLLVMTy = [&](Type t) { 1198 return this->getTypeConverter()->convertType(t); 1199 }; 1200 1201 Location loc = xferOp->getLoc(); 1202 1203 if (auto memrefVectorElementType = 1204 memRefType.getElementType().template dyn_cast<VectorType>()) { 1205 // Memref has vector element type. 1206 if (memrefVectorElementType.getElementType() != 1207 xferOp.getVectorType().getElementType()) 1208 return failure(); 1209 #ifndef NDEBUG 1210 // Check that memref vector type is a suffix of 'vectorType. 1211 unsigned memrefVecEltRank = memrefVectorElementType.getRank(); 1212 unsigned resultVecRank = xferOp.getVectorType().getRank(); 1213 assert(memrefVecEltRank <= resultVecRank); 1214 // TODO: Move this to isSuffix in Vector/Utils.h. 1215 unsigned rankOffset = resultVecRank - memrefVecEltRank; 1216 auto memrefVecEltShape = memrefVectorElementType.getShape(); 1217 auto resultVecShape = xferOp.getVectorType().getShape(); 1218 for (unsigned i = 0; i < memrefVecEltRank; ++i) 1219 assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && 1220 "memref vector element shape should match suffix of vector " 1221 "result shape."); 1222 #endif // ifndef NDEBUG 1223 } 1224 1225 // 1. Get the source/dst address as an LLVM vector pointer. 1226 // The vector pointer would always be on address space 0, therefore 1227 // addrspacecast shall be used when source/dst memrefs are not on 1228 // address space 0. 1229 // TODO: support alignment when possible. 1230 Value dataPtr = this->getStridedElementPtr( 1231 loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); 1232 auto vecTy = toLLVMTy(xferOp.getVectorType()) 1233 .template cast<LLVM::LLVMFixedVectorType>(); 1234 Value vectorDataPtr; 1235 if (memRefType.getMemorySpace() == 0) 1236 vectorDataPtr = rewriter.create<LLVM::BitcastOp>( 1237 loc, LLVM::LLVMPointerType::get(vecTy), dataPtr); 1238 else 1239 vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 1240 loc, LLVM::LLVMPointerType::get(vecTy), dataPtr); 1241 1242 if (!xferOp.isMaskedDim(0)) 1243 return replaceTransferOpWithLoadOrStore(rewriter, 1244 *this->getTypeConverter(), loc, 1245 xferOp, operands, vectorDataPtr); 1246 1247 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1248 // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1249 // 4. Let dim the memref dimension, compute the vector comparison mask: 1250 // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 1251 // 1252 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 1253 // dimensions here. 1254 unsigned vecWidth = vecTy.getNumElements(); 1255 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 1256 Value off = xferOp.indices()[lastIndex]; 1257 Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex); 1258 Value mask = buildVectorComparison( 1259 rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off); 1260 1261 // 5. Rewrite as a masked read / write. 1262 return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc, 1263 xferOp, operands, vectorDataPtr, mask); 1264 } 1265 1266 private: 1267 const bool enableIndexOptimizations; 1268 }; 1269 1270 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1271 public: 1272 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1273 1274 // Proof-of-concept lowering implementation that relies on a small 1275 // runtime support library, which only needs to provide a few 1276 // printing methods (single value for all data types, opening/closing 1277 // bracket, comma, newline). The lowering fully unrolls a vector 1278 // in terms of these elementary printing operations. The advantage 1279 // of this approach is that the library can remain unaware of all 1280 // low-level implementation details of vectors while still supporting 1281 // output of any shaped and dimensioned vector. Due to full unrolling, 1282 // this approach is less suited for very large vectors though. 1283 // 1284 // TODO: rely solely on libc in future? something else? 1285 // 1286 LogicalResult 1287 matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands, 1288 ConversionPatternRewriter &rewriter) const override { 1289 auto adaptor = vector::PrintOpAdaptor(operands); 1290 Type printType = printOp.getPrintType(); 1291 1292 if (typeConverter->convertType(printType) == nullptr) 1293 return failure(); 1294 1295 // Make sure element type has runtime support. 1296 PrintConversion conversion = PrintConversion::None; 1297 VectorType vectorType = printType.dyn_cast<VectorType>(); 1298 Type eltType = vectorType ? vectorType.getElementType() : printType; 1299 Operation *printer; 1300 if (eltType.isF32()) { 1301 printer = getPrintFloat(printOp); 1302 } else if (eltType.isF64()) { 1303 printer = getPrintDouble(printOp); 1304 } else if (eltType.isIndex()) { 1305 printer = getPrintU64(printOp); 1306 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1307 // Integers need a zero or sign extension on the operand 1308 // (depending on the source type) as well as a signed or 1309 // unsigned print method. Up to 64-bit is supported. 1310 unsigned width = intTy.getWidth(); 1311 if (intTy.isUnsigned()) { 1312 if (width <= 64) { 1313 if (width < 64) 1314 conversion = PrintConversion::ZeroExt64; 1315 printer = getPrintU64(printOp); 1316 } else { 1317 return failure(); 1318 } 1319 } else { 1320 assert(intTy.isSignless() || intTy.isSigned()); 1321 if (width <= 64) { 1322 // Note that we *always* zero extend booleans (1-bit integers), 1323 // so that true/false is printed as 1/0 rather than -1/0. 1324 if (width == 1) 1325 conversion = PrintConversion::ZeroExt64; 1326 else if (width < 64) 1327 conversion = PrintConversion::SignExt64; 1328 printer = getPrintI64(printOp); 1329 } else { 1330 return failure(); 1331 } 1332 } 1333 } else { 1334 return failure(); 1335 } 1336 1337 // Unroll vector into elementary print calls. 1338 int64_t rank = vectorType ? vectorType.getRank() : 0; 1339 emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank, 1340 conversion); 1341 emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp)); 1342 rewriter.eraseOp(printOp); 1343 return success(); 1344 } 1345 1346 private: 1347 enum class PrintConversion { 1348 // clang-format off 1349 None, 1350 ZeroExt64, 1351 SignExt64 1352 // clang-format on 1353 }; 1354 1355 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1356 Value value, VectorType vectorType, Operation *printer, 1357 int64_t rank, PrintConversion conversion) const { 1358 Location loc = op->getLoc(); 1359 if (rank == 0) { 1360 switch (conversion) { 1361 case PrintConversion::ZeroExt64: 1362 value = rewriter.create<ZeroExtendIOp>( 1363 loc, value, IntegerType::get(rewriter.getContext(), 64)); 1364 break; 1365 case PrintConversion::SignExt64: 1366 value = rewriter.create<SignExtendIOp>( 1367 loc, value, IntegerType::get(rewriter.getContext(), 64)); 1368 break; 1369 case PrintConversion::None: 1370 break; 1371 } 1372 emitCall(rewriter, loc, printer, value); 1373 return; 1374 } 1375 1376 emitCall(rewriter, loc, getPrintOpen(op)); 1377 Operation *printComma = getPrintComma(op); 1378 int64_t dim = vectorType.getDimSize(0); 1379 for (int64_t d = 0; d < dim; ++d) { 1380 auto reducedType = 1381 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1382 auto llvmType = typeConverter->convertType( 1383 rank > 1 ? reducedType : vectorType.getElementType()); 1384 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1385 llvmType, rank, d); 1386 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1387 conversion); 1388 if (d != dim - 1) 1389 emitCall(rewriter, loc, printComma); 1390 } 1391 emitCall(rewriter, loc, getPrintClose(op)); 1392 } 1393 1394 // Helper to emit a call. 1395 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1396 Operation *ref, ValueRange params = ValueRange()) { 1397 rewriter.create<LLVM::CallOp>(loc, TypeRange(), 1398 rewriter.getSymbolRefAttr(ref), params); 1399 } 1400 1401 // Helper for printer method declaration (first hit) and lookup. 1402 static Operation *getPrint(Operation *op, StringRef name, 1403 ArrayRef<Type> params) { 1404 auto module = op->getParentOfType<ModuleOp>(); 1405 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1406 if (func) 1407 return func; 1408 OpBuilder moduleBuilder(module.getBodyRegion()); 1409 return moduleBuilder.create<LLVM::LLVMFuncOp>( 1410 op->getLoc(), name, 1411 LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()), 1412 params)); 1413 } 1414 1415 // Helpers for method names. 1416 Operation *getPrintI64(Operation *op) const { 1417 return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64)); 1418 } 1419 Operation *getPrintU64(Operation *op) const { 1420 return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64)); 1421 } 1422 Operation *getPrintFloat(Operation *op) const { 1423 return getPrint(op, "printF32", Float32Type::get(op->getContext())); 1424 } 1425 Operation *getPrintDouble(Operation *op) const { 1426 return getPrint(op, "printF64", Float64Type::get(op->getContext())); 1427 } 1428 Operation *getPrintOpen(Operation *op) const { 1429 return getPrint(op, "printOpen", {}); 1430 } 1431 Operation *getPrintClose(Operation *op) const { 1432 return getPrint(op, "printClose", {}); 1433 } 1434 Operation *getPrintComma(Operation *op) const { 1435 return getPrint(op, "printComma", {}); 1436 } 1437 Operation *getPrintNewline(Operation *op) const { 1438 return getPrint(op, "printNewline", {}); 1439 } 1440 }; 1441 1442 /// Progressive lowering of ExtractStridedSliceOp to either: 1443 /// 1. express single offset extract as a direct shuffle. 1444 /// 2. extract + lower rank strided_slice + insert for the n-D case. 1445 class VectorExtractStridedSliceOpConversion 1446 : public OpRewritePattern<ExtractStridedSliceOp> { 1447 public: 1448 VectorExtractStridedSliceOpConversion(MLIRContext *ctx) 1449 : OpRewritePattern<ExtractStridedSliceOp>(ctx) { 1450 // This pattern creates recursive ExtractStridedSliceOp, but the recursion 1451 // is bounded as the rank is strictly decreasing. 1452 setHasBoundedRewriteRecursion(); 1453 } 1454 1455 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 1456 PatternRewriter &rewriter) const override { 1457 auto dstType = op.getType(); 1458 1459 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 1460 1461 int64_t offset = 1462 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1463 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 1464 int64_t stride = 1465 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1466 1467 auto loc = op.getLoc(); 1468 auto elemType = dstType.getElementType(); 1469 assert(elemType.isSignlessIntOrIndexOrFloat()); 1470 1471 // Single offset can be more efficiently shuffled. 1472 if (op.offsets().getValue().size() == 1) { 1473 SmallVector<int64_t, 4> offsets; 1474 offsets.reserve(size); 1475 for (int64_t off = offset, e = offset + size * stride; off < e; 1476 off += stride) 1477 offsets.push_back(off); 1478 rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(), 1479 op.vector(), 1480 rewriter.getI64ArrayAttr(offsets)); 1481 return success(); 1482 } 1483 1484 // Extract/insert on a lower ranked extract strided slice op. 1485 Value zero = rewriter.create<ConstantOp>(loc, elemType, 1486 rewriter.getZeroAttr(elemType)); 1487 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 1488 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1489 off += stride, ++idx) { 1490 Value one = extractOne(rewriter, loc, op.vector(), off); 1491 Value extracted = rewriter.create<ExtractStridedSliceOp>( 1492 loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), 1493 getI64SubArray(op.sizes(), /* dropFront=*/1), 1494 getI64SubArray(op.strides(), /* dropFront=*/1)); 1495 res = insertOne(rewriter, loc, extracted, res, idx); 1496 } 1497 rewriter.replaceOp(op, res); 1498 return success(); 1499 } 1500 }; 1501 1502 } // namespace 1503 1504 /// Populate the given list with patterns that convert from Vector to LLVM. 1505 void mlir::populateVectorToLLVMConversionPatterns( 1506 LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 1507 bool reassociateFPReductions, bool enableIndexOptimizations) { 1508 MLIRContext *ctx = converter.getDialect()->getContext(); 1509 // clang-format off 1510 patterns.insert<VectorFMAOpNDRewritePattern, 1511 VectorInsertStridedSliceOpDifferentRankRewritePattern, 1512 VectorInsertStridedSliceOpSameRankRewritePattern, 1513 VectorExtractStridedSliceOpConversion>(ctx); 1514 patterns.insert<VectorReductionOpConversion>( 1515 converter, reassociateFPReductions); 1516 patterns.insert<VectorCreateMaskOpConversion, 1517 VectorTransferConversion<TransferReadOp>, 1518 VectorTransferConversion<TransferWriteOp>>( 1519 converter, enableIndexOptimizations); 1520 patterns 1521 .insert<VectorShuffleOpConversion, 1522 VectorExtractElementOpConversion, 1523 VectorExtractOpConversion, 1524 VectorFMAOp1DConversion, 1525 VectorInsertElementOpConversion, 1526 VectorInsertOpConversion, 1527 VectorPrintOpConversion, 1528 VectorTypeCastOpConversion, 1529 VectorMaskedLoadOpConversion, 1530 VectorMaskedStoreOpConversion, 1531 VectorGatherOpConversion, 1532 VectorScatterOpConversion, 1533 VectorExpandLoadOpConversion, 1534 VectorCompressStoreOpConversion>(converter); 1535 // clang-format on 1536 } 1537 1538 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1539 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1540 patterns.insert<VectorMatmulOpConversion>(converter); 1541 patterns.insert<VectorFlatTransposeOpConversion>(converter); 1542 } 1543