1 //===- VectorLinearize.cpp - vector linearization transforms --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns and pass for linearizing ND vectors into 1D. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Vector/IR/VectorOps.h" 15 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/BuiltinAttributes.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "llvm/ADT/ArrayRef.h" 23 #include <cstdint> 24 #include <numeric> 25 26 using namespace mlir; 27 28 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { 29 auto resultTypes = op->getResultTypes(); 30 for (auto resType : resultTypes) { 31 VectorType vecType = dyn_cast<VectorType>(resType); 32 // Reject index since getElementTypeBitWidth will abort for Index types. 33 if (!vecType || vecType.getElementType().isIndex()) 34 return false; 35 // There are no dimension to fold if it is a 0-D vector. 36 if (vecType.getRank() == 0) 37 return false; 38 unsigned trailingVecDimBitWidth = 39 vecType.getShape().back() * vecType.getElementTypeBitWidth(); 40 if (trailingVecDimBitWidth >= targetBitWidth) 41 return false; 42 } 43 return true; 44 } 45 46 static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) { 47 VectorType vecType = dyn_cast<VectorType>(t); 48 // Reject index since getElementTypeBitWidth will abort for Index types. 49 if (!vecType || vecType.getElementType().isIndex()) 50 return false; 51 // There are no dimension to fold if it is a 0-D vector. 52 if (vecType.getRank() == 0) 53 return false; 54 unsigned trailingVecDimBitWidth = 55 vecType.getShape().back() * vecType.getElementTypeBitWidth(); 56 return trailingVecDimBitWidth <= targetBitWidth; 57 } 58 59 namespace { 60 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> { 61 using OpConversionPattern::OpConversionPattern; 62 LinearizeConstant( 63 const TypeConverter &typeConverter, MLIRContext *context, 64 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 65 PatternBenefit benefit = 1) 66 : OpConversionPattern(typeConverter, context, benefit), 67 targetVectorBitWidth(targetVectBitWidth) {} 68 LogicalResult 69 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, 70 ConversionPatternRewriter &rewriter) const override { 71 Location loc = constOp.getLoc(); 72 auto resType = 73 getTypeConverter()->convertType<VectorType>(constOp.getType()); 74 75 if (!resType) 76 return rewriter.notifyMatchFailure(loc, "can't convert return type"); 77 78 if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue())) 79 return rewriter.notifyMatchFailure( 80 loc, 81 "Cannot linearize a constant scalable vector that's not a splat"); 82 83 if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth)) 84 return rewriter.notifyMatchFailure( 85 loc, "Can't flatten since targetBitWidth <= OpSize"); 86 auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue()); 87 if (!dstElementsAttr) 88 return rewriter.notifyMatchFailure(loc, "unsupported attr type"); 89 90 dstElementsAttr = dstElementsAttr.reshape(resType); 91 rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType, 92 dstElementsAttr); 93 return success(); 94 } 95 96 private: 97 unsigned targetVectorBitWidth; 98 }; 99 100 struct LinearizeVectorizable final 101 : OpTraitConversionPattern<OpTrait::Vectorizable> { 102 using OpTraitConversionPattern::OpTraitConversionPattern; 103 104 public: 105 LinearizeVectorizable( 106 const TypeConverter &typeConverter, MLIRContext *context, 107 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 108 PatternBenefit benefit = 1) 109 : OpTraitConversionPattern(typeConverter, context, benefit), 110 targetVectorBitWidth(targetVectBitWidth) {} 111 LogicalResult 112 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 113 ConversionPatternRewriter &rewriter) const override { 114 if (!isLessThanTargetBitWidth(op, targetVectorBitWidth)) 115 return rewriter.notifyMatchFailure( 116 op->getLoc(), "Can't flatten since targetBitWidth <= OpSize"); 117 FailureOr<Operation *> newOp = 118 convertOpResultTypes(op, operands, *getTypeConverter(), rewriter); 119 if (failed(newOp)) 120 return failure(); 121 122 rewriter.replaceOp(op, (*newOp)->getResults()); 123 return success(); 124 } 125 126 private: 127 unsigned targetVectorBitWidth; 128 }; 129 130 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works 131 /// on a linearized vector. 132 /// Following, 133 /// vector.extract_strided_slice %source 134 /// { offsets = [..], strides = [..], sizes = [..] } 135 /// is converted to : 136 /// %source_1d = vector.shape_cast %source 137 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] 138 /// %out_nd = vector.shape_cast %out_1d 139 /// `shuffle_indices_1d` is computed using the offsets and sizes of the 140 /// extraction. 141 struct LinearizeVectorExtractStridedSlice final 142 : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> { 143 using OpConversionPattern::OpConversionPattern; 144 LinearizeVectorExtractStridedSlice( 145 const TypeConverter &typeConverter, MLIRContext *context, 146 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 147 PatternBenefit benefit = 1) 148 : OpConversionPattern(typeConverter, context, benefit), 149 targetVectorBitWidth(targetVectBitWidth) {} 150 151 LogicalResult 152 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, 153 ConversionPatternRewriter &rewriter) const override { 154 VectorType dstType = 155 getTypeConverter()->convertType<VectorType>(extractOp.getType()); 156 assert(dstType && "vector type destination expected."); 157 if (extractOp.getVector().getType().isScalable() || dstType.isScalable()) 158 return rewriter.notifyMatchFailure(extractOp, 159 "scalable vectors are not supported."); 160 if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) 161 return rewriter.notifyMatchFailure( 162 extractOp, "Can't flatten since targetBitWidth <= OpSize"); 163 164 ArrayAttr offsets = extractOp.getOffsets(); 165 ArrayAttr sizes = extractOp.getSizes(); 166 ArrayAttr strides = extractOp.getStrides(); 167 if (!isConstantIntValue(strides[0], 1)) 168 return rewriter.notifyMatchFailure( 169 extractOp, "Strided slice with stride != 1 is not supported."); 170 Value srcVector = adaptor.getVector(); 171 // If kD offsets are specified for nD source vector (n > k), the granularity 172 // of the extraction is greater than 1. In this case last (n-k) dimensions 173 // form the extraction granularity. 174 // Example : 175 // vector.extract_strided_slice %src { 176 // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : 177 // vector<4x8x8xf32> to vector<2x2x8xf32> 178 // Here, extraction granularity is 8. 179 int64_t extractGranularitySize = 1; 180 int64_t nD = extractOp.getSourceVectorType().getRank(); 181 int64_t kD = (int64_t)offsets.size(); 182 int64_t k = kD; 183 while (k < nD) { 184 extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k]; 185 ++k; 186 } 187 // Get total number of extracted slices. 188 int64_t nExtractedSlices = 1; 189 for (Attribute size : sizes) { 190 nExtractedSlices *= cast<IntegerAttr>(size).getInt(); 191 } 192 // Compute the strides of the source vector considering first k dimensions. 193 llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize); 194 for (int i = kD - 2; i >= 0; --i) { 195 sourceStrides[i] = sourceStrides[i + 1] * 196 extractOp.getSourceVectorType().getShape()[i + 1]; 197 } 198 // Final shuffle indices has nExtractedSlices * extractGranularitySize 199 // elements. 200 llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * 201 extractGranularitySize); 202 // Compute the strides of the extracted kD vector. 203 llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1); 204 // Compute extractedStrides. 205 for (int i = kD - 2; i >= 0; --i) { 206 extractedStrides[i] = 207 extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt(); 208 } 209 // Iterate over all extracted slices from 0 to nExtractedSlices - 1 210 // and compute the multi-dimensional index and the corresponding linearized 211 // index within the source vector. 212 for (int64_t i = 0; i < nExtractedSlices; ++i) { 213 int64_t index = i; 214 // Compute the corresponding multi-dimensional index. 215 llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0); 216 for (int64_t j = 0; j < kD; ++j) { 217 multiDimIndex[j] = (index / extractedStrides[j]); 218 index -= multiDimIndex[j] * extractedStrides[j]; 219 } 220 // Compute the corresponding linearized index in the source vector 221 // i.e. shift the multiDimIndex by the offsets. 222 int64_t linearizedIndex = 0; 223 for (int64_t j = 0; j < kD; ++j) { 224 linearizedIndex += 225 (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) * 226 sourceStrides[j]; 227 } 228 // Fill the indices array form linearizedIndex to linearizedIndex + 229 // extractGranularitySize. 230 for (int64_t j = 0; j < extractGranularitySize; ++j) { 231 indices[i * extractGranularitySize + j] = linearizedIndex + j; 232 } 233 } 234 // Perform a shuffle to extract the kD vector. 235 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 236 extractOp, dstType, srcVector, srcVector, indices); 237 return success(); 238 } 239 240 private: 241 unsigned targetVectorBitWidth; 242 }; 243 244 /// This pattern converts the ShuffleOp that works on nD (n > 1) 245 /// vectors to a ShuffleOp that works on linearized vectors. 246 /// Following, 247 /// vector.shuffle %v1, %v2 [ shuffle_indices ] 248 /// is converted to : 249 /// %v1_1d = vector.shape_cast %v1 250 /// %v2_1d = vector.shape_cast %v2 251 /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] 252 /// %out_nd = vector.shape_cast %out_1d 253 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices` 254 /// of the original shuffle operation. 255 struct LinearizeVectorShuffle final 256 : public OpConversionPattern<vector::ShuffleOp> { 257 using OpConversionPattern::OpConversionPattern; 258 LinearizeVectorShuffle( 259 const TypeConverter &typeConverter, MLIRContext *context, 260 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 261 PatternBenefit benefit = 1) 262 : OpConversionPattern(typeConverter, context, benefit), 263 targetVectorBitWidth(targetVectBitWidth) {} 264 265 LogicalResult 266 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 267 ConversionPatternRewriter &rewriter) const override { 268 VectorType dstType = 269 getTypeConverter()->convertType<VectorType>(shuffleOp.getType()); 270 assert(dstType && "vector type destination expected."); 271 // The assert is used because vector.shuffle does not support scalable 272 // vectors. 273 assert(!(shuffleOp.getV1VectorType().isScalable() || 274 shuffleOp.getV2VectorType().isScalable() || 275 dstType.isScalable()) && 276 "scalable vectors are not supported."); 277 if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth)) 278 return rewriter.notifyMatchFailure( 279 shuffleOp, "Can't flatten since targetBitWidth <= OpSize"); 280 281 Value vec1 = adaptor.getV1(); 282 Value vec2 = adaptor.getV2(); 283 int shuffleSliceLen = 1; 284 int rank = shuffleOp.getV1().getType().getRank(); 285 286 // If rank > 1, we need to do the shuffle in the granularity of slices 287 // instead of scalars. Size of the slice is equal to the rank-1 innermost 288 // dims. Mask of the shuffle op specifies which slice to take from the 289 // outermost dim. 290 if (rank > 1) { 291 llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape(); 292 for (unsigned i = 1; i < shape.size(); ++i) { 293 shuffleSliceLen *= shape[i]; 294 } 295 } 296 297 // For each value in the mask, we generate the indices of the source vectors 298 // that needs to be shuffled to the destination vector. If shuffleSliceLen > 299 // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of 300 // elements) instead of scalars. 301 ArrayRef<int64_t> mask = shuffleOp.getMask(); 302 int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen; 303 llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts); 304 for (auto [i, value] : llvm::enumerate(mask)) { 305 std::iota(indices.begin() + shuffleSliceLen * i, 306 indices.begin() + shuffleSliceLen * (i + 1), 307 shuffleSliceLen * value); 308 } 309 310 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1, 311 vec2, indices); 312 return success(); 313 } 314 315 private: 316 unsigned targetVectorBitWidth; 317 }; 318 319 /// This pattern converts the ExtractOp to a ShuffleOp that works on a 320 /// linearized vector. 321 /// Following, 322 /// vector.extract %source [ position ] 323 /// is converted to : 324 /// %source_1d = vector.shape_cast %source 325 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] 326 /// %out_nd = vector.shape_cast %out_1d 327 /// `shuffle_indices_1d` is computed using the position of the original extract. 328 struct LinearizeVectorExtract final 329 : public OpConversionPattern<vector::ExtractOp> { 330 using OpConversionPattern::OpConversionPattern; 331 LinearizeVectorExtract( 332 const TypeConverter &typeConverter, MLIRContext *context, 333 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 334 PatternBenefit benefit = 1) 335 : OpConversionPattern(typeConverter, context, benefit), 336 targetVectorBitWidth(targetVectBitWidth) {} 337 LogicalResult 338 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 339 ConversionPatternRewriter &rewriter) const override { 340 Type dstTy = getTypeConverter()->convertType(extractOp.getType()); 341 if (!dstTy) 342 return rewriter.notifyMatchFailure(extractOp, 343 "expected n-D vector type."); 344 345 if (extractOp.getVector().getType().isScalable() || 346 cast<VectorType>(dstTy).isScalable()) 347 return rewriter.notifyMatchFailure(extractOp, 348 "scalable vectors are not supported."); 349 if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) 350 return rewriter.notifyMatchFailure( 351 extractOp, "Can't flatten since targetBitWidth <= OpSize"); 352 353 // Dynamic position is not supported. 354 if (extractOp.hasDynamicPosition()) 355 return rewriter.notifyMatchFailure(extractOp, 356 "dynamic position is not supported."); 357 358 llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape(); 359 int64_t size = extractOp.getVector().getType().getNumElements(); 360 361 // Compute linearized offset. 362 int64_t linearizedOffset = 0; 363 llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition(); 364 for (auto [i, off] : llvm::enumerate(offsets)) { 365 size /= shape[i]; 366 linearizedOffset += offsets[i] * size; 367 } 368 369 llvm::SmallVector<int64_t, 2> indices(size); 370 std::iota(indices.begin(), indices.end(), linearizedOffset); 371 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 372 extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices); 373 374 return success(); 375 } 376 377 private: 378 unsigned targetVectorBitWidth; 379 }; 380 381 /// This pattern converts the InsertOp to a ShuffleOp that works on a 382 /// linearized vector. 383 /// Following, 384 /// vector.insert %source %destination [ position ] 385 /// is converted to : 386 /// %source_1d = vector.shape_cast %source 387 /// %destination_1d = vector.shape_cast %destination 388 /// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d 389 /// ] %out_nd = vector.shape_cast %out_1d 390 /// `shuffle_indices_1d` is computed using the position of the original insert. 391 struct LinearizeVectorInsert final 392 : public OpConversionPattern<vector::InsertOp> { 393 using OpConversionPattern::OpConversionPattern; 394 LinearizeVectorInsert( 395 const TypeConverter &typeConverter, MLIRContext *context, 396 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 397 PatternBenefit benefit = 1) 398 : OpConversionPattern(typeConverter, context, benefit), 399 targetVectorBitWidth(targetVectBitWidth) {} 400 LogicalResult 401 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 402 ConversionPatternRewriter &rewriter) const override { 403 VectorType dstTy = getTypeConverter()->convertType<VectorType>( 404 insertOp.getDestVectorType()); 405 assert(dstTy && "vector type destination expected."); 406 if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable()) 407 return rewriter.notifyMatchFailure(insertOp, 408 "scalable vectors are not supported."); 409 410 if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(), 411 targetVectorBitWidth)) 412 return rewriter.notifyMatchFailure( 413 insertOp, "Can't flatten since targetBitWidth < OpSize"); 414 415 // dynamic position is not supported 416 if (insertOp.hasDynamicPosition()) 417 return rewriter.notifyMatchFailure(insertOp, 418 "dynamic position is not supported."); 419 auto srcTy = insertOp.getSourceType(); 420 auto srcAsVec = dyn_cast<VectorType>(srcTy); 421 uint64_t srcSize = 0; 422 if (srcAsVec) { 423 srcSize = srcAsVec.getNumElements(); 424 } else { 425 return rewriter.notifyMatchFailure(insertOp, 426 "scalars are not supported."); 427 } 428 429 auto dstShape = insertOp.getDestVectorType().getShape(); 430 const auto dstSize = insertOp.getDestVectorType().getNumElements(); 431 auto dstSizeForOffsets = dstSize; 432 433 // compute linearized offset 434 int64_t linearizedOffset = 0; 435 auto offsetsNd = insertOp.getStaticPosition(); 436 for (auto [dim, offset] : llvm::enumerate(offsetsNd)) { 437 dstSizeForOffsets /= dstShape[dim]; 438 linearizedOffset += offset * dstSizeForOffsets; 439 } 440 441 llvm::SmallVector<int64_t, 2> indices(dstSize); 442 auto origValsUntil = indices.begin(); 443 std::advance(origValsUntil, linearizedOffset); 444 std::iota(indices.begin(), origValsUntil, 445 0); // original values that remain [0, offset) 446 auto newValsUntil = origValsUntil; 447 std::advance(newValsUntil, srcSize); 448 std::iota(origValsUntil, newValsUntil, 449 dstSize); // new values [offset, offset+srcNumElements) 450 std::iota(newValsUntil, indices.end(), 451 linearizedOffset + srcSize); // the rest of original values 452 // [offset+srcNumElements, end) 453 454 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 455 insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices); 456 457 return success(); 458 } 459 460 private: 461 unsigned targetVectorBitWidth; 462 }; 463 464 /// This pattern converts the BitCastOp that works on nD (n > 1) 465 /// vectors to a BitCastOp that works on linearized vectors. 466 /// Following, 467 /// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16> 468 /// is converted to : 469 /// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32> 470 /// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16> 471 /// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16> 472 struct LinearizeVectorBitCast final 473 : public OpConversionPattern<vector::BitCastOp> { 474 using OpConversionPattern::OpConversionPattern; 475 LinearizeVectorBitCast( 476 const TypeConverter &typeConverter, MLIRContext *context, 477 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 478 PatternBenefit benefit = 1) 479 : OpConversionPattern(typeConverter, context, benefit), 480 targetVectorBitWidth(targetVectBitWidth) {} 481 LogicalResult 482 matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor, 483 ConversionPatternRewriter &rewriter) const override { 484 Location loc = castOp.getLoc(); 485 auto resType = getTypeConverter()->convertType(castOp.getType()); 486 if (!resType) 487 return rewriter.notifyMatchFailure(loc, "can't convert return type."); 488 489 if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth)) 490 return rewriter.notifyMatchFailure( 491 loc, "Can't flatten since targetBitWidth <= OpSize"); 492 493 rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType, 494 adaptor.getSource()); 495 return mlir::success(); 496 } 497 498 private: 499 unsigned targetVectorBitWidth; 500 }; 501 502 } // namespace 503 504 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( 505 TypeConverter &typeConverter, RewritePatternSet &patterns, 506 ConversionTarget &target, unsigned targetBitWidth) { 507 508 typeConverter.addConversion([](VectorType type) -> std::optional<Type> { 509 if (!isLinearizableVector(type)) 510 return type; 511 512 return VectorType::get(type.getNumElements(), type.getElementType(), 513 type.isScalable()); 514 }); 515 516 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, 517 Location loc) -> Value { 518 if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) || 519 !isa<VectorType>(type)) 520 return nullptr; 521 522 return builder.create<vector::ShapeCastOp>(loc, type, inputs.front()); 523 }; 524 typeConverter.addSourceMaterialization(materializeCast); 525 typeConverter.addTargetMaterialization(materializeCast); 526 target.markUnknownOpDynamicallyLegal( 527 [=](Operation *op) -> std::optional<bool> { 528 if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) || 529 op->hasTrait<OpTrait::Vectorizable>())) { 530 return (isLessThanTargetBitWidth(op, targetBitWidth) 531 ? typeConverter.isLegal(op) 532 : true); 533 } 534 return std::nullopt; 535 }); 536 537 patterns 538 .add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>( 539 typeConverter, patterns.getContext(), targetBitWidth); 540 } 541 542 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( 543 const TypeConverter &typeConverter, RewritePatternSet &patterns, 544 ConversionTarget &target, unsigned int targetBitWidth) { 545 target.addDynamicallyLegalOp<vector::ShuffleOp>( 546 [=](vector::ShuffleOp shuffleOp) -> bool { 547 return isLessThanTargetBitWidth(shuffleOp, targetBitWidth) 548 ? (typeConverter.isLegal(shuffleOp) && 549 cast<mlir::VectorType>(shuffleOp.getResult().getType()) 550 .getRank() == 1) 551 : true; 552 }); 553 patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract, 554 LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>( 555 typeConverter, patterns.getContext(), targetBitWidth); 556 } 557