1 //===- VectorLinearize.cpp - vector linearization transforms --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns and pass for linearizing ND vectors into 1D. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Vector/IR/VectorOps.h" 15 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/BuiltinAttributes.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "llvm/ADT/ArrayRef.h" 23 #include <cstdint> 24 #include <numeric> 25 26 using namespace mlir; 27 28 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { 29 auto resultTypes = op->getResultTypes(); 30 for (auto resType : resultTypes) { 31 VectorType vecType = dyn_cast<VectorType>(resType); 32 // Reject index since getElementTypeBitWidth will abort for Index types. 33 if (!vecType || vecType.getElementType().isIndex()) 34 return false; 35 // There are no dimension to fold if it is a 0-D vector. 36 if (vecType.getRank() == 0) 37 return false; 38 unsigned trailingVecDimBitWidth = 39 vecType.getShape().back() * vecType.getElementTypeBitWidth(); 40 if (trailingVecDimBitWidth >= targetBitWidth) 41 return false; 42 } 43 return true; 44 } 45 46 static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) { 47 VectorType vecType = dyn_cast<VectorType>(t); 48 // Reject index since getElementTypeBitWidth will abort for Index types. 49 if (!vecType || vecType.getElementType().isIndex()) 50 return false; 51 // There are no dimension to fold if it is a 0-D vector. 52 if (vecType.getRank() == 0) 53 return false; 54 unsigned trailingVecDimBitWidth = 55 vecType.getShape().back() * vecType.getElementTypeBitWidth(); 56 return trailingVecDimBitWidth <= targetBitWidth; 57 } 58 59 namespace { 60 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> { 61 using OpConversionPattern::OpConversionPattern; 62 LinearizeConstant( 63 const TypeConverter &typeConverter, MLIRContext *context, 64 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 65 PatternBenefit benefit = 1) 66 : OpConversionPattern(typeConverter, context, benefit), 67 targetVectorBitWidth(targetVectBitWidth) {} 68 LogicalResult 69 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, 70 ConversionPatternRewriter &rewriter) const override { 71 Location loc = constOp.getLoc(); 72 auto resType = 73 getTypeConverter()->convertType<VectorType>(constOp.getType()); 74 75 if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue())) 76 return rewriter.notifyMatchFailure( 77 loc, 78 "Cannot linearize a constant scalable vector that's not a splat"); 79 80 if (!resType) 81 return rewriter.notifyMatchFailure(loc, "can't convert return type"); 82 if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth)) 83 return rewriter.notifyMatchFailure( 84 loc, "Can't flatten since targetBitWidth <= OpSize"); 85 auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue()); 86 if (!dstElementsAttr) 87 return rewriter.notifyMatchFailure(loc, "unsupported attr type"); 88 89 dstElementsAttr = dstElementsAttr.reshape(resType); 90 rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType, 91 dstElementsAttr); 92 return success(); 93 } 94 95 private: 96 unsigned targetVectorBitWidth; 97 }; 98 99 struct LinearizeVectorizable final 100 : OpTraitConversionPattern<OpTrait::Vectorizable> { 101 using OpTraitConversionPattern::OpTraitConversionPattern; 102 103 public: 104 LinearizeVectorizable( 105 const TypeConverter &typeConverter, MLIRContext *context, 106 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 107 PatternBenefit benefit = 1) 108 : OpTraitConversionPattern(typeConverter, context, benefit), 109 targetVectorBitWidth(targetVectBitWidth) {} 110 LogicalResult 111 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 112 ConversionPatternRewriter &rewriter) const override { 113 if (!isLessThanTargetBitWidth(op, targetVectorBitWidth)) 114 return rewriter.notifyMatchFailure( 115 op->getLoc(), "Can't flatten since targetBitWidth <= OpSize"); 116 FailureOr<Operation *> newOp = 117 convertOpResultTypes(op, operands, *getTypeConverter(), rewriter); 118 if (failed(newOp)) 119 return failure(); 120 121 rewriter.replaceOp(op, (*newOp)->getResults()); 122 return success(); 123 } 124 125 private: 126 unsigned targetVectorBitWidth; 127 }; 128 129 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works 130 /// on a linearized vector. 131 /// Following, 132 /// vector.extract_strided_slice %source 133 /// { offsets = [..], strides = [..], sizes = [..] } 134 /// is converted to : 135 /// %source_1d = vector.shape_cast %source 136 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] 137 /// %out_nd = vector.shape_cast %out_1d 138 /// `shuffle_indices_1d` is computed using the offsets and sizes of the 139 /// extraction. 140 struct LinearizeVectorExtractStridedSlice final 141 : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> { 142 using OpConversionPattern::OpConversionPattern; 143 LinearizeVectorExtractStridedSlice( 144 const TypeConverter &typeConverter, MLIRContext *context, 145 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 146 PatternBenefit benefit = 1) 147 : OpConversionPattern(typeConverter, context, benefit), 148 targetVectorBitWidth(targetVectBitWidth) {} 149 150 LogicalResult 151 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, 152 ConversionPatternRewriter &rewriter) const override { 153 VectorType dstType = 154 getTypeConverter()->convertType<VectorType>(extractOp.getType()); 155 assert(dstType && "vector type destination expected."); 156 if (extractOp.getVector().getType().isScalable() || dstType.isScalable()) 157 return rewriter.notifyMatchFailure(extractOp, 158 "scalable vectors are not supported."); 159 if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) 160 return rewriter.notifyMatchFailure( 161 extractOp, "Can't flatten since targetBitWidth <= OpSize"); 162 163 ArrayAttr offsets = extractOp.getOffsets(); 164 ArrayAttr sizes = extractOp.getSizes(); 165 ArrayAttr strides = extractOp.getStrides(); 166 if (!isConstantIntValue(strides[0], 1)) 167 return rewriter.notifyMatchFailure( 168 extractOp, "Strided slice with stride != 1 is not supported."); 169 Value srcVector = adaptor.getVector(); 170 // If kD offsets are specified for nD source vector (n > k), the granularity 171 // of the extraction is greater than 1. In this case last (n-k) dimensions 172 // form the extraction granularity. 173 // Example : 174 // vector.extract_strided_slice %src { 175 // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : 176 // vector<4x8x8xf32> to vector<2x2x8xf32> 177 // Here, extraction granularity is 8. 178 int64_t extractGranularitySize = 1; 179 int64_t nD = extractOp.getSourceVectorType().getRank(); 180 int64_t kD = (int64_t)offsets.size(); 181 int64_t k = kD; 182 while (k < nD) { 183 extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k]; 184 ++k; 185 } 186 // Get total number of extracted slices. 187 int64_t nExtractedSlices = 1; 188 for (Attribute size : sizes) { 189 nExtractedSlices *= cast<IntegerAttr>(size).getInt(); 190 } 191 // Compute the strides of the source vector considering first k dimensions. 192 llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize); 193 for (int i = kD - 2; i >= 0; --i) { 194 sourceStrides[i] = sourceStrides[i + 1] * 195 extractOp.getSourceVectorType().getShape()[i + 1]; 196 } 197 // Final shuffle indices has nExtractedSlices * extractGranularitySize 198 // elements. 199 llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * 200 extractGranularitySize); 201 // Compute the strides of the extracted kD vector. 202 llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1); 203 // Compute extractedStrides. 204 for (int i = kD - 2; i >= 0; --i) { 205 extractedStrides[i] = 206 extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt(); 207 } 208 // Iterate over all extracted slices from 0 to nExtractedSlices - 1 209 // and compute the multi-dimensional index and the corresponding linearized 210 // index within the source vector. 211 for (int64_t i = 0; i < nExtractedSlices; ++i) { 212 int64_t index = i; 213 // Compute the corresponding multi-dimensional index. 214 llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0); 215 for (int64_t j = 0; j < kD; ++j) { 216 multiDimIndex[j] = (index / extractedStrides[j]); 217 index -= multiDimIndex[j] * extractedStrides[j]; 218 } 219 // Compute the corresponding linearized index in the source vector 220 // i.e. shift the multiDimIndex by the offsets. 221 int64_t linearizedIndex = 0; 222 for (int64_t j = 0; j < kD; ++j) { 223 linearizedIndex += 224 (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) * 225 sourceStrides[j]; 226 } 227 // Fill the indices array form linearizedIndex to linearizedIndex + 228 // extractGranularitySize. 229 for (int64_t j = 0; j < extractGranularitySize; ++j) { 230 indices[i * extractGranularitySize + j] = linearizedIndex + j; 231 } 232 } 233 // Perform a shuffle to extract the kD vector. 234 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 235 extractOp, dstType, srcVector, srcVector, 236 rewriter.getI64ArrayAttr(indices)); 237 return success(); 238 } 239 240 private: 241 unsigned targetVectorBitWidth; 242 }; 243 244 /// This pattern converts the ShuffleOp that works on nD (n > 1) 245 /// vectors to a ShuffleOp that works on linearized vectors. 246 /// Following, 247 /// vector.shuffle %v1, %v2 [ shuffle_indices ] 248 /// is converted to : 249 /// %v1_1d = vector.shape_cast %v1 250 /// %v2_1d = vector.shape_cast %v2 251 /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] 252 /// %out_nd = vector.shape_cast %out_1d 253 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices` 254 /// of the original shuffle operation. 255 struct LinearizeVectorShuffle final 256 : public OpConversionPattern<vector::ShuffleOp> { 257 using OpConversionPattern::OpConversionPattern; 258 LinearizeVectorShuffle( 259 const TypeConverter &typeConverter, MLIRContext *context, 260 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 261 PatternBenefit benefit = 1) 262 : OpConversionPattern(typeConverter, context, benefit), 263 targetVectorBitWidth(targetVectBitWidth) {} 264 265 LogicalResult 266 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 267 ConversionPatternRewriter &rewriter) const override { 268 VectorType dstType = 269 getTypeConverter()->convertType<VectorType>(shuffleOp.getType()); 270 assert(dstType && "vector type destination expected."); 271 // The assert is used because vector.shuffle does not support scalable 272 // vectors. 273 assert(!(shuffleOp.getV1VectorType().isScalable() || 274 shuffleOp.getV2VectorType().isScalable() || 275 dstType.isScalable()) && 276 "scalable vectors are not supported."); 277 if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth)) 278 return rewriter.notifyMatchFailure( 279 shuffleOp, "Can't flatten since targetBitWidth <= OpSize"); 280 281 Value vec1 = adaptor.getV1(); 282 Value vec2 = adaptor.getV2(); 283 int shuffleSliceLen = 1; 284 int rank = shuffleOp.getV1().getType().getRank(); 285 286 // If rank > 1, we need to do the shuffle in the granularity of slices 287 // instead of scalars. Size of the slice is equal to the rank-1 innermost 288 // dims. Mask of the shuffle op specifies which slice to take from the 289 // outermost dim. 290 if (rank > 1) { 291 llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape(); 292 for (unsigned i = 1; i < shape.size(); ++i) { 293 shuffleSliceLen *= shape[i]; 294 } 295 } 296 297 // For each value in the mask, we generate the indices of the source vectors 298 // that needs to be shuffled to the destination vector. If shuffleSliceLen > 299 // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of 300 // elements) instead of scalars. 301 ArrayAttr mask = shuffleOp.getMask(); 302 int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen; 303 llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts); 304 for (auto [i, value] : 305 llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) { 306 307 int64_t v = value.getZExtValue(); 308 std::iota(indices.begin() + shuffleSliceLen * i, 309 indices.begin() + shuffleSliceLen * (i + 1), 310 shuffleSliceLen * v); 311 } 312 313 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 314 shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices)); 315 return success(); 316 } 317 318 private: 319 unsigned targetVectorBitWidth; 320 }; 321 322 /// This pattern converts the ExtractOp to a ShuffleOp that works on a 323 /// linearized vector. 324 /// Following, 325 /// vector.extract %source [ position ] 326 /// is converted to : 327 /// %source_1d = vector.shape_cast %source 328 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] 329 /// %out_nd = vector.shape_cast %out_1d 330 /// `shuffle_indices_1d` is computed using the position of the original extract. 331 struct LinearizeVectorExtract final 332 : public OpConversionPattern<vector::ExtractOp> { 333 using OpConversionPattern::OpConversionPattern; 334 LinearizeVectorExtract( 335 const TypeConverter &typeConverter, MLIRContext *context, 336 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 337 PatternBenefit benefit = 1) 338 : OpConversionPattern(typeConverter, context, benefit), 339 targetVectorBitWidth(targetVectBitWidth) {} 340 LogicalResult 341 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 342 ConversionPatternRewriter &rewriter) const override { 343 Type dstTy = getTypeConverter()->convertType(extractOp.getType()); 344 if (extractOp.getVector().getType().isScalable() || 345 cast<VectorType>(dstTy).isScalable()) 346 return rewriter.notifyMatchFailure(extractOp, 347 "scalable vectors are not supported."); 348 if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) 349 return rewriter.notifyMatchFailure( 350 extractOp, "Can't flatten since targetBitWidth <= OpSize"); 351 352 // Dynamic position is not supported. 353 if (extractOp.hasDynamicPosition()) 354 return rewriter.notifyMatchFailure(extractOp, 355 "dynamic position is not supported."); 356 357 llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape(); 358 int64_t size = extractOp.getVector().getType().getNumElements(); 359 360 // Compute linearized offset. 361 int64_t linearizedOffset = 0; 362 llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition(); 363 for (auto [i, off] : llvm::enumerate(offsets)) { 364 size /= shape[i]; 365 linearizedOffset += offsets[i] * size; 366 } 367 368 llvm::SmallVector<int64_t, 2> indices(size); 369 std::iota(indices.begin(), indices.end(), linearizedOffset); 370 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 371 extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), 372 rewriter.getI64ArrayAttr(indices)); 373 374 return success(); 375 } 376 377 private: 378 unsigned targetVectorBitWidth; 379 }; 380 381 /// This pattern converts the InsertOp to a ShuffleOp that works on a 382 /// linearized vector. 383 /// Following, 384 /// vector.insert %source %destination [ position ] 385 /// is converted to : 386 /// %source_1d = vector.shape_cast %source 387 /// %destination_1d = vector.shape_cast %destination 388 /// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d 389 /// ] %out_nd = vector.shape_cast %out_1d 390 /// `shuffle_indices_1d` is computed using the position of the original insert. 391 struct LinearizeVectorInsert final 392 : public OpConversionPattern<vector::InsertOp> { 393 using OpConversionPattern::OpConversionPattern; 394 LinearizeVectorInsert( 395 const TypeConverter &typeConverter, MLIRContext *context, 396 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 397 PatternBenefit benefit = 1) 398 : OpConversionPattern(typeConverter, context, benefit), 399 targetVectorBitWidth(targetVectBitWidth) {} 400 LogicalResult 401 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 402 ConversionPatternRewriter &rewriter) const override { 403 VectorType dstTy = getTypeConverter()->convertType<VectorType>( 404 insertOp.getDestVectorType()); 405 assert(dstTy && "vector type destination expected."); 406 if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable()) 407 return rewriter.notifyMatchFailure(insertOp, 408 "scalable vectors are not supported."); 409 410 if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(), 411 targetVectorBitWidth)) 412 return rewriter.notifyMatchFailure( 413 insertOp, "Can't flatten since targetBitWidth < OpSize"); 414 415 // dynamic position is not supported 416 if (insertOp.hasDynamicPosition()) 417 return rewriter.notifyMatchFailure(insertOp, 418 "dynamic position is not supported."); 419 auto srcTy = insertOp.getSourceType(); 420 auto srcAsVec = dyn_cast<VectorType>(srcTy); 421 uint64_t srcSize = 0; 422 if (srcAsVec) { 423 srcSize = srcAsVec.getNumElements(); 424 } else { 425 return rewriter.notifyMatchFailure(insertOp, 426 "scalars are not supported."); 427 } 428 429 auto dstShape = insertOp.getDestVectorType().getShape(); 430 const auto dstSize = insertOp.getDestVectorType().getNumElements(); 431 auto dstSizeForOffsets = dstSize; 432 433 // compute linearized offset 434 int64_t linearizedOffset = 0; 435 auto offsetsNd = insertOp.getStaticPosition(); 436 for (auto [dim, offset] : llvm::enumerate(offsetsNd)) { 437 dstSizeForOffsets /= dstShape[dim]; 438 linearizedOffset += offset * dstSizeForOffsets; 439 } 440 441 llvm::SmallVector<int64_t, 2> indices(dstSize); 442 auto origValsUntil = indices.begin(); 443 std::advance(origValsUntil, linearizedOffset); 444 std::iota(indices.begin(), origValsUntil, 445 0); // original values that remain [0, offset) 446 auto newValsUntil = origValsUntil; 447 std::advance(newValsUntil, srcSize); 448 std::iota(origValsUntil, newValsUntil, 449 dstSize); // new values [offset, offset+srcNumElements) 450 std::iota(newValsUntil, indices.end(), 451 linearizedOffset + srcSize); // the rest of original values 452 // [offset+srcNumElements, end) 453 454 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 455 insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), 456 rewriter.getI64ArrayAttr(indices)); 457 458 return success(); 459 } 460 461 private: 462 unsigned targetVectorBitWidth; 463 }; 464 } // namespace 465 466 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( 467 TypeConverter &typeConverter, RewritePatternSet &patterns, 468 ConversionTarget &target, unsigned targetBitWidth) { 469 470 typeConverter.addConversion([](VectorType type) -> std::optional<Type> { 471 if (!isLinearizableVector(type)) 472 return type; 473 474 return VectorType::get(type.getNumElements(), type.getElementType(), 475 type.isScalable()); 476 }); 477 478 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, 479 Location loc) -> Value { 480 if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) || 481 !isa<VectorType>(type)) 482 return nullptr; 483 484 return builder.create<vector::ShapeCastOp>(loc, type, inputs.front()); 485 }; 486 typeConverter.addArgumentMaterialization(materializeCast); 487 typeConverter.addSourceMaterialization(materializeCast); 488 typeConverter.addTargetMaterialization(materializeCast); 489 target.markUnknownOpDynamicallyLegal( 490 [=](Operation *op) -> std::optional<bool> { 491 if ((isa<arith::ConstantOp>(op) || 492 op->hasTrait<OpTrait::Vectorizable>())) { 493 return (isLessThanTargetBitWidth(op, targetBitWidth) 494 ? typeConverter.isLegal(op) 495 : true); 496 } 497 return std::nullopt; 498 }); 499 500 patterns.add<LinearizeConstant, LinearizeVectorizable>( 501 typeConverter, patterns.getContext(), targetBitWidth); 502 } 503 504 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( 505 TypeConverter &typeConverter, RewritePatternSet &patterns, 506 ConversionTarget &target, unsigned int targetBitWidth) { 507 target.addDynamicallyLegalOp<vector::ShuffleOp>( 508 [=](vector::ShuffleOp shuffleOp) -> bool { 509 return isLessThanTargetBitWidth(shuffleOp, targetBitWidth) 510 ? (typeConverter.isLegal(shuffleOp) && 511 cast<mlir::VectorType>(shuffleOp.getResult().getType()) 512 .getRank() == 1) 513 : true; 514 }); 515 patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract, 516 LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>( 517 typeConverter, patterns.getContext(), targetBitWidth); 518 } 519