1 //===- VectorLinearize.cpp - vector linearization transforms --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns and pass for linearizing ND vectors into 1D. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Vector/IR/VectorOps.h" 15 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/BuiltinAttributes.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "llvm/ADT/ArrayRef.h" 23 #include <cstdint> 24 #include <numeric> 25 26 using namespace mlir; 27 28 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { 29 auto resultTypes = op->getResultTypes(); 30 for (auto resType : resultTypes) { 31 VectorType vecType = dyn_cast<VectorType>(resType); 32 // Reject index since getElementTypeBitWidth will abort for Index types. 33 if (!vecType || vecType.getElementType().isIndex()) 34 return false; 35 // There are no dimension to fold if it is a 0-D vector. 36 if (vecType.getRank() == 0) 37 return false; 38 unsigned trailingVecDimBitWidth = 39 vecType.getShape().back() * vecType.getElementTypeBitWidth(); 40 if (trailingVecDimBitWidth >= targetBitWidth) 41 return false; 42 } 43 return true; 44 } 45 46 static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) { 47 VectorType vecType = dyn_cast<VectorType>(t); 48 // Reject index since getElementTypeBitWidth will abort for Index types. 49 if (!vecType || vecType.getElementType().isIndex()) 50 return false; 51 // There are no dimension to fold if it is a 0-D vector. 52 if (vecType.getRank() == 0) 53 return false; 54 unsigned trailingVecDimBitWidth = 55 vecType.getShape().back() * vecType.getElementTypeBitWidth(); 56 return trailingVecDimBitWidth <= targetBitWidth; 57 } 58 59 namespace { 60 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> { 61 using OpConversionPattern::OpConversionPattern; 62 LinearizeConstant( 63 const TypeConverter &typeConverter, MLIRContext *context, 64 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 65 PatternBenefit benefit = 1) 66 : OpConversionPattern(typeConverter, context, benefit), 67 targetVectorBitWidth(targetVectBitWidth) {} 68 LogicalResult 69 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, 70 ConversionPatternRewriter &rewriter) const override { 71 Location loc = constOp.getLoc(); 72 auto resType = 73 getTypeConverter()->convertType<VectorType>(constOp.getType()); 74 75 if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue())) 76 return rewriter.notifyMatchFailure( 77 loc, 78 "Cannot linearize a constant scalable vector that's not a splat"); 79 80 if (!resType) 81 return rewriter.notifyMatchFailure(loc, "can't convert return type"); 82 if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth)) 83 return rewriter.notifyMatchFailure( 84 loc, "Can't flatten since targetBitWidth <= OpSize"); 85 auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue()); 86 if (!dstElementsAttr) 87 return rewriter.notifyMatchFailure(loc, "unsupported attr type"); 88 89 dstElementsAttr = dstElementsAttr.reshape(resType); 90 rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType, 91 dstElementsAttr); 92 return success(); 93 } 94 95 private: 96 unsigned targetVectorBitWidth; 97 }; 98 99 struct LinearizeVectorizable final 100 : OpTraitConversionPattern<OpTrait::Vectorizable> { 101 using OpTraitConversionPattern::OpTraitConversionPattern; 102 103 public: 104 LinearizeVectorizable( 105 const TypeConverter &typeConverter, MLIRContext *context, 106 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 107 PatternBenefit benefit = 1) 108 : OpTraitConversionPattern(typeConverter, context, benefit), 109 targetVectorBitWidth(targetVectBitWidth) {} 110 LogicalResult 111 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 112 ConversionPatternRewriter &rewriter) const override { 113 if (!isLessThanTargetBitWidth(op, targetVectorBitWidth)) 114 return rewriter.notifyMatchFailure( 115 op->getLoc(), "Can't flatten since targetBitWidth <= OpSize"); 116 FailureOr<Operation *> newOp = 117 convertOpResultTypes(op, operands, *getTypeConverter(), rewriter); 118 if (failed(newOp)) 119 return failure(); 120 121 rewriter.replaceOp(op, (*newOp)->getResults()); 122 return success(); 123 } 124 125 private: 126 unsigned targetVectorBitWidth; 127 }; 128 129 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works 130 /// on a linearized vector. 131 /// Following, 132 /// vector.extract_strided_slice %source 133 /// { offsets = [..], strides = [..], sizes = [..] } 134 /// is converted to : 135 /// %source_1d = vector.shape_cast %source 136 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] 137 /// %out_nd = vector.shape_cast %out_1d 138 /// `shuffle_indices_1d` is computed using the offsets and sizes of the 139 /// extraction. 140 struct LinearizeVectorExtractStridedSlice final 141 : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> { 142 using OpConversionPattern::OpConversionPattern; 143 LinearizeVectorExtractStridedSlice( 144 const TypeConverter &typeConverter, MLIRContext *context, 145 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 146 PatternBenefit benefit = 1) 147 : OpConversionPattern(typeConverter, context, benefit), 148 targetVectorBitWidth(targetVectBitWidth) {} 149 150 LogicalResult 151 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, 152 ConversionPatternRewriter &rewriter) const override { 153 VectorType dstType = 154 getTypeConverter()->convertType<VectorType>(extractOp.getType()); 155 assert(dstType && "vector type destination expected."); 156 if (extractOp.getVector().getType().isScalable() || dstType.isScalable()) 157 return rewriter.notifyMatchFailure(extractOp, 158 "scalable vectors are not supported."); 159 if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) 160 return rewriter.notifyMatchFailure( 161 extractOp, "Can't flatten since targetBitWidth <= OpSize"); 162 163 ArrayAttr offsets = extractOp.getOffsets(); 164 ArrayAttr sizes = extractOp.getSizes(); 165 ArrayAttr strides = extractOp.getStrides(); 166 if (!isConstantIntValue(strides[0], 1)) 167 return rewriter.notifyMatchFailure( 168 extractOp, "Strided slice with stride != 1 is not supported."); 169 Value srcVector = adaptor.getVector(); 170 // If kD offsets are specified for nD source vector (n > k), the granularity 171 // of the extraction is greater than 1. In this case last (n-k) dimensions 172 // form the extraction granularity. 173 // Example : 174 // vector.extract_strided_slice %src { 175 // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : 176 // vector<4x8x8xf32> to vector<2x2x8xf32> 177 // Here, extraction granularity is 8. 178 int64_t extractGranularitySize = 1; 179 int64_t nD = extractOp.getSourceVectorType().getRank(); 180 int64_t kD = (int64_t)offsets.size(); 181 int64_t k = kD; 182 while (k < nD) { 183 extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k]; 184 ++k; 185 } 186 // Get total number of extracted slices. 187 int64_t nExtractedSlices = 1; 188 for (Attribute size : sizes) { 189 nExtractedSlices *= cast<IntegerAttr>(size).getInt(); 190 } 191 // Compute the strides of the source vector considering first k dimensions. 192 llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize); 193 for (int i = kD - 2; i >= 0; --i) { 194 sourceStrides[i] = sourceStrides[i + 1] * 195 extractOp.getSourceVectorType().getShape()[i + 1]; 196 } 197 // Final shuffle indices has nExtractedSlices * extractGranularitySize 198 // elements. 199 llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * 200 extractGranularitySize); 201 // Compute the strides of the extracted kD vector. 202 llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1); 203 // Compute extractedStrides. 204 for (int i = kD - 2; i >= 0; --i) { 205 extractedStrides[i] = 206 extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt(); 207 } 208 // Iterate over all extracted slices from 0 to nExtractedSlices - 1 209 // and compute the multi-dimensional index and the corresponding linearized 210 // index within the source vector. 211 for (int64_t i = 0; i < nExtractedSlices; ++i) { 212 int64_t index = i; 213 // Compute the corresponding multi-dimensional index. 214 llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0); 215 for (int64_t j = 0; j < kD; ++j) { 216 multiDimIndex[j] = (index / extractedStrides[j]); 217 index -= multiDimIndex[j] * extractedStrides[j]; 218 } 219 // Compute the corresponding linearized index in the source vector 220 // i.e. shift the multiDimIndex by the offsets. 221 int64_t linearizedIndex = 0; 222 for (int64_t j = 0; j < kD; ++j) { 223 linearizedIndex += 224 (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) * 225 sourceStrides[j]; 226 } 227 // Fill the indices array form linearizedIndex to linearizedIndex + 228 // extractGranularitySize. 229 for (int64_t j = 0; j < extractGranularitySize; ++j) { 230 indices[i * extractGranularitySize + j] = linearizedIndex + j; 231 } 232 } 233 // Perform a shuffle to extract the kD vector. 234 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 235 extractOp, dstType, srcVector, srcVector, indices); 236 return success(); 237 } 238 239 private: 240 unsigned targetVectorBitWidth; 241 }; 242 243 /// This pattern converts the ShuffleOp that works on nD (n > 1) 244 /// vectors to a ShuffleOp that works on linearized vectors. 245 /// Following, 246 /// vector.shuffle %v1, %v2 [ shuffle_indices ] 247 /// is converted to : 248 /// %v1_1d = vector.shape_cast %v1 249 /// %v2_1d = vector.shape_cast %v2 250 /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] 251 /// %out_nd = vector.shape_cast %out_1d 252 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices` 253 /// of the original shuffle operation. 254 struct LinearizeVectorShuffle final 255 : public OpConversionPattern<vector::ShuffleOp> { 256 using OpConversionPattern::OpConversionPattern; 257 LinearizeVectorShuffle( 258 const TypeConverter &typeConverter, MLIRContext *context, 259 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 260 PatternBenefit benefit = 1) 261 : OpConversionPattern(typeConverter, context, benefit), 262 targetVectorBitWidth(targetVectBitWidth) {} 263 264 LogicalResult 265 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 266 ConversionPatternRewriter &rewriter) const override { 267 VectorType dstType = 268 getTypeConverter()->convertType<VectorType>(shuffleOp.getType()); 269 assert(dstType && "vector type destination expected."); 270 // The assert is used because vector.shuffle does not support scalable 271 // vectors. 272 assert(!(shuffleOp.getV1VectorType().isScalable() || 273 shuffleOp.getV2VectorType().isScalable() || 274 dstType.isScalable()) && 275 "scalable vectors are not supported."); 276 if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth)) 277 return rewriter.notifyMatchFailure( 278 shuffleOp, "Can't flatten since targetBitWidth <= OpSize"); 279 280 Value vec1 = adaptor.getV1(); 281 Value vec2 = adaptor.getV2(); 282 int shuffleSliceLen = 1; 283 int rank = shuffleOp.getV1().getType().getRank(); 284 285 // If rank > 1, we need to do the shuffle in the granularity of slices 286 // instead of scalars. Size of the slice is equal to the rank-1 innermost 287 // dims. Mask of the shuffle op specifies which slice to take from the 288 // outermost dim. 289 if (rank > 1) { 290 llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape(); 291 for (unsigned i = 1; i < shape.size(); ++i) { 292 shuffleSliceLen *= shape[i]; 293 } 294 } 295 296 // For each value in the mask, we generate the indices of the source vectors 297 // that needs to be shuffled to the destination vector. If shuffleSliceLen > 298 // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of 299 // elements) instead of scalars. 300 ArrayRef<int64_t> mask = shuffleOp.getMask(); 301 int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen; 302 llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts); 303 for (auto [i, value] : llvm::enumerate(mask)) { 304 std::iota(indices.begin() + shuffleSliceLen * i, 305 indices.begin() + shuffleSliceLen * (i + 1), 306 shuffleSliceLen * value); 307 } 308 309 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1, 310 vec2, indices); 311 return success(); 312 } 313 314 private: 315 unsigned targetVectorBitWidth; 316 }; 317 318 /// This pattern converts the ExtractOp to a ShuffleOp that works on a 319 /// linearized vector. 320 /// Following, 321 /// vector.extract %source [ position ] 322 /// is converted to : 323 /// %source_1d = vector.shape_cast %source 324 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] 325 /// %out_nd = vector.shape_cast %out_1d 326 /// `shuffle_indices_1d` is computed using the position of the original extract. 327 struct LinearizeVectorExtract final 328 : public OpConversionPattern<vector::ExtractOp> { 329 using OpConversionPattern::OpConversionPattern; 330 LinearizeVectorExtract( 331 const TypeConverter &typeConverter, MLIRContext *context, 332 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 333 PatternBenefit benefit = 1) 334 : OpConversionPattern(typeConverter, context, benefit), 335 targetVectorBitWidth(targetVectBitWidth) {} 336 LogicalResult 337 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 338 ConversionPatternRewriter &rewriter) const override { 339 Type dstTy = getTypeConverter()->convertType(extractOp.getType()); 340 if (extractOp.getVector().getType().isScalable() || 341 cast<VectorType>(dstTy).isScalable()) 342 return rewriter.notifyMatchFailure(extractOp, 343 "scalable vectors are not supported."); 344 if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) 345 return rewriter.notifyMatchFailure( 346 extractOp, "Can't flatten since targetBitWidth <= OpSize"); 347 348 // Dynamic position is not supported. 349 if (extractOp.hasDynamicPosition()) 350 return rewriter.notifyMatchFailure(extractOp, 351 "dynamic position is not supported."); 352 353 llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape(); 354 int64_t size = extractOp.getVector().getType().getNumElements(); 355 356 // Compute linearized offset. 357 int64_t linearizedOffset = 0; 358 llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition(); 359 for (auto [i, off] : llvm::enumerate(offsets)) { 360 size /= shape[i]; 361 linearizedOffset += offsets[i] * size; 362 } 363 364 llvm::SmallVector<int64_t, 2> indices(size); 365 std::iota(indices.begin(), indices.end(), linearizedOffset); 366 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 367 extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices); 368 369 return success(); 370 } 371 372 private: 373 unsigned targetVectorBitWidth; 374 }; 375 376 /// This pattern converts the InsertOp to a ShuffleOp that works on a 377 /// linearized vector. 378 /// Following, 379 /// vector.insert %source %destination [ position ] 380 /// is converted to : 381 /// %source_1d = vector.shape_cast %source 382 /// %destination_1d = vector.shape_cast %destination 383 /// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d 384 /// ] %out_nd = vector.shape_cast %out_1d 385 /// `shuffle_indices_1d` is computed using the position of the original insert. 386 struct LinearizeVectorInsert final 387 : public OpConversionPattern<vector::InsertOp> { 388 using OpConversionPattern::OpConversionPattern; 389 LinearizeVectorInsert( 390 const TypeConverter &typeConverter, MLIRContext *context, 391 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 392 PatternBenefit benefit = 1) 393 : OpConversionPattern(typeConverter, context, benefit), 394 targetVectorBitWidth(targetVectBitWidth) {} 395 LogicalResult 396 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 397 ConversionPatternRewriter &rewriter) const override { 398 VectorType dstTy = getTypeConverter()->convertType<VectorType>( 399 insertOp.getDestVectorType()); 400 assert(dstTy && "vector type destination expected."); 401 if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable()) 402 return rewriter.notifyMatchFailure(insertOp, 403 "scalable vectors are not supported."); 404 405 if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(), 406 targetVectorBitWidth)) 407 return rewriter.notifyMatchFailure( 408 insertOp, "Can't flatten since targetBitWidth < OpSize"); 409 410 // dynamic position is not supported 411 if (insertOp.hasDynamicPosition()) 412 return rewriter.notifyMatchFailure(insertOp, 413 "dynamic position is not supported."); 414 auto srcTy = insertOp.getSourceType(); 415 auto srcAsVec = dyn_cast<VectorType>(srcTy); 416 uint64_t srcSize = 0; 417 if (srcAsVec) { 418 srcSize = srcAsVec.getNumElements(); 419 } else { 420 return rewriter.notifyMatchFailure(insertOp, 421 "scalars are not supported."); 422 } 423 424 auto dstShape = insertOp.getDestVectorType().getShape(); 425 const auto dstSize = insertOp.getDestVectorType().getNumElements(); 426 auto dstSizeForOffsets = dstSize; 427 428 // compute linearized offset 429 int64_t linearizedOffset = 0; 430 auto offsetsNd = insertOp.getStaticPosition(); 431 for (auto [dim, offset] : llvm::enumerate(offsetsNd)) { 432 dstSizeForOffsets /= dstShape[dim]; 433 linearizedOffset += offset * dstSizeForOffsets; 434 } 435 436 llvm::SmallVector<int64_t, 2> indices(dstSize); 437 auto origValsUntil = indices.begin(); 438 std::advance(origValsUntil, linearizedOffset); 439 std::iota(indices.begin(), origValsUntil, 440 0); // original values that remain [0, offset) 441 auto newValsUntil = origValsUntil; 442 std::advance(newValsUntil, srcSize); 443 std::iota(origValsUntil, newValsUntil, 444 dstSize); // new values [offset, offset+srcNumElements) 445 std::iota(newValsUntil, indices.end(), 446 linearizedOffset + srcSize); // the rest of original values 447 // [offset+srcNumElements, end) 448 449 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 450 insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices); 451 452 return success(); 453 } 454 455 private: 456 unsigned targetVectorBitWidth; 457 }; 458 } // namespace 459 460 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( 461 TypeConverter &typeConverter, RewritePatternSet &patterns, 462 ConversionTarget &target, unsigned targetBitWidth) { 463 464 typeConverter.addConversion([](VectorType type) -> std::optional<Type> { 465 if (!isLinearizableVector(type)) 466 return type; 467 468 return VectorType::get(type.getNumElements(), type.getElementType(), 469 type.isScalable()); 470 }); 471 472 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, 473 Location loc) -> Value { 474 if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) || 475 !isa<VectorType>(type)) 476 return nullptr; 477 478 return builder.create<vector::ShapeCastOp>(loc, type, inputs.front()); 479 }; 480 typeConverter.addArgumentMaterialization(materializeCast); 481 typeConverter.addSourceMaterialization(materializeCast); 482 typeConverter.addTargetMaterialization(materializeCast); 483 target.markUnknownOpDynamicallyLegal( 484 [=](Operation *op) -> std::optional<bool> { 485 if ((isa<arith::ConstantOp>(op) || 486 op->hasTrait<OpTrait::Vectorizable>())) { 487 return (isLessThanTargetBitWidth(op, targetBitWidth) 488 ? typeConverter.isLegal(op) 489 : true); 490 } 491 return std::nullopt; 492 }); 493 494 patterns.add<LinearizeConstant, LinearizeVectorizable>( 495 typeConverter, patterns.getContext(), targetBitWidth); 496 } 497 498 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( 499 TypeConverter &typeConverter, RewritePatternSet &patterns, 500 ConversionTarget &target, unsigned int targetBitWidth) { 501 target.addDynamicallyLegalOp<vector::ShuffleOp>( 502 [=](vector::ShuffleOp shuffleOp) -> bool { 503 return isLessThanTargetBitWidth(shuffleOp, targetBitWidth) 504 ? (typeConverter.isLegal(shuffleOp) && 505 cast<mlir::VectorType>(shuffleOp.getResult().getType()) 506 .getRank() == 1) 507 : true; 508 }); 509 patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract, 510 LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>( 511 typeConverter, patterns.getContext(), targetBitWidth); 512 } 513