1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Vector/VectorOps.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/Target/LLVMIR/TypeTranslation.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 20 using namespace mlir; 21 using namespace mlir::vector; 22 23 // Helper to reduce vector type by one rank at front. 24 static VectorType reducedVectorTypeFront(VectorType tp) { 25 assert((tp.getRank() > 1) && "unlowerable vector type"); 26 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 27 } 28 29 // Helper to reduce vector type by *all* but one rank at back. 30 static VectorType reducedVectorTypeBack(VectorType tp) { 31 assert((tp.getRank() > 1) && "unlowerable vector type"); 32 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 33 } 34 35 // Helper that picks the proper sequence for inserting. 36 static Value insertOne(ConversionPatternRewriter &rewriter, 37 LLVMTypeConverter &typeConverter, Location loc, 38 Value val1, Value val2, Type llvmType, int64_t rank, 39 int64_t pos) { 40 if (rank == 1) { 41 auto idxType = rewriter.getIndexType(); 42 auto constant = rewriter.create<LLVM::ConstantOp>( 43 loc, typeConverter.convertType(idxType), 44 rewriter.getIntegerAttr(idxType, pos)); 45 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 46 constant); 47 } 48 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 49 rewriter.getI64ArrayAttr(pos)); 50 } 51 52 // Helper that picks the proper sequence for inserting. 53 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 54 Value into, int64_t offset) { 55 auto vectorType = into.getType().cast<VectorType>(); 56 if (vectorType.getRank() > 1) 57 return rewriter.create<InsertOp>(loc, from, into, offset); 58 return rewriter.create<vector::InsertElementOp>( 59 loc, vectorType, from, into, 60 rewriter.create<ConstantIndexOp>(loc, offset)); 61 } 62 63 // Helper that picks the proper sequence for extracting. 64 static Value extractOne(ConversionPatternRewriter &rewriter, 65 LLVMTypeConverter &typeConverter, Location loc, 66 Value val, Type llvmType, int64_t rank, int64_t pos) { 67 if (rank == 1) { 68 auto idxType = rewriter.getIndexType(); 69 auto constant = rewriter.create<LLVM::ConstantOp>( 70 loc, typeConverter.convertType(idxType), 71 rewriter.getIntegerAttr(idxType, pos)); 72 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 73 constant); 74 } 75 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 76 rewriter.getI64ArrayAttr(pos)); 77 } 78 79 // Helper that picks the proper sequence for extracting. 80 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 81 int64_t offset) { 82 auto vectorType = vector.getType().cast<VectorType>(); 83 if (vectorType.getRank() > 1) 84 return rewriter.create<ExtractOp>(loc, vector, offset); 85 return rewriter.create<vector::ExtractElementOp>( 86 loc, vectorType.getElementType(), vector, 87 rewriter.create<ConstantIndexOp>(loc, offset)); 88 } 89 90 // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 91 // TODO: Better support for attribute subtype forwarding + slicing. 92 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 93 unsigned dropFront = 0, 94 unsigned dropBack = 0) { 95 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 96 auto range = arrayAttr.getAsRange<IntegerAttr>(); 97 SmallVector<int64_t, 4> res; 98 res.reserve(arrayAttr.size() - dropFront - dropBack); 99 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 100 it != eit; ++it) 101 res.push_back((*it).getValue().getSExtValue()); 102 return res; 103 } 104 105 // Helper that returns a vector comparison that constructs a mask: 106 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 107 // 108 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 109 // much more compact, IR for this operation, but LLVM eventually 110 // generates more elaborate instructions for this intrinsic since it 111 // is very conservative on the boundary conditions. 112 static Value buildVectorComparison(ConversionPatternRewriter &rewriter, 113 Operation *op, bool enableIndexOptimizations, 114 int64_t dim, Value b, Value *off = nullptr) { 115 auto loc = op->getLoc(); 116 // If we can assume all indices fit in 32-bit, we perform the vector 117 // comparison in 32-bit to get a higher degree of SIMD parallelism. 118 // Otherwise we perform the vector comparison using 64-bit indices. 119 Value indices; 120 Type idxType; 121 if (enableIndexOptimizations) { 122 indices = rewriter.create<ConstantOp>( 123 loc, rewriter.getI32VectorAttr( 124 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)))); 125 idxType = rewriter.getI32Type(); 126 } else { 127 indices = rewriter.create<ConstantOp>( 128 loc, rewriter.getI64VectorAttr( 129 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)))); 130 idxType = rewriter.getI64Type(); 131 } 132 // Add in an offset if requested. 133 if (off) { 134 Value o = rewriter.create<IndexCastOp>(loc, idxType, *off); 135 Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o); 136 indices = rewriter.create<AddIOp>(loc, ov, indices); 137 } 138 // Construct the vector comparison. 139 Value bound = rewriter.create<IndexCastOp>(loc, idxType, b); 140 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 141 return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds); 142 } 143 144 // Helper that returns data layout alignment of a memref. 145 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 146 MemRefType memrefType, unsigned &align) { 147 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 148 if (!elementTy) 149 return failure(); 150 151 // TODO: this should use the MLIR data layout when it becomes available and 152 // stop depending on translation. 153 llvm::LLVMContext llvmContext; 154 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 155 .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(), 156 typeConverter.getDataLayout()); 157 return success(); 158 } 159 160 // Helper that returns the base address of a memref. 161 static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc, 162 Value memref, MemRefType memRefType, Value &base) { 163 // Inspect stride and offset structure. 164 // 165 // TODO: flat memory only for now, generalize 166 // 167 int64_t offset; 168 SmallVector<int64_t, 4> strides; 169 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 170 if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || 171 offset != 0 || memRefType.getMemorySpace() != 0) 172 return failure(); 173 base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); 174 return success(); 175 } 176 177 // Helper that returns a pointer given a memref base. 178 static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, 179 Location loc, Value memref, 180 MemRefType memRefType, Value &ptr) { 181 Value base; 182 if (failed(getBase(rewriter, loc, memref, memRefType, base))) 183 return failure(); 184 auto pType = MemRefDescriptor(memref).getElementPtrType(); 185 ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base); 186 return success(); 187 } 188 189 // Helper that returns a bit-casted pointer given a memref base. 190 static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, 191 Location loc, Value memref, 192 MemRefType memRefType, Type type, Value &ptr) { 193 Value base; 194 if (failed(getBase(rewriter, loc, memref, memRefType, base))) 195 return failure(); 196 auto pType = LLVM::LLVMPointerType::get(type.template cast<LLVM::LLVMType>()); 197 base = rewriter.create<LLVM::BitcastOp>(loc, pType, base); 198 ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base); 199 return success(); 200 } 201 202 // Helper that returns vector of pointers given a memref base and an index 203 // vector. 204 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 205 Location loc, Value memref, Value indices, 206 MemRefType memRefType, VectorType vType, 207 Type iType, Value &ptrs) { 208 Value base; 209 if (failed(getBase(rewriter, loc, memref, memRefType, base))) 210 return failure(); 211 auto pType = MemRefDescriptor(memref).getElementPtrType(); 212 auto ptrsType = LLVM::LLVMFixedVectorType::get(pType, vType.getDimSize(0)); 213 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices); 214 return success(); 215 } 216 217 static LogicalResult 218 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 219 LLVMTypeConverter &typeConverter, Location loc, 220 TransferReadOp xferOp, 221 ArrayRef<Value> operands, Value dataPtr) { 222 unsigned align; 223 if (failed(getMemRefAlignment( 224 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 225 return failure(); 226 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align); 227 return success(); 228 } 229 230 static LogicalResult 231 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 232 LLVMTypeConverter &typeConverter, Location loc, 233 TransferReadOp xferOp, ArrayRef<Value> operands, 234 Value dataPtr, Value mask) { 235 auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 236 VectorType fillType = xferOp.getVectorType(); 237 Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 238 fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill); 239 240 Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 241 if (!vecTy) 242 return failure(); 243 244 unsigned align; 245 if (failed(getMemRefAlignment( 246 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 247 return failure(); 248 249 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 250 xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 251 rewriter.getI32IntegerAttr(align)); 252 return success(); 253 } 254 255 static LogicalResult 256 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 257 LLVMTypeConverter &typeConverter, Location loc, 258 TransferWriteOp xferOp, 259 ArrayRef<Value> operands, Value dataPtr) { 260 unsigned align; 261 if (failed(getMemRefAlignment( 262 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 263 return failure(); 264 auto adaptor = TransferWriteOpAdaptor(operands); 265 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr, 266 align); 267 return success(); 268 } 269 270 static LogicalResult 271 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 272 LLVMTypeConverter &typeConverter, Location loc, 273 TransferWriteOp xferOp, ArrayRef<Value> operands, 274 Value dataPtr, Value mask) { 275 unsigned align; 276 if (failed(getMemRefAlignment( 277 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 278 return failure(); 279 280 auto adaptor = TransferWriteOpAdaptor(operands); 281 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 282 xferOp, adaptor.vector(), dataPtr, mask, 283 rewriter.getI32IntegerAttr(align)); 284 return success(); 285 } 286 287 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 288 ArrayRef<Value> operands) { 289 return TransferReadOpAdaptor(operands); 290 } 291 292 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 293 ArrayRef<Value> operands) { 294 return TransferWriteOpAdaptor(operands); 295 } 296 297 namespace { 298 299 /// Conversion pattern for a vector.matrix_multiply. 300 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 301 class VectorMatmulOpConversion 302 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 303 public: 304 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 305 306 LogicalResult 307 matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands, 308 ConversionPatternRewriter &rewriter) const override { 309 auto adaptor = vector::MatmulOpAdaptor(operands); 310 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 311 matmulOp, typeConverter->convertType(matmulOp.res().getType()), 312 adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), 313 matmulOp.lhs_columns(), matmulOp.rhs_columns()); 314 return success(); 315 } 316 }; 317 318 /// Conversion pattern for a vector.flat_transpose. 319 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 320 class VectorFlatTransposeOpConversion 321 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 322 public: 323 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 324 325 LogicalResult 326 matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands, 327 ConversionPatternRewriter &rewriter) const override { 328 auto adaptor = vector::FlatTransposeOpAdaptor(operands); 329 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 330 transOp, typeConverter->convertType(transOp.res().getType()), 331 adaptor.matrix(), transOp.rows(), transOp.columns()); 332 return success(); 333 } 334 }; 335 336 /// Conversion pattern for a vector.maskedload. 337 class VectorMaskedLoadOpConversion 338 : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> { 339 public: 340 using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern; 341 342 LogicalResult 343 matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands, 344 ConversionPatternRewriter &rewriter) const override { 345 auto loc = load->getLoc(); 346 auto adaptor = vector::MaskedLoadOpAdaptor(operands); 347 348 // Resolve alignment. 349 unsigned align; 350 if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(), 351 align))) 352 return failure(); 353 354 auto vtype = typeConverter->convertType(load.getResultVectorType()); 355 Value ptr; 356 if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(), 357 vtype, ptr))) 358 return failure(); 359 360 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 361 load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(), 362 rewriter.getI32IntegerAttr(align)); 363 return success(); 364 } 365 }; 366 367 /// Conversion pattern for a vector.maskedstore. 368 class VectorMaskedStoreOpConversion 369 : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> { 370 public: 371 using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern; 372 373 LogicalResult 374 matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands, 375 ConversionPatternRewriter &rewriter) const override { 376 auto loc = store->getLoc(); 377 auto adaptor = vector::MaskedStoreOpAdaptor(operands); 378 379 // Resolve alignment. 380 unsigned align; 381 if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(), 382 align))) 383 return failure(); 384 385 auto vtype = typeConverter->convertType(store.getValueVectorType()); 386 Value ptr; 387 if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(), 388 vtype, ptr))) 389 return failure(); 390 391 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 392 store, adaptor.value(), ptr, adaptor.mask(), 393 rewriter.getI32IntegerAttr(align)); 394 return success(); 395 } 396 }; 397 398 /// Conversion pattern for a vector.gather. 399 class VectorGatherOpConversion 400 : public ConvertOpToLLVMPattern<vector::GatherOp> { 401 public: 402 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 403 404 LogicalResult 405 matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands, 406 ConversionPatternRewriter &rewriter) const override { 407 auto loc = gather->getLoc(); 408 auto adaptor = vector::GatherOpAdaptor(operands); 409 410 // Resolve alignment. 411 unsigned align; 412 if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(), 413 align))) 414 return failure(); 415 416 // Get index ptrs. 417 VectorType vType = gather.getResultVectorType(); 418 Type iType = gather.getIndicesVectorType().getElementType(); 419 Value ptrs; 420 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 421 gather.getMemRefType(), vType, iType, ptrs))) 422 return failure(); 423 424 // Replace with the gather intrinsic. 425 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 426 gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), 427 adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 428 return success(); 429 } 430 }; 431 432 /// Conversion pattern for a vector.scatter. 433 class VectorScatterOpConversion 434 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 435 public: 436 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 437 438 LogicalResult 439 matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands, 440 ConversionPatternRewriter &rewriter) const override { 441 auto loc = scatter->getLoc(); 442 auto adaptor = vector::ScatterOpAdaptor(operands); 443 444 // Resolve alignment. 445 unsigned align; 446 if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(), 447 align))) 448 return failure(); 449 450 // Get index ptrs. 451 VectorType vType = scatter.getValueVectorType(); 452 Type iType = scatter.getIndicesVectorType().getElementType(); 453 Value ptrs; 454 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 455 scatter.getMemRefType(), vType, iType, ptrs))) 456 return failure(); 457 458 // Replace with the scatter intrinsic. 459 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 460 scatter, adaptor.value(), ptrs, adaptor.mask(), 461 rewriter.getI32IntegerAttr(align)); 462 return success(); 463 } 464 }; 465 466 /// Conversion pattern for a vector.expandload. 467 class VectorExpandLoadOpConversion 468 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 469 public: 470 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 471 472 LogicalResult 473 matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands, 474 ConversionPatternRewriter &rewriter) const override { 475 auto loc = expand->getLoc(); 476 auto adaptor = vector::ExpandLoadOpAdaptor(operands); 477 478 Value ptr; 479 if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(), 480 ptr))) 481 return failure(); 482 483 auto vType = expand.getResultVectorType(); 484 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 485 expand, typeConverter->convertType(vType), ptr, adaptor.mask(), 486 adaptor.pass_thru()); 487 return success(); 488 } 489 }; 490 491 /// Conversion pattern for a vector.compressstore. 492 class VectorCompressStoreOpConversion 493 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 494 public: 495 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 496 497 LogicalResult 498 matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands, 499 ConversionPatternRewriter &rewriter) const override { 500 auto loc = compress->getLoc(); 501 auto adaptor = vector::CompressStoreOpAdaptor(operands); 502 503 Value ptr; 504 if (failed(getBasePtr(rewriter, loc, adaptor.base(), 505 compress.getMemRefType(), ptr))) 506 return failure(); 507 508 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 509 compress, adaptor.value(), ptr, adaptor.mask()); 510 return success(); 511 } 512 }; 513 514 /// Conversion pattern for all vector reductions. 515 class VectorReductionOpConversion 516 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 517 public: 518 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 519 bool reassociateFPRed) 520 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 521 reassociateFPReductions(reassociateFPRed) {} 522 523 LogicalResult 524 matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands, 525 ConversionPatternRewriter &rewriter) const override { 526 auto kind = reductionOp.kind(); 527 Type eltType = reductionOp.dest().getType(); 528 Type llvmType = typeConverter->convertType(eltType); 529 if (eltType.isIntOrIndex()) { 530 // Integer reductions: add/mul/min/max/and/or/xor. 531 if (kind == "add") 532 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>( 533 reductionOp, llvmType, operands[0]); 534 else if (kind == "mul") 535 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>( 536 reductionOp, llvmType, operands[0]); 537 else if (kind == "min" && 538 (eltType.isIndex() || eltType.isUnsignedInteger())) 539 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 540 reductionOp, llvmType, operands[0]); 541 else if (kind == "min") 542 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 543 reductionOp, llvmType, operands[0]); 544 else if (kind == "max" && 545 (eltType.isIndex() || eltType.isUnsignedInteger())) 546 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 547 reductionOp, llvmType, operands[0]); 548 else if (kind == "max") 549 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 550 reductionOp, llvmType, operands[0]); 551 else if (kind == "and") 552 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>( 553 reductionOp, llvmType, operands[0]); 554 else if (kind == "or") 555 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>( 556 reductionOp, llvmType, operands[0]); 557 else if (kind == "xor") 558 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>( 559 reductionOp, llvmType, operands[0]); 560 else 561 return failure(); 562 return success(); 563 } 564 565 if (!eltType.isa<FloatType>()) 566 return failure(); 567 568 // Floating-point reductions: add/mul/min/max 569 if (kind == "add") { 570 // Optional accumulator (or zero). 571 Value acc = operands.size() > 1 ? operands[1] 572 : rewriter.create<LLVM::ConstantOp>( 573 reductionOp->getLoc(), llvmType, 574 rewriter.getZeroAttr(eltType)); 575 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 576 reductionOp, llvmType, acc, operands[0], 577 rewriter.getBoolAttr(reassociateFPReductions)); 578 } else if (kind == "mul") { 579 // Optional accumulator (or one). 580 Value acc = operands.size() > 1 581 ? operands[1] 582 : rewriter.create<LLVM::ConstantOp>( 583 reductionOp->getLoc(), llvmType, 584 rewriter.getFloatAttr(eltType, 1.0)); 585 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 586 reductionOp, llvmType, acc, operands[0], 587 rewriter.getBoolAttr(reassociateFPReductions)); 588 } else if (kind == "min") 589 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>( 590 reductionOp, llvmType, operands[0]); 591 else if (kind == "max") 592 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>( 593 reductionOp, llvmType, operands[0]); 594 else 595 return failure(); 596 return success(); 597 } 598 599 private: 600 const bool reassociateFPReductions; 601 }; 602 603 /// Conversion pattern for a vector.create_mask (1-D only). 604 class VectorCreateMaskOpConversion 605 : public ConvertOpToLLVMPattern<vector::CreateMaskOp> { 606 public: 607 explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv, 608 bool enableIndexOpt) 609 : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv), 610 enableIndexOptimizations(enableIndexOpt) {} 611 612 LogicalResult 613 matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands, 614 ConversionPatternRewriter &rewriter) const override { 615 auto dstType = op->getResult(0).getType().cast<VectorType>(); 616 int64_t rank = dstType.getRank(); 617 if (rank == 1) { 618 rewriter.replaceOp( 619 op, buildVectorComparison(rewriter, op, enableIndexOptimizations, 620 dstType.getDimSize(0), operands[0])); 621 return success(); 622 } 623 return failure(); 624 } 625 626 private: 627 const bool enableIndexOptimizations; 628 }; 629 630 class VectorShuffleOpConversion 631 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 632 public: 633 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 634 635 LogicalResult 636 matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands, 637 ConversionPatternRewriter &rewriter) const override { 638 auto loc = shuffleOp->getLoc(); 639 auto adaptor = vector::ShuffleOpAdaptor(operands); 640 auto v1Type = shuffleOp.getV1VectorType(); 641 auto v2Type = shuffleOp.getV2VectorType(); 642 auto vectorType = shuffleOp.getVectorType(); 643 Type llvmType = typeConverter->convertType(vectorType); 644 auto maskArrayAttr = shuffleOp.mask(); 645 646 // Bail if result type cannot be lowered. 647 if (!llvmType) 648 return failure(); 649 650 // Get rank and dimension sizes. 651 int64_t rank = vectorType.getRank(); 652 assert(v1Type.getRank() == rank); 653 assert(v2Type.getRank() == rank); 654 int64_t v1Dim = v1Type.getDimSize(0); 655 656 // For rank 1, where both operands have *exactly* the same vector type, 657 // there is direct shuffle support in LLVM. Use it! 658 if (rank == 1 && v1Type == v2Type) { 659 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 660 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 661 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 662 return success(); 663 } 664 665 // For all other cases, insert the individual values individually. 666 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 667 int64_t insPos = 0; 668 for (auto en : llvm::enumerate(maskArrayAttr)) { 669 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 670 Value value = adaptor.v1(); 671 if (extPos >= v1Dim) { 672 extPos -= v1Dim; 673 value = adaptor.v2(); 674 } 675 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 676 llvmType, rank, extPos); 677 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 678 llvmType, rank, insPos++); 679 } 680 rewriter.replaceOp(shuffleOp, insert); 681 return success(); 682 } 683 }; 684 685 class VectorExtractElementOpConversion 686 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 687 public: 688 using ConvertOpToLLVMPattern< 689 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 690 691 LogicalResult 692 matchAndRewrite(vector::ExtractElementOp extractEltOp, 693 ArrayRef<Value> operands, 694 ConversionPatternRewriter &rewriter) const override { 695 auto adaptor = vector::ExtractElementOpAdaptor(operands); 696 auto vectorType = extractEltOp.getVectorType(); 697 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 698 699 // Bail if result type cannot be lowered. 700 if (!llvmType) 701 return failure(); 702 703 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 704 extractEltOp, llvmType, adaptor.vector(), adaptor.position()); 705 return success(); 706 } 707 }; 708 709 class VectorExtractOpConversion 710 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 711 public: 712 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 713 714 LogicalResult 715 matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 716 ConversionPatternRewriter &rewriter) const override { 717 auto loc = extractOp->getLoc(); 718 auto adaptor = vector::ExtractOpAdaptor(operands); 719 auto vectorType = extractOp.getVectorType(); 720 auto resultType = extractOp.getResult().getType(); 721 auto llvmResultType = typeConverter->convertType(resultType); 722 auto positionArrayAttr = extractOp.position(); 723 724 // Bail if result type cannot be lowered. 725 if (!llvmResultType) 726 return failure(); 727 728 // One-shot extraction of vector from array (only requires extractvalue). 729 if (resultType.isa<VectorType>()) { 730 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 731 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 732 rewriter.replaceOp(extractOp, extracted); 733 return success(); 734 } 735 736 // Potential extraction of 1-D vector from array. 737 auto *context = extractOp->getContext(); 738 Value extracted = adaptor.vector(); 739 auto positionAttrs = positionArrayAttr.getValue(); 740 if (positionAttrs.size() > 1) { 741 auto oneDVectorType = reducedVectorTypeBack(vectorType); 742 auto nMinusOnePositionAttrs = 743 ArrayAttr::get(positionAttrs.drop_back(), context); 744 extracted = rewriter.create<LLVM::ExtractValueOp>( 745 loc, typeConverter->convertType(oneDVectorType), extracted, 746 nMinusOnePositionAttrs); 747 } 748 749 // Remaining extraction of element from 1-D LLVM vector 750 auto position = positionAttrs.back().cast<IntegerAttr>(); 751 auto i64Type = LLVM::LLVMIntegerType::get(rewriter.getContext(), 64); 752 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 753 extracted = 754 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 755 rewriter.replaceOp(extractOp, extracted); 756 757 return success(); 758 } 759 }; 760 761 /// Conversion pattern that turns a vector.fma on a 1-D vector 762 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 763 /// This does not match vectors of n >= 2 rank. 764 /// 765 /// Example: 766 /// ``` 767 /// vector.fma %a, %a, %a : vector<8xf32> 768 /// ``` 769 /// is converted to: 770 /// ``` 771 /// llvm.intr.fmuladd %va, %va, %va: 772 /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) 773 /// -> !llvm<"<8 x float>"> 774 /// ``` 775 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 776 public: 777 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 778 779 LogicalResult 780 matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands, 781 ConversionPatternRewriter &rewriter) const override { 782 auto adaptor = vector::FMAOpAdaptor(operands); 783 VectorType vType = fmaOp.getVectorType(); 784 if (vType.getRank() != 1) 785 return failure(); 786 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(), 787 adaptor.rhs(), adaptor.acc()); 788 return success(); 789 } 790 }; 791 792 class VectorInsertElementOpConversion 793 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 794 public: 795 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 796 797 LogicalResult 798 matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands, 799 ConversionPatternRewriter &rewriter) const override { 800 auto adaptor = vector::InsertElementOpAdaptor(operands); 801 auto vectorType = insertEltOp.getDestVectorType(); 802 auto llvmType = typeConverter->convertType(vectorType); 803 804 // Bail if result type cannot be lowered. 805 if (!llvmType) 806 return failure(); 807 808 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 809 insertEltOp, llvmType, adaptor.dest(), adaptor.source(), 810 adaptor.position()); 811 return success(); 812 } 813 }; 814 815 class VectorInsertOpConversion 816 : public ConvertOpToLLVMPattern<vector::InsertOp> { 817 public: 818 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 819 820 LogicalResult 821 matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 822 ConversionPatternRewriter &rewriter) const override { 823 auto loc = insertOp->getLoc(); 824 auto adaptor = vector::InsertOpAdaptor(operands); 825 auto sourceType = insertOp.getSourceType(); 826 auto destVectorType = insertOp.getDestVectorType(); 827 auto llvmResultType = typeConverter->convertType(destVectorType); 828 auto positionArrayAttr = insertOp.position(); 829 830 // Bail if result type cannot be lowered. 831 if (!llvmResultType) 832 return failure(); 833 834 // One-shot insertion of a vector into an array (only requires insertvalue). 835 if (sourceType.isa<VectorType>()) { 836 Value inserted = rewriter.create<LLVM::InsertValueOp>( 837 loc, llvmResultType, adaptor.dest(), adaptor.source(), 838 positionArrayAttr); 839 rewriter.replaceOp(insertOp, inserted); 840 return success(); 841 } 842 843 // Potential extraction of 1-D vector from array. 844 auto *context = insertOp->getContext(); 845 Value extracted = adaptor.dest(); 846 auto positionAttrs = positionArrayAttr.getValue(); 847 auto position = positionAttrs.back().cast<IntegerAttr>(); 848 auto oneDVectorType = destVectorType; 849 if (positionAttrs.size() > 1) { 850 oneDVectorType = reducedVectorTypeBack(destVectorType); 851 auto nMinusOnePositionAttrs = 852 ArrayAttr::get(positionAttrs.drop_back(), context); 853 extracted = rewriter.create<LLVM::ExtractValueOp>( 854 loc, typeConverter->convertType(oneDVectorType), extracted, 855 nMinusOnePositionAttrs); 856 } 857 858 // Insertion of an element into a 1-D LLVM vector. 859 auto i64Type = LLVM::LLVMIntegerType::get(rewriter.getContext(), 64); 860 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 861 Value inserted = rewriter.create<LLVM::InsertElementOp>( 862 loc, typeConverter->convertType(oneDVectorType), extracted, 863 adaptor.source(), constant); 864 865 // Potential insertion of resulting 1-D vector into array. 866 if (positionAttrs.size() > 1) { 867 auto nMinusOnePositionAttrs = 868 ArrayAttr::get(positionAttrs.drop_back(), context); 869 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 870 adaptor.dest(), inserted, 871 nMinusOnePositionAttrs); 872 } 873 874 rewriter.replaceOp(insertOp, inserted); 875 return success(); 876 } 877 }; 878 879 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 880 /// 881 /// Example: 882 /// ``` 883 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 884 /// ``` 885 /// is rewritten into: 886 /// ``` 887 /// %r = splat %f0: vector<2x4xf32> 888 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 889 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 890 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 891 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 892 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 893 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 894 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 895 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 896 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 897 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 898 /// // %r3 holds the final value. 899 /// ``` 900 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 901 public: 902 using OpRewritePattern<FMAOp>::OpRewritePattern; 903 904 LogicalResult matchAndRewrite(FMAOp op, 905 PatternRewriter &rewriter) const override { 906 auto vType = op.getVectorType(); 907 if (vType.getRank() < 2) 908 return failure(); 909 910 auto loc = op.getLoc(); 911 auto elemType = vType.getElementType(); 912 Value zero = rewriter.create<ConstantOp>(loc, elemType, 913 rewriter.getZeroAttr(elemType)); 914 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 915 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 916 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 917 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 918 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 919 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 920 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 921 } 922 rewriter.replaceOp(op, desc); 923 return success(); 924 } 925 }; 926 927 // When ranks are different, InsertStridedSlice needs to extract a properly 928 // ranked vector from the destination vector into which to insert. This pattern 929 // only takes care of this part and forwards the rest of the conversion to 930 // another pattern that converts InsertStridedSlice for operands of the same 931 // rank. 932 // 933 // RewritePattern for InsertStridedSliceOp where source and destination vectors 934 // have different ranks. In this case: 935 // 1. the proper subvector is extracted from the destination vector 936 // 2. a new InsertStridedSlice op is created to insert the source in the 937 // destination subvector 938 // 3. the destination subvector is inserted back in the proper place 939 // 4. the op is replaced by the result of step 3. 940 // The new InsertStridedSlice from step 2. will be picked up by a 941 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 942 class VectorInsertStridedSliceOpDifferentRankRewritePattern 943 : public OpRewritePattern<InsertStridedSliceOp> { 944 public: 945 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 946 947 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 948 PatternRewriter &rewriter) const override { 949 auto srcType = op.getSourceVectorType(); 950 auto dstType = op.getDestVectorType(); 951 952 if (op.offsets().getValue().empty()) 953 return failure(); 954 955 auto loc = op.getLoc(); 956 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 957 assert(rankDiff >= 0); 958 if (rankDiff == 0) 959 return failure(); 960 961 int64_t rankRest = dstType.getRank() - rankDiff; 962 // Extract / insert the subvector of matching rank and InsertStridedSlice 963 // on it. 964 Value extracted = 965 rewriter.create<ExtractOp>(loc, op.dest(), 966 getI64SubArray(op.offsets(), /*dropFront=*/0, 967 /*dropBack=*/rankRest)); 968 // A different pattern will kick in for InsertStridedSlice with matching 969 // ranks. 970 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 971 loc, op.source(), extracted, 972 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 973 getI64SubArray(op.strides(), /*dropFront=*/0)); 974 rewriter.replaceOpWithNewOp<InsertOp>( 975 op, stridedSliceInnerOp.getResult(), op.dest(), 976 getI64SubArray(op.offsets(), /*dropFront=*/0, 977 /*dropBack=*/rankRest)); 978 return success(); 979 } 980 }; 981 982 // RewritePattern for InsertStridedSliceOp where source and destination vectors 983 // have the same rank. In this case, we reduce 984 // 1. the proper subvector is extracted from the destination vector 985 // 2. a new InsertStridedSlice op is created to insert the source in the 986 // destination subvector 987 // 3. the destination subvector is inserted back in the proper place 988 // 4. the op is replaced by the result of step 3. 989 // The new InsertStridedSlice from step 2. will be picked up by a 990 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 991 class VectorInsertStridedSliceOpSameRankRewritePattern 992 : public OpRewritePattern<InsertStridedSliceOp> { 993 public: 994 VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx) 995 : OpRewritePattern<InsertStridedSliceOp>(ctx) { 996 // This pattern creates recursive InsertStridedSliceOp, but the recursion is 997 // bounded as the rank is strictly decreasing. 998 setHasBoundedRewriteRecursion(); 999 } 1000 1001 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 1002 PatternRewriter &rewriter) const override { 1003 auto srcType = op.getSourceVectorType(); 1004 auto dstType = op.getDestVectorType(); 1005 1006 if (op.offsets().getValue().empty()) 1007 return failure(); 1008 1009 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 1010 assert(rankDiff >= 0); 1011 if (rankDiff != 0) 1012 return failure(); 1013 1014 if (srcType == dstType) { 1015 rewriter.replaceOp(op, op.source()); 1016 return success(); 1017 } 1018 1019 int64_t offset = 1020 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1021 int64_t size = srcType.getShape().front(); 1022 int64_t stride = 1023 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1024 1025 auto loc = op.getLoc(); 1026 Value res = op.dest(); 1027 // For each slice of the source vector along the most major dimension. 1028 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1029 off += stride, ++idx) { 1030 // 1. extract the proper subvector (or element) from source 1031 Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 1032 if (extractedSource.getType().isa<VectorType>()) { 1033 // 2. If we have a vector, extract the proper subvector from destination 1034 // Otherwise we are at the element level and no need to recurse. 1035 Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 1036 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 1037 // smaller rank. 1038 extractedSource = rewriter.create<InsertStridedSliceOp>( 1039 loc, extractedSource, extractedDest, 1040 getI64SubArray(op.offsets(), /* dropFront=*/1), 1041 getI64SubArray(op.strides(), /* dropFront=*/1)); 1042 } 1043 // 4. Insert the extractedSource into the res vector. 1044 res = insertOne(rewriter, loc, extractedSource, res, off); 1045 } 1046 1047 rewriter.replaceOp(op, res); 1048 return success(); 1049 } 1050 }; 1051 1052 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1053 /// static layout. 1054 static llvm::Optional<SmallVector<int64_t, 4>> 1055 computeContiguousStrides(MemRefType memRefType) { 1056 int64_t offset; 1057 SmallVector<int64_t, 4> strides; 1058 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1059 return None; 1060 if (!strides.empty() && strides.back() != 1) 1061 return None; 1062 // If no layout or identity layout, this is contiguous by definition. 1063 if (memRefType.getAffineMaps().empty() || 1064 memRefType.getAffineMaps().front().isIdentity()) 1065 return strides; 1066 1067 // Otherwise, we must determine contiguity form shapes. This can only ever 1068 // work in static cases because MemRefType is underspecified to represent 1069 // contiguous dynamic shapes in other ways than with just empty/identity 1070 // layout. 1071 auto sizes = memRefType.getShape(); 1072 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 1073 if (ShapedType::isDynamic(sizes[index + 1]) || 1074 ShapedType::isDynamicStrideOrOffset(strides[index]) || 1075 ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 1076 return None; 1077 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1078 return None; 1079 } 1080 return strides; 1081 } 1082 1083 class VectorTypeCastOpConversion 1084 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1085 public: 1086 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1087 1088 LogicalResult 1089 matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands, 1090 ConversionPatternRewriter &rewriter) const override { 1091 auto loc = castOp->getLoc(); 1092 MemRefType sourceMemRefType = 1093 castOp.getOperand().getType().cast<MemRefType>(); 1094 MemRefType targetMemRefType = 1095 castOp.getResult().getType().cast<MemRefType>(); 1096 1097 // Only static shape casts supported atm. 1098 if (!sourceMemRefType.hasStaticShape() || 1099 !targetMemRefType.hasStaticShape()) 1100 return failure(); 1101 1102 auto llvmSourceDescriptorTy = 1103 operands[0].getType().dyn_cast<LLVM::LLVMStructType>(); 1104 if (!llvmSourceDescriptorTy) 1105 return failure(); 1106 MemRefDescriptor sourceMemRef(operands[0]); 1107 1108 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1109 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1110 if (!llvmTargetDescriptorTy) 1111 return failure(); 1112 1113 // Only contiguous source buffers supported atm. 1114 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1115 if (!sourceStrides) 1116 return failure(); 1117 auto targetStrides = computeContiguousStrides(targetMemRefType); 1118 if (!targetStrides) 1119 return failure(); 1120 // Only support static strides for now, regardless of contiguity. 1121 if (llvm::any_of(*targetStrides, [](int64_t stride) { 1122 return ShapedType::isDynamicStrideOrOffset(stride); 1123 })) 1124 return failure(); 1125 1126 auto int64Ty = LLVM::LLVMIntegerType::get(rewriter.getContext(), 64); 1127 1128 // Create descriptor. 1129 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1130 Type llvmTargetElementTy = desc.getElementPtrType(); 1131 // Set allocated ptr. 1132 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1133 allocated = 1134 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1135 desc.setAllocatedPtr(rewriter, loc, allocated); 1136 // Set aligned ptr. 1137 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1138 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1139 desc.setAlignedPtr(rewriter, loc, ptr); 1140 // Fill offset 0. 1141 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1142 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1143 desc.setOffset(rewriter, loc, zero); 1144 1145 // Fill size and stride descriptors in memref. 1146 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 1147 int64_t index = indexedSize.index(); 1148 auto sizeAttr = 1149 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1150 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1151 desc.setSize(rewriter, loc, index, size); 1152 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1153 (*targetStrides)[index]); 1154 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1155 desc.setStride(rewriter, loc, index, stride); 1156 } 1157 1158 rewriter.replaceOp(castOp, {desc}); 1159 return success(); 1160 } 1161 }; 1162 1163 /// Conversion pattern that converts a 1-D vector transfer read/write op in a 1164 /// sequence of: 1165 /// 1. Get the source/dst address as an LLVM vector pointer. 1166 /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1167 /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1168 /// 4. Create a mask where offsetVector is compared against memref upper bound. 1169 /// 5. Rewrite op as a masked read or write. 1170 template <typename ConcreteOp> 1171 class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> { 1172 public: 1173 explicit VectorTransferConversion(LLVMTypeConverter &typeConv, 1174 bool enableIndexOpt) 1175 : ConvertOpToLLVMPattern<ConcreteOp>(typeConv), 1176 enableIndexOptimizations(enableIndexOpt) {} 1177 1178 LogicalResult 1179 matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands, 1180 ConversionPatternRewriter &rewriter) const override { 1181 auto adaptor = getTransferOpAdapter(xferOp, operands); 1182 1183 if (xferOp.getVectorType().getRank() > 1 || 1184 llvm::size(xferOp.indices()) == 0) 1185 return failure(); 1186 if (xferOp.permutation_map() != 1187 AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 1188 xferOp.getVectorType().getRank(), 1189 xferOp->getContext())) 1190 return failure(); 1191 auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 1192 if (!memRefType) 1193 return failure(); 1194 // Only contiguous source tensors supported atm. 1195 auto strides = computeContiguousStrides(memRefType); 1196 if (!strides) 1197 return failure(); 1198 1199 auto toLLVMTy = [&](Type t) { 1200 return this->getTypeConverter()->convertType(t); 1201 }; 1202 1203 Location loc = xferOp->getLoc(); 1204 1205 if (auto memrefVectorElementType = 1206 memRefType.getElementType().template dyn_cast<VectorType>()) { 1207 // Memref has vector element type. 1208 if (memrefVectorElementType.getElementType() != 1209 xferOp.getVectorType().getElementType()) 1210 return failure(); 1211 #ifndef NDEBUG 1212 // Check that memref vector type is a suffix of 'vectorType. 1213 unsigned memrefVecEltRank = memrefVectorElementType.getRank(); 1214 unsigned resultVecRank = xferOp.getVectorType().getRank(); 1215 assert(memrefVecEltRank <= resultVecRank); 1216 // TODO: Move this to isSuffix in Vector/Utils.h. 1217 unsigned rankOffset = resultVecRank - memrefVecEltRank; 1218 auto memrefVecEltShape = memrefVectorElementType.getShape(); 1219 auto resultVecShape = xferOp.getVectorType().getShape(); 1220 for (unsigned i = 0; i < memrefVecEltRank; ++i) 1221 assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && 1222 "memref vector element shape should match suffix of vector " 1223 "result shape."); 1224 #endif // ifndef NDEBUG 1225 } 1226 1227 // 1. Get the source/dst address as an LLVM vector pointer. 1228 // The vector pointer would always be on address space 0, therefore 1229 // addrspacecast shall be used when source/dst memrefs are not on 1230 // address space 0. 1231 // TODO: support alignment when possible. 1232 Value dataPtr = this->getStridedElementPtr( 1233 loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); 1234 auto vecTy = toLLVMTy(xferOp.getVectorType()) 1235 .template cast<LLVM::LLVMFixedVectorType>(); 1236 Value vectorDataPtr; 1237 if (memRefType.getMemorySpace() == 0) 1238 vectorDataPtr = rewriter.create<LLVM::BitcastOp>( 1239 loc, LLVM::LLVMPointerType::get(vecTy), dataPtr); 1240 else 1241 vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 1242 loc, LLVM::LLVMPointerType::get(vecTy), dataPtr); 1243 1244 if (!xferOp.isMaskedDim(0)) 1245 return replaceTransferOpWithLoadOrStore(rewriter, 1246 *this->getTypeConverter(), loc, 1247 xferOp, operands, vectorDataPtr); 1248 1249 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1250 // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1251 // 4. Let dim the memref dimension, compute the vector comparison mask: 1252 // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 1253 // 1254 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 1255 // dimensions here. 1256 unsigned vecWidth = vecTy.getNumElements(); 1257 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 1258 Value off = xferOp.indices()[lastIndex]; 1259 Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex); 1260 Value mask = buildVectorComparison( 1261 rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off); 1262 1263 // 5. Rewrite as a masked read / write. 1264 return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc, 1265 xferOp, operands, vectorDataPtr, mask); 1266 } 1267 1268 private: 1269 const bool enableIndexOptimizations; 1270 }; 1271 1272 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1273 public: 1274 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1275 1276 // Proof-of-concept lowering implementation that relies on a small 1277 // runtime support library, which only needs to provide a few 1278 // printing methods (single value for all data types, opening/closing 1279 // bracket, comma, newline). The lowering fully unrolls a vector 1280 // in terms of these elementary printing operations. The advantage 1281 // of this approach is that the library can remain unaware of all 1282 // low-level implementation details of vectors while still supporting 1283 // output of any shaped and dimensioned vector. Due to full unrolling, 1284 // this approach is less suited for very large vectors though. 1285 // 1286 // TODO: rely solely on libc in future? something else? 1287 // 1288 LogicalResult 1289 matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands, 1290 ConversionPatternRewriter &rewriter) const override { 1291 auto adaptor = vector::PrintOpAdaptor(operands); 1292 Type printType = printOp.getPrintType(); 1293 1294 if (typeConverter->convertType(printType) == nullptr) 1295 return failure(); 1296 1297 // Make sure element type has runtime support. 1298 PrintConversion conversion = PrintConversion::None; 1299 VectorType vectorType = printType.dyn_cast<VectorType>(); 1300 Type eltType = vectorType ? vectorType.getElementType() : printType; 1301 Operation *printer; 1302 if (eltType.isF32()) { 1303 printer = getPrintFloat(printOp); 1304 } else if (eltType.isF64()) { 1305 printer = getPrintDouble(printOp); 1306 } else if (eltType.isIndex()) { 1307 printer = getPrintU64(printOp); 1308 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1309 // Integers need a zero or sign extension on the operand 1310 // (depending on the source type) as well as a signed or 1311 // unsigned print method. Up to 64-bit is supported. 1312 unsigned width = intTy.getWidth(); 1313 if (intTy.isUnsigned()) { 1314 if (width <= 64) { 1315 if (width < 64) 1316 conversion = PrintConversion::ZeroExt64; 1317 printer = getPrintU64(printOp); 1318 } else { 1319 return failure(); 1320 } 1321 } else { 1322 assert(intTy.isSignless() || intTy.isSigned()); 1323 if (width <= 64) { 1324 // Note that we *always* zero extend booleans (1-bit integers), 1325 // so that true/false is printed as 1/0 rather than -1/0. 1326 if (width == 1) 1327 conversion = PrintConversion::ZeroExt64; 1328 else if (width < 64) 1329 conversion = PrintConversion::SignExt64; 1330 printer = getPrintI64(printOp); 1331 } else { 1332 return failure(); 1333 } 1334 } 1335 } else { 1336 return failure(); 1337 } 1338 1339 // Unroll vector into elementary print calls. 1340 int64_t rank = vectorType ? vectorType.getRank() : 0; 1341 emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank, 1342 conversion); 1343 emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp)); 1344 rewriter.eraseOp(printOp); 1345 return success(); 1346 } 1347 1348 private: 1349 enum class PrintConversion { 1350 // clang-format off 1351 None, 1352 ZeroExt64, 1353 SignExt64 1354 // clang-format on 1355 }; 1356 1357 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1358 Value value, VectorType vectorType, Operation *printer, 1359 int64_t rank, PrintConversion conversion) const { 1360 Location loc = op->getLoc(); 1361 if (rank == 0) { 1362 switch (conversion) { 1363 case PrintConversion::ZeroExt64: 1364 value = rewriter.create<ZeroExtendIOp>( 1365 loc, value, LLVM::LLVMIntegerType::get(rewriter.getContext(), 64)); 1366 break; 1367 case PrintConversion::SignExt64: 1368 value = rewriter.create<SignExtendIOp>( 1369 loc, value, LLVM::LLVMIntegerType::get(rewriter.getContext(), 64)); 1370 break; 1371 case PrintConversion::None: 1372 break; 1373 } 1374 emitCall(rewriter, loc, printer, value); 1375 return; 1376 } 1377 1378 emitCall(rewriter, loc, getPrintOpen(op)); 1379 Operation *printComma = getPrintComma(op); 1380 int64_t dim = vectorType.getDimSize(0); 1381 for (int64_t d = 0; d < dim; ++d) { 1382 auto reducedType = 1383 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1384 auto llvmType = typeConverter->convertType( 1385 rank > 1 ? reducedType : vectorType.getElementType()); 1386 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1387 llvmType, rank, d); 1388 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1389 conversion); 1390 if (d != dim - 1) 1391 emitCall(rewriter, loc, printComma); 1392 } 1393 emitCall(rewriter, loc, getPrintClose(op)); 1394 } 1395 1396 // Helper to emit a call. 1397 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1398 Operation *ref, ValueRange params = ValueRange()) { 1399 rewriter.create<LLVM::CallOp>(loc, TypeRange(), 1400 rewriter.getSymbolRefAttr(ref), params); 1401 } 1402 1403 // Helper for printer method declaration (first hit) and lookup. 1404 static Operation *getPrint(Operation *op, StringRef name, 1405 ArrayRef<LLVM::LLVMType> params) { 1406 auto module = op->getParentOfType<ModuleOp>(); 1407 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1408 if (func) 1409 return func; 1410 OpBuilder moduleBuilder(module.getBodyRegion()); 1411 return moduleBuilder.create<LLVM::LLVMFuncOp>( 1412 op->getLoc(), name, 1413 LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()), 1414 params)); 1415 } 1416 1417 // Helpers for method names. 1418 Operation *getPrintI64(Operation *op) const { 1419 return getPrint(op, "printI64", 1420 LLVM::LLVMIntegerType::get(op->getContext(), 64)); 1421 } 1422 Operation *getPrintU64(Operation *op) const { 1423 return getPrint(op, "printU64", 1424 LLVM::LLVMIntegerType::get(op->getContext(), 64)); 1425 } 1426 Operation *getPrintFloat(Operation *op) const { 1427 return getPrint(op, "printF32", LLVM::LLVMFloatType::get(op->getContext())); 1428 } 1429 Operation *getPrintDouble(Operation *op) const { 1430 return getPrint(op, "printF64", 1431 LLVM::LLVMDoubleType::get(op->getContext())); 1432 } 1433 Operation *getPrintOpen(Operation *op) const { 1434 return getPrint(op, "printOpen", {}); 1435 } 1436 Operation *getPrintClose(Operation *op) const { 1437 return getPrint(op, "printClose", {}); 1438 } 1439 Operation *getPrintComma(Operation *op) const { 1440 return getPrint(op, "printComma", {}); 1441 } 1442 Operation *getPrintNewline(Operation *op) const { 1443 return getPrint(op, "printNewline", {}); 1444 } 1445 }; 1446 1447 /// Progressive lowering of ExtractStridedSliceOp to either: 1448 /// 1. express single offset extract as a direct shuffle. 1449 /// 2. extract + lower rank strided_slice + insert for the n-D case. 1450 class VectorExtractStridedSliceOpConversion 1451 : public OpRewritePattern<ExtractStridedSliceOp> { 1452 public: 1453 VectorExtractStridedSliceOpConversion(MLIRContext *ctx) 1454 : OpRewritePattern<ExtractStridedSliceOp>(ctx) { 1455 // This pattern creates recursive ExtractStridedSliceOp, but the recursion 1456 // is bounded as the rank is strictly decreasing. 1457 setHasBoundedRewriteRecursion(); 1458 } 1459 1460 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 1461 PatternRewriter &rewriter) const override { 1462 auto dstType = op.getResult().getType().cast<VectorType>(); 1463 1464 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 1465 1466 int64_t offset = 1467 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1468 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 1469 int64_t stride = 1470 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1471 1472 auto loc = op.getLoc(); 1473 auto elemType = dstType.getElementType(); 1474 assert(elemType.isSignlessIntOrIndexOrFloat()); 1475 1476 // Single offset can be more efficiently shuffled. 1477 if (op.offsets().getValue().size() == 1) { 1478 SmallVector<int64_t, 4> offsets; 1479 offsets.reserve(size); 1480 for (int64_t off = offset, e = offset + size * stride; off < e; 1481 off += stride) 1482 offsets.push_back(off); 1483 rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(), 1484 op.vector(), 1485 rewriter.getI64ArrayAttr(offsets)); 1486 return success(); 1487 } 1488 1489 // Extract/insert on a lower ranked extract strided slice op. 1490 Value zero = rewriter.create<ConstantOp>(loc, elemType, 1491 rewriter.getZeroAttr(elemType)); 1492 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 1493 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1494 off += stride, ++idx) { 1495 Value one = extractOne(rewriter, loc, op.vector(), off); 1496 Value extracted = rewriter.create<ExtractStridedSliceOp>( 1497 loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), 1498 getI64SubArray(op.sizes(), /* dropFront=*/1), 1499 getI64SubArray(op.strides(), /* dropFront=*/1)); 1500 res = insertOne(rewriter, loc, extracted, res, idx); 1501 } 1502 rewriter.replaceOp(op, res); 1503 return success(); 1504 } 1505 }; 1506 1507 } // namespace 1508 1509 /// Populate the given list with patterns that convert from Vector to LLVM. 1510 void mlir::populateVectorToLLVMConversionPatterns( 1511 LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 1512 bool reassociateFPReductions, bool enableIndexOptimizations) { 1513 MLIRContext *ctx = converter.getDialect()->getContext(); 1514 // clang-format off 1515 patterns.insert<VectorFMAOpNDRewritePattern, 1516 VectorInsertStridedSliceOpDifferentRankRewritePattern, 1517 VectorInsertStridedSliceOpSameRankRewritePattern, 1518 VectorExtractStridedSliceOpConversion>(ctx); 1519 patterns.insert<VectorReductionOpConversion>( 1520 converter, reassociateFPReductions); 1521 patterns.insert<VectorCreateMaskOpConversion, 1522 VectorTransferConversion<TransferReadOp>, 1523 VectorTransferConversion<TransferWriteOp>>( 1524 converter, enableIndexOptimizations); 1525 patterns 1526 .insert<VectorShuffleOpConversion, 1527 VectorExtractElementOpConversion, 1528 VectorExtractOpConversion, 1529 VectorFMAOp1DConversion, 1530 VectorInsertElementOpConversion, 1531 VectorInsertOpConversion, 1532 VectorPrintOpConversion, 1533 VectorTypeCastOpConversion, 1534 VectorMaskedLoadOpConversion, 1535 VectorMaskedStoreOpConversion, 1536 VectorGatherOpConversion, 1537 VectorScatterOpConversion, 1538 VectorExpandLoadOpConversion, 1539 VectorCompressStoreOpConversion>(converter); 1540 // clang-format on 1541 } 1542 1543 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1544 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1545 patterns.insert<VectorMatmulOpConversion>(converter); 1546 patterns.insert<VectorFlatTransposeOpConversion>(converter); 1547 } 1548