1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Vector dialect to SPIRV dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" 14 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 18 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 19 #include "mlir/Dialect/Utils/StaticValueUtils.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/IR/Attributes.h" 22 #include "mlir/IR/BuiltinAttributes.h" 23 #include "mlir/IR/BuiltinTypes.h" 24 #include "mlir/IR/Location.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "mlir/IR/TypeUtilities.h" 28 #include "mlir/Transforms/DialectConversion.h" 29 #include "llvm/ADT/ArrayRef.h" 30 #include "llvm/ADT/STLExtras.h" 31 #include "llvm/ADT/SmallVector.h" 32 #include "llvm/ADT/SmallVectorExtras.h" 33 #include "llvm/Support/FormatVariadic.h" 34 #include <cassert> 35 #include <cstdint> 36 #include <numeric> 37 38 using namespace mlir; 39 40 /// Returns the integer value from the first valid input element, assuming Value 41 /// inputs are defined by a constant index ops and Attribute inputs are integer 42 /// attributes. 43 static uint64_t getFirstIntValue(ArrayAttr attr) { 44 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 45 } 46 47 /// Returns the number of bits for the given scalar/vector type. 48 static int getNumBits(Type type) { 49 // TODO: This does not take into account any memory layout or widening 50 // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even 51 // though in practice it will likely be stored as in a 4xi64 vector register. 52 if (auto vectorType = dyn_cast<VectorType>(type)) 53 return vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); 54 return type.getIntOrFloatBitWidth(); 55 } 56 57 namespace { 58 59 struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> { 60 using OpConversionPattern::OpConversionPattern; 61 62 LogicalResult 63 matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor, 64 ConversionPatternRewriter &rewriter) const override { 65 Type dstType = getTypeConverter()->convertType(shapeCastOp.getType()); 66 if (!dstType) 67 return failure(); 68 69 // If dstType is same as the source type or the vector size is 1, it can be 70 // directly replaced by the source. 71 if (dstType == adaptor.getSource().getType() || 72 shapeCastOp.getResultVectorType().getNumElements() == 1) { 73 rewriter.replaceOp(shapeCastOp, adaptor.getSource()); 74 return success(); 75 } 76 77 // Lowering for size-n vectors when n > 1 hasn't been implemented. 78 return failure(); 79 } 80 }; 81 82 struct VectorBitcastConvert final 83 : public OpConversionPattern<vector::BitCastOp> { 84 using OpConversionPattern::OpConversionPattern; 85 86 LogicalResult 87 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor, 88 ConversionPatternRewriter &rewriter) const override { 89 Type dstType = getTypeConverter()->convertType(bitcastOp.getType()); 90 if (!dstType) 91 return failure(); 92 93 if (dstType == adaptor.getSource().getType()) { 94 rewriter.replaceOp(bitcastOp, adaptor.getSource()); 95 return success(); 96 } 97 98 // Check that the source and destination type have the same bitwidth. 99 // Depending on the target environment, we may need to emulate certain 100 // types, which can cause issue with bitcast. 101 Type srcType = adaptor.getSource().getType(); 102 if (getNumBits(dstType) != getNumBits(srcType)) { 103 return rewriter.notifyMatchFailure( 104 bitcastOp, 105 llvm::formatv("different source ({0}) and target ({1}) bitwidth", 106 srcType, dstType)); 107 } 108 109 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType, 110 adaptor.getSource()); 111 return success(); 112 } 113 }; 114 115 struct VectorBroadcastConvert final 116 : public OpConversionPattern<vector::BroadcastOp> { 117 using OpConversionPattern::OpConversionPattern; 118 119 LogicalResult 120 matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor, 121 ConversionPatternRewriter &rewriter) const override { 122 Type resultType = 123 getTypeConverter()->convertType(castOp.getResultVectorType()); 124 if (!resultType) 125 return failure(); 126 127 if (isa<spirv::ScalarType>(resultType)) { 128 rewriter.replaceOp(castOp, adaptor.getSource()); 129 return success(); 130 } 131 132 SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(), 133 adaptor.getSource()); 134 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType, 135 source); 136 return success(); 137 } 138 }; 139 140 struct VectorExtractOpConvert final 141 : public OpConversionPattern<vector::ExtractOp> { 142 using OpConversionPattern::OpConversionPattern; 143 144 LogicalResult 145 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 146 ConversionPatternRewriter &rewriter) const override { 147 Type dstType = getTypeConverter()->convertType(extractOp.getType()); 148 if (!dstType) 149 return failure(); 150 151 if (isa<spirv::ScalarType>(adaptor.getVector().getType())) { 152 rewriter.replaceOp(extractOp, adaptor.getVector()); 153 return success(); 154 } 155 156 if (std::optional<int64_t> id = 157 getConstantIntValue(extractOp.getMixedPosition()[0])) 158 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 159 extractOp, dstType, adaptor.getVector(), 160 rewriter.getI32ArrayAttr(id.value())); 161 else 162 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( 163 extractOp, dstType, adaptor.getVector(), 164 adaptor.getDynamicPosition()[0]); 165 return success(); 166 } 167 }; 168 169 struct VectorExtractStridedSliceOpConvert final 170 : public OpConversionPattern<vector::ExtractStridedSliceOp> { 171 using OpConversionPattern::OpConversionPattern; 172 173 LogicalResult 174 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, 175 ConversionPatternRewriter &rewriter) const override { 176 Type dstType = getTypeConverter()->convertType(extractOp.getType()); 177 if (!dstType) 178 return failure(); 179 180 uint64_t offset = getFirstIntValue(extractOp.getOffsets()); 181 uint64_t size = getFirstIntValue(extractOp.getSizes()); 182 uint64_t stride = getFirstIntValue(extractOp.getStrides()); 183 if (stride != 1) 184 return failure(); 185 186 Value srcVector = adaptor.getOperands().front(); 187 188 // Extract vector<1xT> case. 189 if (isa<spirv::ScalarType>(dstType)) { 190 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp, 191 srcVector, offset); 192 return success(); 193 } 194 195 SmallVector<int32_t, 2> indices(size); 196 std::iota(indices.begin(), indices.end(), offset); 197 198 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 199 extractOp, dstType, srcVector, srcVector, 200 rewriter.getI32ArrayAttr(indices)); 201 202 return success(); 203 } 204 }; 205 206 template <class SPIRVFMAOp> 207 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> { 208 using OpConversionPattern::OpConversionPattern; 209 210 LogicalResult 211 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 212 ConversionPatternRewriter &rewriter) const override { 213 Type dstType = getTypeConverter()->convertType(fmaOp.getType()); 214 if (!dstType) 215 return failure(); 216 rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(), 217 adaptor.getRhs(), adaptor.getAcc()); 218 return success(); 219 } 220 }; 221 222 struct VectorFromElementsOpConvert final 223 : public OpConversionPattern<vector::FromElementsOp> { 224 using OpConversionPattern::OpConversionPattern; 225 226 LogicalResult 227 matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor, 228 ConversionPatternRewriter &rewriter) const override { 229 Type resultType = getTypeConverter()->convertType(op.getType()); 230 if (!resultType) 231 return failure(); 232 OperandRange elements = op.getElements(); 233 if (isa<spirv::ScalarType>(resultType)) { 234 // In the case with a single scalar operand / single-element result, 235 // pass through the scalar. 236 rewriter.replaceOp(op, elements[0]); 237 return success(); 238 } 239 // SPIRVTypeConverter rejects vectors with rank > 1, so multi-dimensional 240 // vector.from_elements cases should not need to be handled, only 1d. 241 assert(cast<VectorType>(resultType).getRank() == 1); 242 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType, 243 elements); 244 return success(); 245 } 246 }; 247 248 struct VectorInsertOpConvert final 249 : public OpConversionPattern<vector::InsertOp> { 250 using OpConversionPattern::OpConversionPattern; 251 252 LogicalResult 253 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 254 ConversionPatternRewriter &rewriter) const override { 255 if (isa<VectorType>(insertOp.getSourceType())) 256 return rewriter.notifyMatchFailure(insertOp, "unsupported vector source"); 257 if (!getTypeConverter()->convertType(insertOp.getDestVectorType())) 258 return rewriter.notifyMatchFailure(insertOp, 259 "unsupported dest vector type"); 260 261 // Special case for inserting scalar values into size-1 vectors. 262 if (insertOp.getSourceType().isIntOrFloat() && 263 insertOp.getDestVectorType().getNumElements() == 1) { 264 rewriter.replaceOp(insertOp, adaptor.getSource()); 265 return success(); 266 } 267 268 if (std::optional<int64_t> id = 269 getConstantIntValue(insertOp.getMixedPosition()[0])) 270 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( 271 insertOp, adaptor.getSource(), adaptor.getDest(), id.value()); 272 else 273 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( 274 insertOp, insertOp.getDest(), adaptor.getSource(), 275 adaptor.getDynamicPosition()[0]); 276 return success(); 277 } 278 }; 279 280 struct VectorExtractElementOpConvert final 281 : public OpConversionPattern<vector::ExtractElementOp> { 282 using OpConversionPattern::OpConversionPattern; 283 284 LogicalResult 285 matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor, 286 ConversionPatternRewriter &rewriter) const override { 287 Type resultType = getTypeConverter()->convertType(extractOp.getType()); 288 if (!resultType) 289 return failure(); 290 291 if (isa<spirv::ScalarType>(adaptor.getVector().getType())) { 292 rewriter.replaceOp(extractOp, adaptor.getVector()); 293 return success(); 294 } 295 296 APInt cstPos; 297 if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) 298 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 299 extractOp, resultType, adaptor.getVector(), 300 rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())})); 301 else 302 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( 303 extractOp, resultType, adaptor.getVector(), adaptor.getPosition()); 304 return success(); 305 } 306 }; 307 308 struct VectorInsertElementOpConvert final 309 : public OpConversionPattern<vector::InsertElementOp> { 310 using OpConversionPattern::OpConversionPattern; 311 312 LogicalResult 313 matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor, 314 ConversionPatternRewriter &rewriter) const override { 315 Type vectorType = getTypeConverter()->convertType(insertOp.getType()); 316 if (!vectorType) 317 return failure(); 318 319 if (isa<spirv::ScalarType>(vectorType)) { 320 rewriter.replaceOp(insertOp, adaptor.getSource()); 321 return success(); 322 } 323 324 APInt cstPos; 325 if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) 326 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( 327 insertOp, adaptor.getSource(), adaptor.getDest(), 328 cstPos.getSExtValue()); 329 else 330 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( 331 insertOp, vectorType, insertOp.getDest(), adaptor.getSource(), 332 adaptor.getPosition()); 333 return success(); 334 } 335 }; 336 337 struct VectorInsertStridedSliceOpConvert final 338 : public OpConversionPattern<vector::InsertStridedSliceOp> { 339 using OpConversionPattern::OpConversionPattern; 340 341 LogicalResult 342 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor, 343 ConversionPatternRewriter &rewriter) const override { 344 Value srcVector = adaptor.getOperands().front(); 345 Value dstVector = adaptor.getOperands().back(); 346 347 uint64_t stride = getFirstIntValue(insertOp.getStrides()); 348 if (stride != 1) 349 return failure(); 350 uint64_t offset = getFirstIntValue(insertOp.getOffsets()); 351 352 if (isa<spirv::ScalarType>(srcVector.getType())) { 353 assert(!isa<spirv::ScalarType>(dstVector.getType())); 354 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( 355 insertOp, dstVector.getType(), srcVector, dstVector, 356 rewriter.getI32ArrayAttr(offset)); 357 return success(); 358 } 359 360 uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements(); 361 uint64_t insertSize = 362 cast<VectorType>(srcVector.getType()).getNumElements(); 363 364 SmallVector<int32_t, 2> indices(totalSize); 365 std::iota(indices.begin(), indices.end(), 0); 366 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize, 367 totalSize); 368 369 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 370 insertOp, dstVector.getType(), dstVector, srcVector, 371 rewriter.getI32ArrayAttr(indices)); 372 373 return success(); 374 } 375 }; 376 377 static SmallVector<Value> extractAllElements( 378 vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor, 379 VectorType srcVectorType, ConversionPatternRewriter &rewriter) { 380 int numElements = static_cast<int>(srcVectorType.getDimSize(0)); 381 SmallVector<Value> values; 382 values.reserve(numElements + (adaptor.getAcc() ? 1 : 0)); 383 Location loc = reduceOp.getLoc(); 384 385 for (int i = 0; i < numElements; ++i) { 386 values.push_back(rewriter.create<spirv::CompositeExtractOp>( 387 loc, srcVectorType.getElementType(), adaptor.getVector(), 388 rewriter.getI32ArrayAttr({i}))); 389 } 390 if (Value acc = adaptor.getAcc()) 391 values.push_back(acc); 392 393 return values; 394 } 395 396 struct ReductionRewriteInfo { 397 Type resultType; 398 SmallVector<Value> extractedElements; 399 }; 400 401 FailureOr<ReductionRewriteInfo> static getReductionInfo( 402 vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor, 403 ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) { 404 Type resultType = typeConverter.convertType(op.getType()); 405 if (!resultType) 406 return failure(); 407 408 auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType()); 409 if (!srcVectorType || srcVectorType.getRank() != 1) 410 return rewriter.notifyMatchFailure(op, "not a 1-D vector source"); 411 412 SmallVector<Value> extractedElements = 413 extractAllElements(op, adaptor, srcVectorType, rewriter); 414 415 return ReductionRewriteInfo{resultType, std::move(extractedElements)}; 416 } 417 418 template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp, 419 typename SPIRVSMinOp> 420 struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> { 421 using OpConversionPattern::OpConversionPattern; 422 423 LogicalResult 424 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor, 425 ConversionPatternRewriter &rewriter) const override { 426 auto reductionInfo = 427 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter()); 428 if (failed(reductionInfo)) 429 return failure(); 430 431 auto [resultType, extractedElements] = *reductionInfo; 432 Location loc = reduceOp->getLoc(); 433 Value result = extractedElements.front(); 434 for (Value next : llvm::drop_begin(extractedElements)) { 435 switch (reduceOp.getKind()) { 436 437 #define INT_AND_FLOAT_CASE(kind, iop, fop) \ 438 case vector::CombiningKind::kind: \ 439 if (llvm::isa<IntegerType>(resultType)) { \ 440 result = rewriter.create<spirv::iop>(loc, resultType, result, next); \ 441 } else { \ 442 assert(llvm::isa<FloatType>(resultType)); \ 443 result = rewriter.create<spirv::fop>(loc, resultType, result, next); \ 444 } \ 445 break 446 447 #define INT_OR_FLOAT_CASE(kind, fop) \ 448 case vector::CombiningKind::kind: \ 449 result = rewriter.create<fop>(loc, resultType, result, next); \ 450 break 451 452 INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp); 453 INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp); 454 INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp); 455 INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp); 456 INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp); 457 INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp); 458 459 case vector::CombiningKind::AND: 460 case vector::CombiningKind::OR: 461 case vector::CombiningKind::XOR: 462 return rewriter.notifyMatchFailure(reduceOp, "unimplemented"); 463 default: 464 return rewriter.notifyMatchFailure(reduceOp, "not handled here"); 465 } 466 #undef INT_AND_FLOAT_CASE 467 #undef INT_OR_FLOAT_CASE 468 } 469 470 rewriter.replaceOp(reduceOp, result); 471 return success(); 472 } 473 }; 474 475 template <typename SPIRVFMaxOp, typename SPIRVFMinOp> 476 struct VectorReductionFloatMinMax final 477 : OpConversionPattern<vector::ReductionOp> { 478 using OpConversionPattern::OpConversionPattern; 479 480 LogicalResult 481 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor, 482 ConversionPatternRewriter &rewriter) const override { 483 auto reductionInfo = 484 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter()); 485 if (failed(reductionInfo)) 486 return failure(); 487 488 auto [resultType, extractedElements] = *reductionInfo; 489 Location loc = reduceOp->getLoc(); 490 Value result = extractedElements.front(); 491 for (Value next : llvm::drop_begin(extractedElements)) { 492 switch (reduceOp.getKind()) { 493 494 #define INT_OR_FLOAT_CASE(kind, fop) \ 495 case vector::CombiningKind::kind: \ 496 result = rewriter.create<fop>(loc, resultType, result, next); \ 497 break 498 499 INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp); 500 INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp); 501 INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp); 502 INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp); 503 504 default: 505 return rewriter.notifyMatchFailure(reduceOp, "not handled here"); 506 } 507 #undef INT_OR_FLOAT_CASE 508 } 509 510 rewriter.replaceOp(reduceOp, result); 511 return success(); 512 } 513 }; 514 515 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> { 516 public: 517 using OpConversionPattern<vector::SplatOp>::OpConversionPattern; 518 519 LogicalResult 520 matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor, 521 ConversionPatternRewriter &rewriter) const override { 522 Type dstType = getTypeConverter()->convertType(op.getType()); 523 if (!dstType) 524 return failure(); 525 if (isa<spirv::ScalarType>(dstType)) { 526 rewriter.replaceOp(op, adaptor.getInput()); 527 } else { 528 auto dstVecType = cast<VectorType>(dstType); 529 SmallVector<Value, 4> source(dstVecType.getNumElements(), 530 adaptor.getInput()); 531 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType, 532 source); 533 } 534 return success(); 535 } 536 }; 537 538 struct VectorShuffleOpConvert final 539 : public OpConversionPattern<vector::ShuffleOp> { 540 using OpConversionPattern::OpConversionPattern; 541 542 LogicalResult 543 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 544 ConversionPatternRewriter &rewriter) const override { 545 VectorType oldResultType = shuffleOp.getResultVectorType(); 546 Type newResultType = getTypeConverter()->convertType(oldResultType); 547 if (!newResultType) 548 return rewriter.notifyMatchFailure(shuffleOp, 549 "unsupported result vector type"); 550 551 auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask()); 552 553 VectorType oldV1Type = shuffleOp.getV1VectorType(); 554 VectorType oldV2Type = shuffleOp.getV2VectorType(); 555 556 // When both operands and the result are SPIR-V vectors, emit a SPIR-V 557 // shuffle. 558 if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 && 559 oldResultType.getNumElements() > 1) { 560 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 561 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(), 562 rewriter.getI32ArrayAttr(mask)); 563 return success(); 564 } 565 566 // When at least one of the operands or the result becomes a scalar after 567 // type conversion for SPIR-V, extract all the required elements and 568 // construct the result vector. 569 auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()]( 570 Value scalarOrVec, int32_t idx) -> Value { 571 if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType())) 572 return rewriter.create<spirv::CompositeExtractOp>(loc, scalarOrVec, 573 idx); 574 575 assert(idx == 0 && "Invalid scalar element index"); 576 return scalarOrVec; 577 }; 578 579 int32_t numV1Elems = oldV1Type.getNumElements(); 580 SmallVector<Value> newOperands(mask.size()); 581 for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) { 582 Value vec = adaptor.getV1(); 583 int32_t elementIdx = shuffleIdx; 584 if (elementIdx >= numV1Elems) { 585 vec = adaptor.getV2(); 586 elementIdx -= numV1Elems; 587 } 588 589 newOperand = getElementAtIdx(vec, elementIdx); 590 } 591 592 // Handle the scalar result corner case. 593 if (newOperands.size() == 1) { 594 rewriter.replaceOp(shuffleOp, newOperands.front()); 595 return success(); 596 } 597 598 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( 599 shuffleOp, newResultType, newOperands); 600 return success(); 601 } 602 }; 603 604 struct VectorInterleaveOpConvert final 605 : public OpConversionPattern<vector::InterleaveOp> { 606 using OpConversionPattern::OpConversionPattern; 607 608 LogicalResult 609 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor, 610 ConversionPatternRewriter &rewriter) const override { 611 // Check the result vector type. 612 VectorType oldResultType = interleaveOp.getResultVectorType(); 613 Type newResultType = getTypeConverter()->convertType(oldResultType); 614 if (!newResultType) 615 return rewriter.notifyMatchFailure(interleaveOp, 616 "unsupported result vector type"); 617 618 // Interleave the indices. 619 VectorType sourceType = interleaveOp.getSourceVectorType(); 620 int n = sourceType.getNumElements(); 621 622 // Input vectors of size 1 are converted to scalars by the type converter. 623 // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to 624 // use `spirv::CompositeConstructOp`. 625 if (n == 1) { 626 Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()}; 627 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( 628 interleaveOp, newResultType, newOperands); 629 return success(); 630 } 631 632 auto seq = llvm::seq<int64_t>(2 * n); 633 auto indices = llvm::map_to_vector( 634 seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }); 635 636 // Emit a SPIR-V shuffle. 637 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 638 interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(), 639 rewriter.getI32ArrayAttr(indices)); 640 641 return success(); 642 } 643 }; 644 645 struct VectorDeinterleaveOpConvert final 646 : public OpConversionPattern<vector::DeinterleaveOp> { 647 using OpConversionPattern::OpConversionPattern; 648 649 LogicalResult 650 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor, 651 ConversionPatternRewriter &rewriter) const override { 652 653 // Check the result vector type. 654 VectorType oldResultType = deinterleaveOp.getResultVectorType(); 655 Type newResultType = getTypeConverter()->convertType(oldResultType); 656 if (!newResultType) 657 return rewriter.notifyMatchFailure(deinterleaveOp, 658 "unsupported result vector type"); 659 660 Location loc = deinterleaveOp->getLoc(); 661 662 // Deinterleave the indices. 663 Value sourceVector = adaptor.getSource(); 664 VectorType sourceType = deinterleaveOp.getSourceVectorType(); 665 int n = sourceType.getNumElements(); 666 667 // Output vectors of size 1 are converted to scalars by the type converter. 668 // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to 669 // use `spirv::CompositeExtractOp`. 670 if (n == 2) { 671 auto elem0 = rewriter.create<spirv::CompositeExtractOp>( 672 loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0})); 673 674 auto elem1 = rewriter.create<spirv::CompositeExtractOp>( 675 loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1})); 676 677 rewriter.replaceOp(deinterleaveOp, {elem0, elem1}); 678 return success(); 679 } 680 681 // Indices for `shuffleEven` (result 0). 682 auto seqEven = llvm::seq<int64_t>(n / 2); 683 auto indicesEven = 684 llvm::map_to_vector(seqEven, [](int i) { return i * 2; }); 685 686 // Indices for `shuffleOdd` (result 1). 687 auto seqOdd = llvm::seq<int64_t>(n / 2); 688 auto indicesOdd = 689 llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; }); 690 691 // Create two SPIR-V shuffles. 692 auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>( 693 loc, newResultType, sourceVector, sourceVector, 694 rewriter.getI32ArrayAttr(indicesEven)); 695 696 auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>( 697 loc, newResultType, sourceVector, sourceVector, 698 rewriter.getI32ArrayAttr(indicesOdd)); 699 700 rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd}); 701 return success(); 702 } 703 }; 704 705 struct VectorLoadOpConverter final 706 : public OpConversionPattern<vector::LoadOp> { 707 using OpConversionPattern::OpConversionPattern; 708 709 LogicalResult 710 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor, 711 ConversionPatternRewriter &rewriter) const override { 712 auto memrefType = loadOp.getMemRefType(); 713 auto attr = 714 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace()); 715 if (!attr) 716 return rewriter.notifyMatchFailure( 717 loadOp, "expected spirv.storage_class memory space"); 718 719 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 720 auto loc = loadOp.getLoc(); 721 Value accessChain = 722 spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(), 723 adaptor.getIndices(), loc, rewriter); 724 if (!accessChain) 725 return rewriter.notifyMatchFailure( 726 loadOp, "failed to get memref element pointer"); 727 728 spirv::StorageClass storageClass = attr.getValue(); 729 auto vectorType = loadOp.getVectorType(); 730 auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass); 731 Value castedAccessChain = 732 rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain); 733 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType, 734 castedAccessChain); 735 736 return success(); 737 } 738 }; 739 740 struct VectorStoreOpConverter final 741 : public OpConversionPattern<vector::StoreOp> { 742 using OpConversionPattern::OpConversionPattern; 743 744 LogicalResult 745 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor, 746 ConversionPatternRewriter &rewriter) const override { 747 auto memrefType = storeOp.getMemRefType(); 748 auto attr = 749 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace()); 750 if (!attr) 751 return rewriter.notifyMatchFailure( 752 storeOp, "expected spirv.storage_class memory space"); 753 754 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 755 auto loc = storeOp.getLoc(); 756 Value accessChain = 757 spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(), 758 adaptor.getIndices(), loc, rewriter); 759 if (!accessChain) 760 return rewriter.notifyMatchFailure( 761 storeOp, "failed to get memref element pointer"); 762 763 spirv::StorageClass storageClass = attr.getValue(); 764 auto vectorType = storeOp.getVectorType(); 765 auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass); 766 Value castedAccessChain = 767 rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain); 768 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain, 769 adaptor.getValueToStore()); 770 771 return success(); 772 } 773 }; 774 775 struct VectorReductionToIntDotProd final 776 : OpRewritePattern<vector::ReductionOp> { 777 using OpRewritePattern::OpRewritePattern; 778 779 LogicalResult matchAndRewrite(vector::ReductionOp op, 780 PatternRewriter &rewriter) const override { 781 if (op.getKind() != vector::CombiningKind::ADD) 782 return rewriter.notifyMatchFailure(op, "combining kind is not 'add'"); 783 784 auto resultType = dyn_cast<IntegerType>(op.getType()); 785 if (!resultType) 786 return rewriter.notifyMatchFailure(op, "result is not an integer"); 787 788 int64_t resultBitwidth = resultType.getIntOrFloatBitWidth(); 789 if (!llvm::is_contained({32, 64}, resultBitwidth)) 790 return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth"); 791 792 VectorType inVecTy = op.getSourceVectorType(); 793 if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) || 794 inVecTy.getShape().size() != 1 || inVecTy.isScalable()) 795 return rewriter.notifyMatchFailure(op, "unsupported vector shape"); 796 797 auto mul = op.getVector().getDefiningOp<arith::MulIOp>(); 798 if (!mul) 799 return rewriter.notifyMatchFailure( 800 op, "reduction operand is not 'arith.muli'"); 801 802 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp, 803 spirv::SDotAccSatOp, false>(op, mul, rewriter))) 804 return success(); 805 806 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp, 807 spirv::UDotAccSatOp, false>(op, mul, rewriter))) 808 return success(); 809 810 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp, 811 spirv::SUDotAccSatOp, false>(op, mul, rewriter))) 812 return success(); 813 814 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp, 815 spirv::SUDotAccSatOp, true>(op, mul, rewriter))) 816 return success(); 817 818 return failure(); 819 } 820 821 private: 822 template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp, 823 typename DotAccOp, bool SwapOperands> 824 static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul, 825 PatternRewriter &rewriter) { 826 auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>(); 827 if (!lhs) 828 return failure(); 829 Value lhsIn = lhs.getIn(); 830 auto lhsInType = cast<VectorType>(lhsIn.getType()); 831 if (!lhsInType.getElementType().isInteger(8)) 832 return failure(); 833 834 auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>(); 835 if (!rhs) 836 return failure(); 837 Value rhsIn = rhs.getIn(); 838 auto rhsInType = cast<VectorType>(rhsIn.getType()); 839 if (!rhsInType.getElementType().isInteger(8)) 840 return failure(); 841 842 if (op.getSourceVectorType().getNumElements() == 3) { 843 IntegerType i8Type = rewriter.getI8Type(); 844 auto v4i8Type = VectorType::get({4}, i8Type); 845 Location loc = op.getLoc(); 846 Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter); 847 lhsIn = rewriter.create<spirv::CompositeConstructOp>( 848 loc, v4i8Type, ValueRange{lhsIn, zero}); 849 rhsIn = rewriter.create<spirv::CompositeConstructOp>( 850 loc, v4i8Type, ValueRange{rhsIn, zero}); 851 } 852 853 // There's no variant of dot prod ops for unsigned LHS and signed RHS, so 854 // we have to swap operands instead in that case. 855 if (SwapOperands) 856 std::swap(lhsIn, rhsIn); 857 858 if (Value acc = op.getAcc()) { 859 rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc, 860 nullptr); 861 } else { 862 rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn, 863 nullptr); 864 } 865 866 return success(); 867 } 868 }; 869 870 struct VectorReductionToFPDotProd final 871 : OpConversionPattern<vector::ReductionOp> { 872 using OpConversionPattern::OpConversionPattern; 873 874 LogicalResult 875 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor, 876 ConversionPatternRewriter &rewriter) const override { 877 if (op.getKind() != vector::CombiningKind::ADD) 878 return rewriter.notifyMatchFailure(op, "combining kind is not 'add'"); 879 880 auto resultType = getTypeConverter()->convertType<FloatType>(op.getType()); 881 if (!resultType) 882 return rewriter.notifyMatchFailure(op, "result is not a float"); 883 884 Value vec = adaptor.getVector(); 885 Value acc = adaptor.getAcc(); 886 887 auto vectorType = dyn_cast<VectorType>(vec.getType()); 888 if (!vectorType) { 889 assert(isa<FloatType>(vec.getType()) && 890 "Expected the vector to be scalarized"); 891 if (acc) { 892 rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec); 893 return success(); 894 } 895 896 rewriter.replaceOp(op, vec); 897 return success(); 898 } 899 900 Location loc = op.getLoc(); 901 Value lhs; 902 Value rhs; 903 if (auto mul = vec.getDefiningOp<arith::MulFOp>()) { 904 lhs = mul.getLhs(); 905 rhs = mul.getRhs(); 906 } else { 907 // If the operand is not a mul, use a vector of ones for the dot operand 908 // to just sum up all values. 909 lhs = vec; 910 Attribute oneAttr = 911 rewriter.getFloatAttr(vectorType.getElementType(), 1.0); 912 oneAttr = SplatElementsAttr::get(vectorType, oneAttr); 913 rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr); 914 } 915 assert(lhs); 916 assert(rhs); 917 918 Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs); 919 if (acc) 920 res = rewriter.create<spirv::FAddOp>(loc, acc, res); 921 922 rewriter.replaceOp(op, res); 923 return success(); 924 } 925 }; 926 927 struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> { 928 using OpConversionPattern::OpConversionPattern; 929 930 LogicalResult 931 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor, 932 ConversionPatternRewriter &rewriter) const override { 933 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 934 Type dstType = typeConverter.convertType(stepOp.getType()); 935 if (!dstType) 936 return failure(); 937 938 Location loc = stepOp.getLoc(); 939 int64_t numElements = stepOp.getType().getNumElements(); 940 auto intType = 941 rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth()); 942 943 // Input vectors of size 1 are converted to scalars by the type converter. 944 // We just create a constant in this case. 945 if (numElements == 1) { 946 Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter); 947 rewriter.replaceOp(stepOp, zero); 948 return success(); 949 } 950 951 SmallVector<Value> source; 952 source.reserve(numElements); 953 for (int64_t i = 0; i < numElements; ++i) { 954 Attribute intAttr = rewriter.getIntegerAttr(intType, i); 955 Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr); 956 source.push_back(constOp); 957 } 958 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType, 959 source); 960 return success(); 961 } 962 }; 963 964 } // namespace 965 #define CL_INT_MAX_MIN_OPS \ 966 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp 967 968 #define GL_INT_MAX_MIN_OPS \ 969 spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp 970 971 #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp 972 #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp 973 974 void mlir::populateVectorToSPIRVPatterns( 975 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { 976 patterns.add< 977 VectorBitcastConvert, VectorBroadcastConvert, 978 VectorExtractElementOpConvert, VectorExtractOpConvert, 979 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>, 980 VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert, 981 VectorInsertElementOpConvert, VectorInsertOpConvert, 982 VectorReductionPattern<GL_INT_MAX_MIN_OPS>, 983 VectorReductionPattern<CL_INT_MAX_MIN_OPS>, 984 VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>, 985 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast, 986 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, 987 VectorInterleaveOpConvert, VectorDeinterleaveOpConvert, 988 VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter, 989 VectorStepOpConvert>(typeConverter, patterns.getContext(), 990 PatternBenefit(1)); 991 992 // Make sure that the more specialized dot product pattern has higher benefit 993 // than the generic one that extracts all elements. 994 patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(), 995 PatternBenefit(2)); 996 } 997 998 void mlir::populateVectorReductionToSPIRVDotProductPatterns( 999 RewritePatternSet &patterns) { 1000 patterns.add<VectorReductionToIntDotProd>(patterns.getContext()); 1001 } 1002