1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 13 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Vector/VectorOps.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/Target/LLVMIR/TypeTranslation.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 21 using namespace mlir; 22 using namespace mlir::vector; 23 24 // Helper to reduce vector type by one rank at front. 25 static VectorType reducedVectorTypeFront(VectorType tp) { 26 assert((tp.getRank() > 1) && "unlowerable vector type"); 27 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 28 } 29 30 // Helper to reduce vector type by *all* but one rank at back. 31 static VectorType reducedVectorTypeBack(VectorType tp) { 32 assert((tp.getRank() > 1) && "unlowerable vector type"); 33 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 34 } 35 36 // Helper that picks the proper sequence for inserting. 37 static Value insertOne(ConversionPatternRewriter &rewriter, 38 LLVMTypeConverter &typeConverter, Location loc, 39 Value val1, Value val2, Type llvmType, int64_t rank, 40 int64_t pos) { 41 if (rank == 1) { 42 auto idxType = rewriter.getIndexType(); 43 auto constant = rewriter.create<LLVM::ConstantOp>( 44 loc, typeConverter.convertType(idxType), 45 rewriter.getIntegerAttr(idxType, pos)); 46 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 47 constant); 48 } 49 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 50 rewriter.getI64ArrayAttr(pos)); 51 } 52 53 // Helper that picks the proper sequence for inserting. 54 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 55 Value into, int64_t offset) { 56 auto vectorType = into.getType().cast<VectorType>(); 57 if (vectorType.getRank() > 1) 58 return rewriter.create<InsertOp>(loc, from, into, offset); 59 return rewriter.create<vector::InsertElementOp>( 60 loc, vectorType, from, into, 61 rewriter.create<ConstantIndexOp>(loc, offset)); 62 } 63 64 // Helper that picks the proper sequence for extracting. 65 static Value extractOne(ConversionPatternRewriter &rewriter, 66 LLVMTypeConverter &typeConverter, Location loc, 67 Value val, Type llvmType, int64_t rank, int64_t pos) { 68 if (rank == 1) { 69 auto idxType = rewriter.getIndexType(); 70 auto constant = rewriter.create<LLVM::ConstantOp>( 71 loc, typeConverter.convertType(idxType), 72 rewriter.getIntegerAttr(idxType, pos)); 73 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 74 constant); 75 } 76 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 77 rewriter.getI64ArrayAttr(pos)); 78 } 79 80 // Helper that picks the proper sequence for extracting. 81 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 82 int64_t offset) { 83 auto vectorType = vector.getType().cast<VectorType>(); 84 if (vectorType.getRank() > 1) 85 return rewriter.create<ExtractOp>(loc, vector, offset); 86 return rewriter.create<vector::ExtractElementOp>( 87 loc, vectorType.getElementType(), vector, 88 rewriter.create<ConstantIndexOp>(loc, offset)); 89 } 90 91 // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 92 // TODO: Better support for attribute subtype forwarding + slicing. 93 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 94 unsigned dropFront = 0, 95 unsigned dropBack = 0) { 96 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 97 auto range = arrayAttr.getAsRange<IntegerAttr>(); 98 SmallVector<int64_t, 4> res; 99 res.reserve(arrayAttr.size() - dropFront - dropBack); 100 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 101 it != eit; ++it) 102 res.push_back((*it).getValue().getSExtValue()); 103 return res; 104 } 105 106 static Value createCastToIndexLike(ConversionPatternRewriter &rewriter, 107 Location loc, Type targetType, Value value) { 108 if (targetType == value.getType()) 109 return value; 110 111 bool targetIsIndex = targetType.isIndex(); 112 bool valueIsIndex = value.getType().isIndex(); 113 if (targetIsIndex ^ valueIsIndex) 114 return rewriter.create<IndexCastOp>(loc, targetType, value); 115 116 auto targetIntegerType = targetType.dyn_cast<IntegerType>(); 117 auto valueIntegerType = value.getType().dyn_cast<IntegerType>(); 118 assert(targetIntegerType && valueIntegerType && 119 "unexpected cast between types other than integers and index"); 120 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); 121 122 if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) 123 return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value); 124 return rewriter.create<TruncateIOp>(loc, targetIntegerType, value); 125 } 126 127 // Helper that returns a vector comparison that constructs a mask: 128 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 129 // 130 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 131 // much more compact, IR for this operation, but LLVM eventually 132 // generates more elaborate instructions for this intrinsic since it 133 // is very conservative on the boundary conditions. 134 static Value buildVectorComparison(ConversionPatternRewriter &rewriter, 135 Operation *op, bool enableIndexOptimizations, 136 int64_t dim, Value b, Value *off = nullptr) { 137 auto loc = op->getLoc(); 138 // If we can assume all indices fit in 32-bit, we perform the vector 139 // comparison in 32-bit to get a higher degree of SIMD parallelism. 140 // Otherwise we perform the vector comparison using 64-bit indices. 141 Value indices; 142 Type idxType; 143 if (enableIndexOptimizations) { 144 indices = rewriter.create<ConstantOp>( 145 loc, rewriter.getI32VectorAttr( 146 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)))); 147 idxType = rewriter.getI32Type(); 148 } else { 149 indices = rewriter.create<ConstantOp>( 150 loc, rewriter.getI64VectorAttr( 151 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)))); 152 idxType = rewriter.getI64Type(); 153 } 154 // Add in an offset if requested. 155 if (off) { 156 Value o = createCastToIndexLike(rewriter, loc, idxType, *off); 157 Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o); 158 indices = rewriter.create<AddIOp>(loc, ov, indices); 159 } 160 // Construct the vector comparison. 161 Value bound = createCastToIndexLike(rewriter, loc, idxType, b); 162 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 163 return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds); 164 } 165 166 // Helper that returns data layout alignment of a memref. 167 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 168 MemRefType memrefType, unsigned &align) { 169 Type elementTy = typeConverter.convertType(memrefType.getElementType()); 170 if (!elementTy) 171 return failure(); 172 173 // TODO: this should use the MLIR data layout when it becomes available and 174 // stop depending on translation. 175 llvm::LLVMContext llvmContext; 176 align = LLVM::TypeToLLVMIRTranslator(llvmContext) 177 .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 178 return success(); 179 } 180 181 // Helper that returns the base address of a memref. 182 static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc, 183 Value memref, MemRefType memRefType, Value &base) { 184 // Inspect stride and offset structure. 185 // 186 // TODO: flat memory only for now, generalize 187 // 188 int64_t offset; 189 SmallVector<int64_t, 4> strides; 190 auto successStrides = getStridesAndOffset(memRefType, strides, offset); 191 if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || 192 offset != 0 || memRefType.getMemorySpace() != 0) 193 return failure(); 194 base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); 195 return success(); 196 } 197 198 // Helper that returns vector of pointers given a memref base with index vector. 199 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 200 Location loc, Value memref, Value indices, 201 MemRefType memRefType, VectorType vType, 202 Type iType, Value &ptrs) { 203 Value base; 204 if (failed(getBase(rewriter, loc, memref, memRefType, base))) 205 return failure(); 206 auto pType = MemRefDescriptor(memref).getElementPtrType(); 207 auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); 208 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices); 209 return success(); 210 } 211 212 // Casts a strided element pointer to a vector pointer. The vector pointer 213 // would always be on address space 0, therefore addrspacecast shall be 214 // used when source/dst memrefs are not on address space 0. 215 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 216 Value ptr, MemRefType memRefType, Type vt) { 217 auto pType = LLVM::LLVMPointerType::get(vt); 218 if (memRefType.getMemorySpace() == 0) 219 return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 220 return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr); 221 } 222 223 static LogicalResult 224 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 225 LLVMTypeConverter &typeConverter, Location loc, 226 TransferReadOp xferOp, 227 ArrayRef<Value> operands, Value dataPtr) { 228 unsigned align; 229 if (failed(getMemRefAlignment( 230 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 231 return failure(); 232 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align); 233 return success(); 234 } 235 236 static LogicalResult 237 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 238 LLVMTypeConverter &typeConverter, Location loc, 239 TransferReadOp xferOp, ArrayRef<Value> operands, 240 Value dataPtr, Value mask) { 241 VectorType fillType = xferOp.getVectorType(); 242 Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 243 244 Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 245 if (!vecTy) 246 return failure(); 247 248 unsigned align; 249 if (failed(getMemRefAlignment( 250 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 251 return failure(); 252 253 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 254 xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 255 rewriter.getI32IntegerAttr(align)); 256 return success(); 257 } 258 259 static LogicalResult 260 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 261 LLVMTypeConverter &typeConverter, Location loc, 262 TransferWriteOp xferOp, 263 ArrayRef<Value> operands, Value dataPtr) { 264 unsigned align; 265 if (failed(getMemRefAlignment( 266 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 267 return failure(); 268 auto adaptor = TransferWriteOpAdaptor(operands); 269 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr, 270 align); 271 return success(); 272 } 273 274 static LogicalResult 275 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 276 LLVMTypeConverter &typeConverter, Location loc, 277 TransferWriteOp xferOp, ArrayRef<Value> operands, 278 Value dataPtr, Value mask) { 279 unsigned align; 280 if (failed(getMemRefAlignment( 281 typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 282 return failure(); 283 284 auto adaptor = TransferWriteOpAdaptor(operands); 285 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 286 xferOp, adaptor.vector(), dataPtr, mask, 287 rewriter.getI32IntegerAttr(align)); 288 return success(); 289 } 290 291 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 292 ArrayRef<Value> operands) { 293 return TransferReadOpAdaptor(operands); 294 } 295 296 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 297 ArrayRef<Value> operands) { 298 return TransferWriteOpAdaptor(operands); 299 } 300 301 namespace { 302 303 /// Conversion pattern for a vector.bitcast. 304 class VectorBitCastOpConversion 305 : public ConvertOpToLLVMPattern<vector::BitCastOp> { 306 public: 307 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 308 309 LogicalResult 310 matchAndRewrite(vector::BitCastOp bitCastOp, ArrayRef<Value> operands, 311 ConversionPatternRewriter &rewriter) const override { 312 // Only 1-D vectors can be lowered to LLVM. 313 VectorType resultTy = bitCastOp.getType(); 314 if (resultTy.getRank() != 1) 315 return failure(); 316 Type newResultTy = typeConverter->convertType(resultTy); 317 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 318 operands[0]); 319 return success(); 320 } 321 }; 322 323 /// Conversion pattern for a vector.matrix_multiply. 324 /// This is lowered directly to the proper llvm.intr.matrix.multiply. 325 class VectorMatmulOpConversion 326 : public ConvertOpToLLVMPattern<vector::MatmulOp> { 327 public: 328 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 329 330 LogicalResult 331 matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands, 332 ConversionPatternRewriter &rewriter) const override { 333 auto adaptor = vector::MatmulOpAdaptor(operands); 334 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 335 matmulOp, typeConverter->convertType(matmulOp.res().getType()), 336 adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), 337 matmulOp.lhs_columns(), matmulOp.rhs_columns()); 338 return success(); 339 } 340 }; 341 342 /// Conversion pattern for a vector.flat_transpose. 343 /// This is lowered directly to the proper llvm.intr.matrix.transpose. 344 class VectorFlatTransposeOpConversion 345 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 346 public: 347 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 348 349 LogicalResult 350 matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands, 351 ConversionPatternRewriter &rewriter) const override { 352 auto adaptor = vector::FlatTransposeOpAdaptor(operands); 353 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 354 transOp, typeConverter->convertType(transOp.res().getType()), 355 adaptor.matrix(), transOp.rows(), transOp.columns()); 356 return success(); 357 } 358 }; 359 360 /// Overloaded utility that replaces a vector.load, vector.store, 361 /// vector.maskedload and vector.maskedstore with their respective LLVM 362 /// couterparts. 363 static void replaceLoadOrStoreOp(vector::LoadOp loadOp, 364 vector::LoadOpAdaptor adaptor, 365 VectorType vectorTy, Value ptr, unsigned align, 366 ConversionPatternRewriter &rewriter) { 367 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align); 368 } 369 370 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, 371 vector::MaskedLoadOpAdaptor adaptor, 372 VectorType vectorTy, Value ptr, unsigned align, 373 ConversionPatternRewriter &rewriter) { 374 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 375 loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align); 376 } 377 378 static void replaceLoadOrStoreOp(vector::StoreOp storeOp, 379 vector::StoreOpAdaptor adaptor, 380 VectorType vectorTy, Value ptr, unsigned align, 381 ConversionPatternRewriter &rewriter) { 382 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(), 383 ptr, align); 384 } 385 386 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, 387 vector::MaskedStoreOpAdaptor adaptor, 388 VectorType vectorTy, Value ptr, unsigned align, 389 ConversionPatternRewriter &rewriter) { 390 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 391 storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align); 392 } 393 394 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and 395 /// vector.maskedstore. 396 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor> 397 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { 398 public: 399 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; 400 401 LogicalResult 402 matchAndRewrite(LoadOrStoreOp loadOrStoreOp, ArrayRef<Value> operands, 403 ConversionPatternRewriter &rewriter) const override { 404 // Only 1-D vectors can be lowered to LLVM. 405 VectorType vectorTy = loadOrStoreOp.getVectorType(); 406 if (vectorTy.getRank() > 1) 407 return failure(); 408 409 auto loc = loadOrStoreOp->getLoc(); 410 auto adaptor = LoadOrStoreOpAdaptor(operands); 411 MemRefType memRefTy = loadOrStoreOp.getMemRefType(); 412 413 // Resolve alignment. 414 unsigned align; 415 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) 416 return failure(); 417 418 // Resolve address. 419 auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) 420 .template cast<VectorType>(); 421 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(), 422 adaptor.indices(), rewriter); 423 Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); 424 425 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); 426 return success(); 427 } 428 }; 429 430 /// Conversion pattern for a vector.gather. 431 class VectorGatherOpConversion 432 : public ConvertOpToLLVMPattern<vector::GatherOp> { 433 public: 434 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 435 436 LogicalResult 437 matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands, 438 ConversionPatternRewriter &rewriter) const override { 439 auto loc = gather->getLoc(); 440 auto adaptor = vector::GatherOpAdaptor(operands); 441 442 // Resolve alignment. 443 unsigned align; 444 if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(), 445 align))) 446 return failure(); 447 448 // Get index ptrs. 449 VectorType vType = gather.getResultVectorType(); 450 Type iType = gather.getIndicesVectorType().getElementType(); 451 Value ptrs; 452 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 453 gather.getMemRefType(), vType, iType, ptrs))) 454 return failure(); 455 456 // Replace with the gather intrinsic. 457 rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 458 gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), 459 adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 460 return success(); 461 } 462 }; 463 464 /// Conversion pattern for a vector.scatter. 465 class VectorScatterOpConversion 466 : public ConvertOpToLLVMPattern<vector::ScatterOp> { 467 public: 468 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 469 470 LogicalResult 471 matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands, 472 ConversionPatternRewriter &rewriter) const override { 473 auto loc = scatter->getLoc(); 474 auto adaptor = vector::ScatterOpAdaptor(operands); 475 476 // Resolve alignment. 477 unsigned align; 478 if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(), 479 align))) 480 return failure(); 481 482 // Get index ptrs. 483 VectorType vType = scatter.getValueVectorType(); 484 Type iType = scatter.getIndicesVectorType().getElementType(); 485 Value ptrs; 486 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 487 scatter.getMemRefType(), vType, iType, ptrs))) 488 return failure(); 489 490 // Replace with the scatter intrinsic. 491 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 492 scatter, adaptor.value(), ptrs, adaptor.mask(), 493 rewriter.getI32IntegerAttr(align)); 494 return success(); 495 } 496 }; 497 498 /// Conversion pattern for a vector.expandload. 499 class VectorExpandLoadOpConversion 500 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 501 public: 502 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 503 504 LogicalResult 505 matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands, 506 ConversionPatternRewriter &rewriter) const override { 507 auto loc = expand->getLoc(); 508 auto adaptor = vector::ExpandLoadOpAdaptor(operands); 509 MemRefType memRefType = expand.getMemRefType(); 510 511 // Resolve address. 512 auto vtype = typeConverter->convertType(expand.getResultVectorType()); 513 Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 514 adaptor.indices(), rewriter); 515 516 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 517 expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru()); 518 return success(); 519 } 520 }; 521 522 /// Conversion pattern for a vector.compressstore. 523 class VectorCompressStoreOpConversion 524 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 525 public: 526 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 527 528 LogicalResult 529 matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands, 530 ConversionPatternRewriter &rewriter) const override { 531 auto loc = compress->getLoc(); 532 auto adaptor = vector::CompressStoreOpAdaptor(operands); 533 MemRefType memRefType = compress.getMemRefType(); 534 535 // Resolve address. 536 Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 537 adaptor.indices(), rewriter); 538 539 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 540 compress, adaptor.value(), ptr, adaptor.mask()); 541 return success(); 542 } 543 }; 544 545 /// Conversion pattern for all vector reductions. 546 class VectorReductionOpConversion 547 : public ConvertOpToLLVMPattern<vector::ReductionOp> { 548 public: 549 explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 550 bool reassociateFPRed) 551 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 552 reassociateFPReductions(reassociateFPRed) {} 553 554 LogicalResult 555 matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands, 556 ConversionPatternRewriter &rewriter) const override { 557 auto kind = reductionOp.kind(); 558 Type eltType = reductionOp.dest().getType(); 559 Type llvmType = typeConverter->convertType(eltType); 560 if (eltType.isIntOrIndex()) { 561 // Integer reductions: add/mul/min/max/and/or/xor. 562 if (kind == "add") 563 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>( 564 reductionOp, llvmType, operands[0]); 565 else if (kind == "mul") 566 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>( 567 reductionOp, llvmType, operands[0]); 568 else if (kind == "min" && 569 (eltType.isIndex() || eltType.isUnsignedInteger())) 570 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 571 reductionOp, llvmType, operands[0]); 572 else if (kind == "min") 573 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 574 reductionOp, llvmType, operands[0]); 575 else if (kind == "max" && 576 (eltType.isIndex() || eltType.isUnsignedInteger())) 577 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 578 reductionOp, llvmType, operands[0]); 579 else if (kind == "max") 580 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 581 reductionOp, llvmType, operands[0]); 582 else if (kind == "and") 583 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>( 584 reductionOp, llvmType, operands[0]); 585 else if (kind == "or") 586 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>( 587 reductionOp, llvmType, operands[0]); 588 else if (kind == "xor") 589 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>( 590 reductionOp, llvmType, operands[0]); 591 else 592 return failure(); 593 return success(); 594 } 595 596 if (!eltType.isa<FloatType>()) 597 return failure(); 598 599 // Floating-point reductions: add/mul/min/max 600 if (kind == "add") { 601 // Optional accumulator (or zero). 602 Value acc = operands.size() > 1 ? operands[1] 603 : rewriter.create<LLVM::ConstantOp>( 604 reductionOp->getLoc(), llvmType, 605 rewriter.getZeroAttr(eltType)); 606 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 607 reductionOp, llvmType, acc, operands[0], 608 rewriter.getBoolAttr(reassociateFPReductions)); 609 } else if (kind == "mul") { 610 // Optional accumulator (or one). 611 Value acc = operands.size() > 1 612 ? operands[1] 613 : rewriter.create<LLVM::ConstantOp>( 614 reductionOp->getLoc(), llvmType, 615 rewriter.getFloatAttr(eltType, 1.0)); 616 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 617 reductionOp, llvmType, acc, operands[0], 618 rewriter.getBoolAttr(reassociateFPReductions)); 619 } else if (kind == "min") 620 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>( 621 reductionOp, llvmType, operands[0]); 622 else if (kind == "max") 623 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>( 624 reductionOp, llvmType, operands[0]); 625 else 626 return failure(); 627 return success(); 628 } 629 630 private: 631 const bool reassociateFPReductions; 632 }; 633 634 /// Conversion pattern for a vector.create_mask (1-D only). 635 class VectorCreateMaskOpConversion 636 : public ConvertOpToLLVMPattern<vector::CreateMaskOp> { 637 public: 638 explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv, 639 bool enableIndexOpt) 640 : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv), 641 enableIndexOptimizations(enableIndexOpt) {} 642 643 LogicalResult 644 matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands, 645 ConversionPatternRewriter &rewriter) const override { 646 auto dstType = op.getType(); 647 int64_t rank = dstType.getRank(); 648 if (rank == 1) { 649 rewriter.replaceOp( 650 op, buildVectorComparison(rewriter, op, enableIndexOptimizations, 651 dstType.getDimSize(0), operands[0])); 652 return success(); 653 } 654 return failure(); 655 } 656 657 private: 658 const bool enableIndexOptimizations; 659 }; 660 661 class VectorShuffleOpConversion 662 : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 663 public: 664 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 665 666 LogicalResult 667 matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands, 668 ConversionPatternRewriter &rewriter) const override { 669 auto loc = shuffleOp->getLoc(); 670 auto adaptor = vector::ShuffleOpAdaptor(operands); 671 auto v1Type = shuffleOp.getV1VectorType(); 672 auto v2Type = shuffleOp.getV2VectorType(); 673 auto vectorType = shuffleOp.getVectorType(); 674 Type llvmType = typeConverter->convertType(vectorType); 675 auto maskArrayAttr = shuffleOp.mask(); 676 677 // Bail if result type cannot be lowered. 678 if (!llvmType) 679 return failure(); 680 681 // Get rank and dimension sizes. 682 int64_t rank = vectorType.getRank(); 683 assert(v1Type.getRank() == rank); 684 assert(v2Type.getRank() == rank); 685 int64_t v1Dim = v1Type.getDimSize(0); 686 687 // For rank 1, where both operands have *exactly* the same vector type, 688 // there is direct shuffle support in LLVM. Use it! 689 if (rank == 1 && v1Type == v2Type) { 690 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 691 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 692 rewriter.replaceOp(shuffleOp, llvmShuffleOp); 693 return success(); 694 } 695 696 // For all other cases, insert the individual values individually. 697 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 698 int64_t insPos = 0; 699 for (auto en : llvm::enumerate(maskArrayAttr)) { 700 int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 701 Value value = adaptor.v1(); 702 if (extPos >= v1Dim) { 703 extPos -= v1Dim; 704 value = adaptor.v2(); 705 } 706 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 707 llvmType, rank, extPos); 708 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 709 llvmType, rank, insPos++); 710 } 711 rewriter.replaceOp(shuffleOp, insert); 712 return success(); 713 } 714 }; 715 716 class VectorExtractElementOpConversion 717 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 718 public: 719 using ConvertOpToLLVMPattern< 720 vector::ExtractElementOp>::ConvertOpToLLVMPattern; 721 722 LogicalResult 723 matchAndRewrite(vector::ExtractElementOp extractEltOp, 724 ArrayRef<Value> operands, 725 ConversionPatternRewriter &rewriter) const override { 726 auto adaptor = vector::ExtractElementOpAdaptor(operands); 727 auto vectorType = extractEltOp.getVectorType(); 728 auto llvmType = typeConverter->convertType(vectorType.getElementType()); 729 730 // Bail if result type cannot be lowered. 731 if (!llvmType) 732 return failure(); 733 734 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 735 extractEltOp, llvmType, adaptor.vector(), adaptor.position()); 736 return success(); 737 } 738 }; 739 740 class VectorExtractOpConversion 741 : public ConvertOpToLLVMPattern<vector::ExtractOp> { 742 public: 743 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 744 745 LogicalResult 746 matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 747 ConversionPatternRewriter &rewriter) const override { 748 auto loc = extractOp->getLoc(); 749 auto adaptor = vector::ExtractOpAdaptor(operands); 750 auto vectorType = extractOp.getVectorType(); 751 auto resultType = extractOp.getResult().getType(); 752 auto llvmResultType = typeConverter->convertType(resultType); 753 auto positionArrayAttr = extractOp.position(); 754 755 // Bail if result type cannot be lowered. 756 if (!llvmResultType) 757 return failure(); 758 759 // One-shot extraction of vector from array (only requires extractvalue). 760 if (resultType.isa<VectorType>()) { 761 Value extracted = rewriter.create<LLVM::ExtractValueOp>( 762 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 763 rewriter.replaceOp(extractOp, extracted); 764 return success(); 765 } 766 767 // Potential extraction of 1-D vector from array. 768 auto *context = extractOp->getContext(); 769 Value extracted = adaptor.vector(); 770 auto positionAttrs = positionArrayAttr.getValue(); 771 if (positionAttrs.size() > 1) { 772 auto oneDVectorType = reducedVectorTypeBack(vectorType); 773 auto nMinusOnePositionAttrs = 774 ArrayAttr::get(context, positionAttrs.drop_back()); 775 extracted = rewriter.create<LLVM::ExtractValueOp>( 776 loc, typeConverter->convertType(oneDVectorType), extracted, 777 nMinusOnePositionAttrs); 778 } 779 780 // Remaining extraction of element from 1-D LLVM vector 781 auto position = positionAttrs.back().cast<IntegerAttr>(); 782 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 783 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 784 extracted = 785 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 786 rewriter.replaceOp(extractOp, extracted); 787 788 return success(); 789 } 790 }; 791 792 /// Conversion pattern that turns a vector.fma on a 1-D vector 793 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 794 /// This does not match vectors of n >= 2 rank. 795 /// 796 /// Example: 797 /// ``` 798 /// vector.fma %a, %a, %a : vector<8xf32> 799 /// ``` 800 /// is converted to: 801 /// ``` 802 /// llvm.intr.fmuladd %va, %va, %va: 803 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 804 /// -> !llvm."<8 x f32>"> 805 /// ``` 806 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 807 public: 808 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 809 810 LogicalResult 811 matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands, 812 ConversionPatternRewriter &rewriter) const override { 813 auto adaptor = vector::FMAOpAdaptor(operands); 814 VectorType vType = fmaOp.getVectorType(); 815 if (vType.getRank() != 1) 816 return failure(); 817 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(), 818 adaptor.rhs(), adaptor.acc()); 819 return success(); 820 } 821 }; 822 823 class VectorInsertElementOpConversion 824 : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 825 public: 826 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 827 828 LogicalResult 829 matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands, 830 ConversionPatternRewriter &rewriter) const override { 831 auto adaptor = vector::InsertElementOpAdaptor(operands); 832 auto vectorType = insertEltOp.getDestVectorType(); 833 auto llvmType = typeConverter->convertType(vectorType); 834 835 // Bail if result type cannot be lowered. 836 if (!llvmType) 837 return failure(); 838 839 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 840 insertEltOp, llvmType, adaptor.dest(), adaptor.source(), 841 adaptor.position()); 842 return success(); 843 } 844 }; 845 846 class VectorInsertOpConversion 847 : public ConvertOpToLLVMPattern<vector::InsertOp> { 848 public: 849 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 850 851 LogicalResult 852 matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 853 ConversionPatternRewriter &rewriter) const override { 854 auto loc = insertOp->getLoc(); 855 auto adaptor = vector::InsertOpAdaptor(operands); 856 auto sourceType = insertOp.getSourceType(); 857 auto destVectorType = insertOp.getDestVectorType(); 858 auto llvmResultType = typeConverter->convertType(destVectorType); 859 auto positionArrayAttr = insertOp.position(); 860 861 // Bail if result type cannot be lowered. 862 if (!llvmResultType) 863 return failure(); 864 865 // One-shot insertion of a vector into an array (only requires insertvalue). 866 if (sourceType.isa<VectorType>()) { 867 Value inserted = rewriter.create<LLVM::InsertValueOp>( 868 loc, llvmResultType, adaptor.dest(), adaptor.source(), 869 positionArrayAttr); 870 rewriter.replaceOp(insertOp, inserted); 871 return success(); 872 } 873 874 // Potential extraction of 1-D vector from array. 875 auto *context = insertOp->getContext(); 876 Value extracted = adaptor.dest(); 877 auto positionAttrs = positionArrayAttr.getValue(); 878 auto position = positionAttrs.back().cast<IntegerAttr>(); 879 auto oneDVectorType = destVectorType; 880 if (positionAttrs.size() > 1) { 881 oneDVectorType = reducedVectorTypeBack(destVectorType); 882 auto nMinusOnePositionAttrs = 883 ArrayAttr::get(context, positionAttrs.drop_back()); 884 extracted = rewriter.create<LLVM::ExtractValueOp>( 885 loc, typeConverter->convertType(oneDVectorType), extracted, 886 nMinusOnePositionAttrs); 887 } 888 889 // Insertion of an element into a 1-D LLVM vector. 890 auto i64Type = IntegerType::get(rewriter.getContext(), 64); 891 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 892 Value inserted = rewriter.create<LLVM::InsertElementOp>( 893 loc, typeConverter->convertType(oneDVectorType), extracted, 894 adaptor.source(), constant); 895 896 // Potential insertion of resulting 1-D vector into array. 897 if (positionAttrs.size() > 1) { 898 auto nMinusOnePositionAttrs = 899 ArrayAttr::get(context, positionAttrs.drop_back()); 900 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 901 adaptor.dest(), inserted, 902 nMinusOnePositionAttrs); 903 } 904 905 rewriter.replaceOp(insertOp, inserted); 906 return success(); 907 } 908 }; 909 910 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 911 /// 912 /// Example: 913 /// ``` 914 /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 915 /// ``` 916 /// is rewritten into: 917 /// ``` 918 /// %r = splat %f0: vector<2x4xf32> 919 /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 920 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 921 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 922 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 923 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 924 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 925 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 926 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 927 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 928 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 929 /// // %r3 holds the final value. 930 /// ``` 931 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 932 public: 933 using OpRewritePattern<FMAOp>::OpRewritePattern; 934 935 LogicalResult matchAndRewrite(FMAOp op, 936 PatternRewriter &rewriter) const override { 937 auto vType = op.getVectorType(); 938 if (vType.getRank() < 2) 939 return failure(); 940 941 auto loc = op.getLoc(); 942 auto elemType = vType.getElementType(); 943 Value zero = rewriter.create<ConstantOp>(loc, elemType, 944 rewriter.getZeroAttr(elemType)); 945 Value desc = rewriter.create<SplatOp>(loc, vType, zero); 946 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 947 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 948 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 949 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 950 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 951 desc = rewriter.create<InsertOp>(loc, fma, desc, i); 952 } 953 rewriter.replaceOp(op, desc); 954 return success(); 955 } 956 }; 957 958 // When ranks are different, InsertStridedSlice needs to extract a properly 959 // ranked vector from the destination vector into which to insert. This pattern 960 // only takes care of this part and forwards the rest of the conversion to 961 // another pattern that converts InsertStridedSlice for operands of the same 962 // rank. 963 // 964 // RewritePattern for InsertStridedSliceOp where source and destination vectors 965 // have different ranks. In this case: 966 // 1. the proper subvector is extracted from the destination vector 967 // 2. a new InsertStridedSlice op is created to insert the source in the 968 // destination subvector 969 // 3. the destination subvector is inserted back in the proper place 970 // 4. the op is replaced by the result of step 3. 971 // The new InsertStridedSlice from step 2. will be picked up by a 972 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 973 class VectorInsertStridedSliceOpDifferentRankRewritePattern 974 : public OpRewritePattern<InsertStridedSliceOp> { 975 public: 976 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 977 978 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 979 PatternRewriter &rewriter) const override { 980 auto srcType = op.getSourceVectorType(); 981 auto dstType = op.getDestVectorType(); 982 983 if (op.offsets().getValue().empty()) 984 return failure(); 985 986 auto loc = op.getLoc(); 987 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 988 assert(rankDiff >= 0); 989 if (rankDiff == 0) 990 return failure(); 991 992 int64_t rankRest = dstType.getRank() - rankDiff; 993 // Extract / insert the subvector of matching rank and InsertStridedSlice 994 // on it. 995 Value extracted = 996 rewriter.create<ExtractOp>(loc, op.dest(), 997 getI64SubArray(op.offsets(), /*dropFront=*/0, 998 /*dropBack=*/rankRest)); 999 // A different pattern will kick in for InsertStridedSlice with matching 1000 // ranks. 1001 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 1002 loc, op.source(), extracted, 1003 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 1004 getI64SubArray(op.strides(), /*dropFront=*/0)); 1005 rewriter.replaceOpWithNewOp<InsertOp>( 1006 op, stridedSliceInnerOp.getResult(), op.dest(), 1007 getI64SubArray(op.offsets(), /*dropFront=*/0, 1008 /*dropBack=*/rankRest)); 1009 return success(); 1010 } 1011 }; 1012 1013 // RewritePattern for InsertStridedSliceOp where source and destination vectors 1014 // have the same rank. In this case, we reduce 1015 // 1. the proper subvector is extracted from the destination vector 1016 // 2. a new InsertStridedSlice op is created to insert the source in the 1017 // destination subvector 1018 // 3. the destination subvector is inserted back in the proper place 1019 // 4. the op is replaced by the result of step 3. 1020 // The new InsertStridedSlice from step 2. will be picked up by a 1021 // `VectorInsertStridedSliceOpSameRankRewritePattern`. 1022 class VectorInsertStridedSliceOpSameRankRewritePattern 1023 : public OpRewritePattern<InsertStridedSliceOp> { 1024 public: 1025 VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx) 1026 : OpRewritePattern<InsertStridedSliceOp>(ctx) { 1027 // This pattern creates recursive InsertStridedSliceOp, but the recursion is 1028 // bounded as the rank is strictly decreasing. 1029 setHasBoundedRewriteRecursion(); 1030 } 1031 1032 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 1033 PatternRewriter &rewriter) const override { 1034 auto srcType = op.getSourceVectorType(); 1035 auto dstType = op.getDestVectorType(); 1036 1037 if (op.offsets().getValue().empty()) 1038 return failure(); 1039 1040 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 1041 assert(rankDiff >= 0); 1042 if (rankDiff != 0) 1043 return failure(); 1044 1045 if (srcType == dstType) { 1046 rewriter.replaceOp(op, op.source()); 1047 return success(); 1048 } 1049 1050 int64_t offset = 1051 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1052 int64_t size = srcType.getShape().front(); 1053 int64_t stride = 1054 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1055 1056 auto loc = op.getLoc(); 1057 Value res = op.dest(); 1058 // For each slice of the source vector along the most major dimension. 1059 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1060 off += stride, ++idx) { 1061 // 1. extract the proper subvector (or element) from source 1062 Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 1063 if (extractedSource.getType().isa<VectorType>()) { 1064 // 2. If we have a vector, extract the proper subvector from destination 1065 // Otherwise we are at the element level and no need to recurse. 1066 Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 1067 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 1068 // smaller rank. 1069 extractedSource = rewriter.create<InsertStridedSliceOp>( 1070 loc, extractedSource, extractedDest, 1071 getI64SubArray(op.offsets(), /* dropFront=*/1), 1072 getI64SubArray(op.strides(), /* dropFront=*/1)); 1073 } 1074 // 4. Insert the extractedSource into the res vector. 1075 res = insertOne(rewriter, loc, extractedSource, res, off); 1076 } 1077 1078 rewriter.replaceOp(op, res); 1079 return success(); 1080 } 1081 }; 1082 1083 /// Returns the strides if the memory underlying `memRefType` has a contiguous 1084 /// static layout. 1085 static llvm::Optional<SmallVector<int64_t, 4>> 1086 computeContiguousStrides(MemRefType memRefType) { 1087 int64_t offset; 1088 SmallVector<int64_t, 4> strides; 1089 if (failed(getStridesAndOffset(memRefType, strides, offset))) 1090 return None; 1091 if (!strides.empty() && strides.back() != 1) 1092 return None; 1093 // If no layout or identity layout, this is contiguous by definition. 1094 if (memRefType.getAffineMaps().empty() || 1095 memRefType.getAffineMaps().front().isIdentity()) 1096 return strides; 1097 1098 // Otherwise, we must determine contiguity form shapes. This can only ever 1099 // work in static cases because MemRefType is underspecified to represent 1100 // contiguous dynamic shapes in other ways than with just empty/identity 1101 // layout. 1102 auto sizes = memRefType.getShape(); 1103 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 1104 if (ShapedType::isDynamic(sizes[index + 1]) || 1105 ShapedType::isDynamicStrideOrOffset(strides[index]) || 1106 ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 1107 return None; 1108 if (strides[index] != strides[index + 1] * sizes[index + 1]) 1109 return None; 1110 } 1111 return strides; 1112 } 1113 1114 class VectorTypeCastOpConversion 1115 : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 1116 public: 1117 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 1118 1119 LogicalResult 1120 matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands, 1121 ConversionPatternRewriter &rewriter) const override { 1122 auto loc = castOp->getLoc(); 1123 MemRefType sourceMemRefType = 1124 castOp.getOperand().getType().cast<MemRefType>(); 1125 MemRefType targetMemRefType = castOp.getType(); 1126 1127 // Only static shape casts supported atm. 1128 if (!sourceMemRefType.hasStaticShape() || 1129 !targetMemRefType.hasStaticShape()) 1130 return failure(); 1131 1132 auto llvmSourceDescriptorTy = 1133 operands[0].getType().dyn_cast<LLVM::LLVMStructType>(); 1134 if (!llvmSourceDescriptorTy) 1135 return failure(); 1136 MemRefDescriptor sourceMemRef(operands[0]); 1137 1138 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 1139 .dyn_cast_or_null<LLVM::LLVMStructType>(); 1140 if (!llvmTargetDescriptorTy) 1141 return failure(); 1142 1143 // Only contiguous source buffers supported atm. 1144 auto sourceStrides = computeContiguousStrides(sourceMemRefType); 1145 if (!sourceStrides) 1146 return failure(); 1147 auto targetStrides = computeContiguousStrides(targetMemRefType); 1148 if (!targetStrides) 1149 return failure(); 1150 // Only support static strides for now, regardless of contiguity. 1151 if (llvm::any_of(*targetStrides, [](int64_t stride) { 1152 return ShapedType::isDynamicStrideOrOffset(stride); 1153 })) 1154 return failure(); 1155 1156 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 1157 1158 // Create descriptor. 1159 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 1160 Type llvmTargetElementTy = desc.getElementPtrType(); 1161 // Set allocated ptr. 1162 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 1163 allocated = 1164 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 1165 desc.setAllocatedPtr(rewriter, loc, allocated); 1166 // Set aligned ptr. 1167 Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 1168 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 1169 desc.setAlignedPtr(rewriter, loc, ptr); 1170 // Fill offset 0. 1171 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 1172 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 1173 desc.setOffset(rewriter, loc, zero); 1174 1175 // Fill size and stride descriptors in memref. 1176 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 1177 int64_t index = indexedSize.index(); 1178 auto sizeAttr = 1179 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 1180 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 1181 desc.setSize(rewriter, loc, index, size); 1182 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 1183 (*targetStrides)[index]); 1184 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 1185 desc.setStride(rewriter, loc, index, stride); 1186 } 1187 1188 rewriter.replaceOp(castOp, {desc}); 1189 return success(); 1190 } 1191 }; 1192 1193 /// Conversion pattern that converts a 1-D vector transfer read/write op in a 1194 /// sequence of: 1195 /// 1. Get the source/dst address as an LLVM vector pointer. 1196 /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1197 /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1198 /// 4. Create a mask where offsetVector is compared against memref upper bound. 1199 /// 5. Rewrite op as a masked read or write. 1200 template <typename ConcreteOp> 1201 class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> { 1202 public: 1203 explicit VectorTransferConversion(LLVMTypeConverter &typeConv, 1204 bool enableIndexOpt) 1205 : ConvertOpToLLVMPattern<ConcreteOp>(typeConv), 1206 enableIndexOptimizations(enableIndexOpt) {} 1207 1208 LogicalResult 1209 matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands, 1210 ConversionPatternRewriter &rewriter) const override { 1211 auto adaptor = getTransferOpAdapter(xferOp, operands); 1212 1213 if (xferOp.getVectorType().getRank() > 1 || 1214 llvm::size(xferOp.indices()) == 0) 1215 return failure(); 1216 if (xferOp.permutation_map() != 1217 AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 1218 xferOp.getVectorType().getRank(), 1219 xferOp->getContext())) 1220 return failure(); 1221 auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 1222 if (!memRefType) 1223 return failure(); 1224 // Only contiguous source tensors supported atm. 1225 auto strides = computeContiguousStrides(memRefType); 1226 if (!strides) 1227 return failure(); 1228 1229 auto toLLVMTy = [&](Type t) { 1230 return this->getTypeConverter()->convertType(t); 1231 }; 1232 1233 Location loc = xferOp->getLoc(); 1234 1235 if (auto memrefVectorElementType = 1236 memRefType.getElementType().template dyn_cast<VectorType>()) { 1237 // Memref has vector element type. 1238 if (memrefVectorElementType.getElementType() != 1239 xferOp.getVectorType().getElementType()) 1240 return failure(); 1241 #ifndef NDEBUG 1242 // Check that memref vector type is a suffix of 'vectorType. 1243 unsigned memrefVecEltRank = memrefVectorElementType.getRank(); 1244 unsigned resultVecRank = xferOp.getVectorType().getRank(); 1245 assert(memrefVecEltRank <= resultVecRank); 1246 // TODO: Move this to isSuffix in Vector/Utils.h. 1247 unsigned rankOffset = resultVecRank - memrefVecEltRank; 1248 auto memrefVecEltShape = memrefVectorElementType.getShape(); 1249 auto resultVecShape = xferOp.getVectorType().getShape(); 1250 for (unsigned i = 0; i < memrefVecEltRank; ++i) 1251 assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && 1252 "memref vector element shape should match suffix of vector " 1253 "result shape."); 1254 #endif // ifndef NDEBUG 1255 } 1256 1257 // 1. Get the source/dst address as an LLVM vector pointer. 1258 VectorType vtp = xferOp.getVectorType(); 1259 Value dataPtr = this->getStridedElementPtr( 1260 loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); 1261 Value vectorDataPtr = 1262 castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp)); 1263 1264 if (!xferOp.isMaskedDim(0)) 1265 return replaceTransferOpWithLoadOrStore(rewriter, 1266 *this->getTypeConverter(), loc, 1267 xferOp, operands, vectorDataPtr); 1268 1269 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1270 // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1271 // 4. Let dim the memref dimension, compute the vector comparison mask: 1272 // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 1273 // 1274 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 1275 // dimensions here. 1276 unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue(); 1277 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 1278 Value off = xferOp.indices()[lastIndex]; 1279 Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex); 1280 Value mask = buildVectorComparison( 1281 rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off); 1282 1283 // 5. Rewrite as a masked read / write. 1284 return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc, 1285 xferOp, operands, vectorDataPtr, mask); 1286 } 1287 1288 private: 1289 const bool enableIndexOptimizations; 1290 }; 1291 1292 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1293 public: 1294 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1295 1296 // Proof-of-concept lowering implementation that relies on a small 1297 // runtime support library, which only needs to provide a few 1298 // printing methods (single value for all data types, opening/closing 1299 // bracket, comma, newline). The lowering fully unrolls a vector 1300 // in terms of these elementary printing operations. The advantage 1301 // of this approach is that the library can remain unaware of all 1302 // low-level implementation details of vectors while still supporting 1303 // output of any shaped and dimensioned vector. Due to full unrolling, 1304 // this approach is less suited for very large vectors though. 1305 // 1306 // TODO: rely solely on libc in future? something else? 1307 // 1308 LogicalResult 1309 matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands, 1310 ConversionPatternRewriter &rewriter) const override { 1311 auto adaptor = vector::PrintOpAdaptor(operands); 1312 Type printType = printOp.getPrintType(); 1313 1314 if (typeConverter->convertType(printType) == nullptr) 1315 return failure(); 1316 1317 // Make sure element type has runtime support. 1318 PrintConversion conversion = PrintConversion::None; 1319 VectorType vectorType = printType.dyn_cast<VectorType>(); 1320 Type eltType = vectorType ? vectorType.getElementType() : printType; 1321 Operation *printer; 1322 if (eltType.isF32()) { 1323 printer = 1324 LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>()); 1325 } else if (eltType.isF64()) { 1326 printer = 1327 LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>()); 1328 } else if (eltType.isIndex()) { 1329 printer = 1330 LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>()); 1331 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1332 // Integers need a zero or sign extension on the operand 1333 // (depending on the source type) as well as a signed or 1334 // unsigned print method. Up to 64-bit is supported. 1335 unsigned width = intTy.getWidth(); 1336 if (intTy.isUnsigned()) { 1337 if (width <= 64) { 1338 if (width < 64) 1339 conversion = PrintConversion::ZeroExt64; 1340 printer = LLVM::lookupOrCreatePrintU64Fn( 1341 printOp->getParentOfType<ModuleOp>()); 1342 } else { 1343 return failure(); 1344 } 1345 } else { 1346 assert(intTy.isSignless() || intTy.isSigned()); 1347 if (width <= 64) { 1348 // Note that we *always* zero extend booleans (1-bit integers), 1349 // so that true/false is printed as 1/0 rather than -1/0. 1350 if (width == 1) 1351 conversion = PrintConversion::ZeroExt64; 1352 else if (width < 64) 1353 conversion = PrintConversion::SignExt64; 1354 printer = LLVM::lookupOrCreatePrintI64Fn( 1355 printOp->getParentOfType<ModuleOp>()); 1356 } else { 1357 return failure(); 1358 } 1359 } 1360 } else { 1361 return failure(); 1362 } 1363 1364 // Unroll vector into elementary print calls. 1365 int64_t rank = vectorType ? vectorType.getRank() : 0; 1366 emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank, 1367 conversion); 1368 emitCall(rewriter, printOp->getLoc(), 1369 LLVM::lookupOrCreatePrintNewlineFn( 1370 printOp->getParentOfType<ModuleOp>())); 1371 rewriter.eraseOp(printOp); 1372 return success(); 1373 } 1374 1375 private: 1376 enum class PrintConversion { 1377 // clang-format off 1378 None, 1379 ZeroExt64, 1380 SignExt64 1381 // clang-format on 1382 }; 1383 1384 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1385 Value value, VectorType vectorType, Operation *printer, 1386 int64_t rank, PrintConversion conversion) const { 1387 Location loc = op->getLoc(); 1388 if (rank == 0) { 1389 switch (conversion) { 1390 case PrintConversion::ZeroExt64: 1391 value = rewriter.create<ZeroExtendIOp>( 1392 loc, value, IntegerType::get(rewriter.getContext(), 64)); 1393 break; 1394 case PrintConversion::SignExt64: 1395 value = rewriter.create<SignExtendIOp>( 1396 loc, value, IntegerType::get(rewriter.getContext(), 64)); 1397 break; 1398 case PrintConversion::None: 1399 break; 1400 } 1401 emitCall(rewriter, loc, printer, value); 1402 return; 1403 } 1404 1405 emitCall(rewriter, loc, 1406 LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>())); 1407 Operation *printComma = 1408 LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>()); 1409 int64_t dim = vectorType.getDimSize(0); 1410 for (int64_t d = 0; d < dim; ++d) { 1411 auto reducedType = 1412 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1413 auto llvmType = typeConverter->convertType( 1414 rank > 1 ? reducedType : vectorType.getElementType()); 1415 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1416 llvmType, rank, d); 1417 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1418 conversion); 1419 if (d != dim - 1) 1420 emitCall(rewriter, loc, printComma); 1421 } 1422 emitCall(rewriter, loc, 1423 LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1424 } 1425 1426 // Helper to emit a call. 1427 static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1428 Operation *ref, ValueRange params = ValueRange()) { 1429 rewriter.create<LLVM::CallOp>(loc, TypeRange(), 1430 rewriter.getSymbolRefAttr(ref), params); 1431 } 1432 }; 1433 1434 /// Progressive lowering of ExtractStridedSliceOp to either: 1435 /// 1. express single offset extract as a direct shuffle. 1436 /// 2. extract + lower rank strided_slice + insert for the n-D case. 1437 class VectorExtractStridedSliceOpConversion 1438 : public OpRewritePattern<ExtractStridedSliceOp> { 1439 public: 1440 VectorExtractStridedSliceOpConversion(MLIRContext *ctx) 1441 : OpRewritePattern<ExtractStridedSliceOp>(ctx) { 1442 // This pattern creates recursive ExtractStridedSliceOp, but the recursion 1443 // is bounded as the rank is strictly decreasing. 1444 setHasBoundedRewriteRecursion(); 1445 } 1446 1447 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 1448 PatternRewriter &rewriter) const override { 1449 auto dstType = op.getType(); 1450 1451 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 1452 1453 int64_t offset = 1454 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 1455 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 1456 int64_t stride = 1457 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 1458 1459 auto loc = op.getLoc(); 1460 auto elemType = dstType.getElementType(); 1461 assert(elemType.isSignlessIntOrIndexOrFloat()); 1462 1463 // Single offset can be more efficiently shuffled. 1464 if (op.offsets().getValue().size() == 1) { 1465 SmallVector<int64_t, 4> offsets; 1466 offsets.reserve(size); 1467 for (int64_t off = offset, e = offset + size * stride; off < e; 1468 off += stride) 1469 offsets.push_back(off); 1470 rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(), 1471 op.vector(), 1472 rewriter.getI64ArrayAttr(offsets)); 1473 return success(); 1474 } 1475 1476 // Extract/insert on a lower ranked extract strided slice op. 1477 Value zero = rewriter.create<ConstantOp>(loc, elemType, 1478 rewriter.getZeroAttr(elemType)); 1479 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 1480 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 1481 off += stride, ++idx) { 1482 Value one = extractOne(rewriter, loc, op.vector(), off); 1483 Value extracted = rewriter.create<ExtractStridedSliceOp>( 1484 loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), 1485 getI64SubArray(op.sizes(), /* dropFront=*/1), 1486 getI64SubArray(op.strides(), /* dropFront=*/1)); 1487 res = insertOne(rewriter, loc, extracted, res, idx); 1488 } 1489 rewriter.replaceOp(op, res); 1490 return success(); 1491 } 1492 }; 1493 1494 } // namespace 1495 1496 /// Populate the given list with patterns that convert from Vector to LLVM. 1497 void mlir::populateVectorToLLVMConversionPatterns( 1498 LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 1499 bool reassociateFPReductions, bool enableIndexOptimizations) { 1500 MLIRContext *ctx = converter.getDialect()->getContext(); 1501 // clang-format off 1502 patterns.insert<VectorFMAOpNDRewritePattern, 1503 VectorInsertStridedSliceOpDifferentRankRewritePattern, 1504 VectorInsertStridedSliceOpSameRankRewritePattern, 1505 VectorExtractStridedSliceOpConversion>(ctx); 1506 patterns.insert<VectorReductionOpConversion>( 1507 converter, reassociateFPReductions); 1508 patterns.insert<VectorCreateMaskOpConversion, 1509 VectorTransferConversion<TransferReadOp>, 1510 VectorTransferConversion<TransferWriteOp>>( 1511 converter, enableIndexOptimizations); 1512 patterns 1513 .insert<VectorBitCastOpConversion, 1514 VectorShuffleOpConversion, 1515 VectorExtractElementOpConversion, 1516 VectorExtractOpConversion, 1517 VectorFMAOp1DConversion, 1518 VectorInsertElementOpConversion, 1519 VectorInsertOpConversion, 1520 VectorPrintOpConversion, 1521 VectorTypeCastOpConversion, 1522 VectorLoadStoreConversion<vector::LoadOp, 1523 vector::LoadOpAdaptor>, 1524 VectorLoadStoreConversion<vector::MaskedLoadOp, 1525 vector::MaskedLoadOpAdaptor>, 1526 VectorLoadStoreConversion<vector::StoreOp, 1527 vector::StoreOpAdaptor>, 1528 VectorLoadStoreConversion<vector::MaskedStoreOp, 1529 vector::MaskedStoreOpAdaptor>, 1530 VectorGatherOpConversion, 1531 VectorScatterOpConversion, 1532 VectorExpandLoadOpConversion, 1533 VectorCompressStoreOpConversion>(converter); 1534 // clang-format on 1535 } 1536 1537 void mlir::populateVectorToLLVMMatrixConversionPatterns( 1538 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1539 patterns.insert<VectorMatmulOpConversion>(converter); 1540 patterns.insert<VectorFlatTransposeOpConversion>(converter); 1541 } 1542