1 //===- VectorLinearize.cpp - vector linearization transforms --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns and pass for linearizing ND vectors into 1D. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Vector/IR/VectorOps.h" 15 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/BuiltinAttributes.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Support/LogicalResult.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "llvm/ADT/ArrayRef.h" 24 #include <cstdint> 25 #include <numeric> 26 27 using namespace mlir; 28 29 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { 30 auto resultTypes = op->getResultTypes(); 31 for (auto resType : resultTypes) { 32 VectorType vecType = dyn_cast<VectorType>(resType); 33 // Reject index since getElementTypeBitWidth will abort for Index types. 34 if (!vecType || vecType.getElementType().isIndex()) 35 return false; 36 // There are no dimension to fold if it is a 0-D vector. 37 if (vecType.getRank() == 0) 38 return false; 39 unsigned trailingVecDimBitWidth = 40 vecType.getShape().back() * vecType.getElementTypeBitWidth(); 41 if (trailingVecDimBitWidth >= targetBitWidth) 42 return false; 43 } 44 return true; 45 } 46 47 static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) { 48 VectorType vecType = dyn_cast<VectorType>(t); 49 // Reject index since getElementTypeBitWidth will abort for Index types. 50 if (!vecType || vecType.getElementType().isIndex()) 51 return false; 52 // There are no dimension to fold if it is a 0-D vector. 53 if (vecType.getRank() == 0) 54 return false; 55 unsigned trailingVecDimBitWidth = 56 vecType.getShape().back() * vecType.getElementTypeBitWidth(); 57 return trailingVecDimBitWidth <= targetBitWidth; 58 } 59 60 namespace { 61 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> { 62 using OpConversionPattern::OpConversionPattern; 63 LinearizeConstant( 64 const TypeConverter &typeConverter, MLIRContext *context, 65 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 66 PatternBenefit benefit = 1) 67 : OpConversionPattern(typeConverter, context, benefit), 68 targetVectorBitWidth(targetVectBitWidth) {} 69 LogicalResult 70 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, 71 ConversionPatternRewriter &rewriter) const override { 72 Location loc = constOp.getLoc(); 73 auto resType = 74 getTypeConverter()->convertType<VectorType>(constOp.getType()); 75 76 if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue())) 77 return rewriter.notifyMatchFailure( 78 loc, 79 "Cannot linearize a constant scalable vector that's not a splat"); 80 81 if (!resType) 82 return rewriter.notifyMatchFailure(loc, "can't convert return type"); 83 if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth)) 84 return rewriter.notifyMatchFailure( 85 loc, "Can't flatten since targetBitWidth <= OpSize"); 86 auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue()); 87 if (!dstElementsAttr) 88 return rewriter.notifyMatchFailure(loc, "unsupported attr type"); 89 90 dstElementsAttr = dstElementsAttr.reshape(resType); 91 rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType, 92 dstElementsAttr); 93 return success(); 94 } 95 96 private: 97 unsigned targetVectorBitWidth; 98 }; 99 100 struct LinearizeVectorizable final 101 : OpTraitConversionPattern<OpTrait::Vectorizable> { 102 using OpTraitConversionPattern::OpTraitConversionPattern; 103 104 public: 105 LinearizeVectorizable( 106 const TypeConverter &typeConverter, MLIRContext *context, 107 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 108 PatternBenefit benefit = 1) 109 : OpTraitConversionPattern(typeConverter, context, benefit), 110 targetVectorBitWidth(targetVectBitWidth) {} 111 LogicalResult 112 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 113 ConversionPatternRewriter &rewriter) const override { 114 if (!isLessThanTargetBitWidth(op, targetVectorBitWidth)) 115 return rewriter.notifyMatchFailure( 116 op->getLoc(), "Can't flatten since targetBitWidth <= OpSize"); 117 FailureOr<Operation *> newOp = 118 convertOpResultTypes(op, operands, *getTypeConverter(), rewriter); 119 if (failed(newOp)) 120 return failure(); 121 122 rewriter.replaceOp(op, (*newOp)->getResults()); 123 return success(); 124 } 125 126 private: 127 unsigned targetVectorBitWidth; 128 }; 129 130 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works 131 /// on a linearized vector. 132 /// Following, 133 /// vector.extract_strided_slice %source 134 /// { offsets = [..], strides = [..], sizes = [..] } 135 /// is converted to : 136 /// %source_1d = vector.shape_cast %source 137 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] 138 /// %out_nd = vector.shape_cast %out_1d 139 /// `shuffle_indices_1d` is computed using the offsets and sizes of the 140 /// extraction. 141 struct LinearizeVectorExtractStridedSlice final 142 : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> { 143 using OpConversionPattern::OpConversionPattern; 144 LinearizeVectorExtractStridedSlice( 145 const TypeConverter &typeConverter, MLIRContext *context, 146 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 147 PatternBenefit benefit = 1) 148 : OpConversionPattern(typeConverter, context, benefit), 149 targetVectorBitWidth(targetVectBitWidth) {} 150 151 LogicalResult 152 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, 153 ConversionPatternRewriter &rewriter) const override { 154 Type dstType = getTypeConverter()->convertType(extractOp.getType()); 155 assert(!(extractOp.getVector().getType().isScalable() || 156 cast<VectorType>(dstType).isScalable()) && 157 "scalable vectors are not supported."); 158 if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) 159 return rewriter.notifyMatchFailure( 160 extractOp, "Can't flatten since targetBitWidth <= OpSize"); 161 162 ArrayAttr offsets = extractOp.getOffsets(); 163 ArrayAttr sizes = extractOp.getSizes(); 164 ArrayAttr strides = extractOp.getStrides(); 165 if (!isConstantIntValue(strides[0], 1)) 166 return rewriter.notifyMatchFailure( 167 extractOp, "Strided slice with stride != 1 is not supported."); 168 Value srcVector = adaptor.getVector(); 169 // If kD offsets are specified for nD source vector (n > k), the granularity 170 // of the extraction is greater than 1. In this case last (n-k) dimensions 171 // form the extraction granularity. 172 // Example : 173 // vector.extract_strided_slice %src { 174 // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : 175 // vector<4x8x8xf32> to vector<2x2x8xf32> 176 // Here, extraction granularity is 8. 177 int64_t extractGranularitySize = 1; 178 int64_t nD = extractOp.getSourceVectorType().getRank(); 179 int64_t kD = (int64_t)offsets.size(); 180 int64_t k = kD; 181 while (k < nD) { 182 extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k]; 183 ++k; 184 } 185 // Get total number of extracted slices. 186 int64_t nExtractedSlices = 1; 187 for (Attribute size : sizes) { 188 nExtractedSlices *= cast<IntegerAttr>(size).getInt(); 189 } 190 // Compute the strides of the source vector considering first k dimensions. 191 llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize); 192 for (int i = kD - 2; i >= 0; --i) { 193 sourceStrides[i] = sourceStrides[i + 1] * 194 extractOp.getSourceVectorType().getShape()[i + 1]; 195 } 196 // Final shuffle indices has nExtractedSlices * extractGranularitySize 197 // elements. 198 llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * 199 extractGranularitySize); 200 // Compute the strides of the extracted kD vector. 201 llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1); 202 // Compute extractedStrides. 203 for (int i = kD - 2; i >= 0; --i) { 204 extractedStrides[i] = 205 extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt(); 206 } 207 // Iterate over all extracted slices from 0 to nExtractedSlices - 1 208 // and compute the multi-dimensional index and the corresponding linearized 209 // index within the source vector. 210 for (int64_t i = 0; i < nExtractedSlices; ++i) { 211 int64_t index = i; 212 // Compute the corresponding multi-dimensional index. 213 llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0); 214 for (int64_t j = 0; j < kD; ++j) { 215 multiDimIndex[j] = (index / extractedStrides[j]); 216 index -= multiDimIndex[j] * extractedStrides[j]; 217 } 218 // Compute the corresponding linearized index in the source vector 219 // i.e. shift the multiDimIndex by the offsets. 220 int64_t linearizedIndex = 0; 221 for (int64_t j = 0; j < kD; ++j) { 222 linearizedIndex += 223 (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) * 224 sourceStrides[j]; 225 } 226 // Fill the indices array form linearizedIndex to linearizedIndex + 227 // extractGranularitySize. 228 for (int64_t j = 0; j < extractGranularitySize; ++j) { 229 indices[i * extractGranularitySize + j] = linearizedIndex + j; 230 } 231 } 232 // Perform a shuffle to extract the kD vector. 233 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 234 extractOp, dstType, srcVector, srcVector, 235 rewriter.getI64ArrayAttr(indices)); 236 return success(); 237 } 238 239 private: 240 unsigned targetVectorBitWidth; 241 }; 242 243 /// This pattern converts the ShuffleOp that works on nD (n > 1) 244 /// vectors to a ShuffleOp that works on linearized vectors. 245 /// Following, 246 /// vector.shuffle %v1, %v2 [ shuffle_indices ] 247 /// is converted to : 248 /// %v1_1d = vector.shape_cast %v1 249 /// %v2_1d = vector.shape_cast %v2 250 /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] 251 /// %out_nd = vector.shape_cast %out_1d 252 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices` 253 /// of the original shuffle operation. 254 struct LinearizeVectorShuffle final 255 : public OpConversionPattern<vector::ShuffleOp> { 256 using OpConversionPattern::OpConversionPattern; 257 LinearizeVectorShuffle( 258 const TypeConverter &typeConverter, MLIRContext *context, 259 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 260 PatternBenefit benefit = 1) 261 : OpConversionPattern(typeConverter, context, benefit), 262 targetVectorBitWidth(targetVectBitWidth) {} 263 264 LogicalResult 265 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 266 ConversionPatternRewriter &rewriter) const override { 267 Type dstType = getTypeConverter()->convertType(shuffleOp.getType()); 268 assert(!(shuffleOp.getV1VectorType().isScalable() || 269 shuffleOp.getV2VectorType().isScalable() || 270 cast<VectorType>(dstType).isScalable()) && 271 "scalable vectors are not supported."); 272 if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth)) 273 return rewriter.notifyMatchFailure( 274 shuffleOp, "Can't flatten since targetBitWidth <= OpSize"); 275 276 Value vec1 = adaptor.getV1(); 277 Value vec2 = adaptor.getV2(); 278 int shuffleSliceLen = 1; 279 int rank = shuffleOp.getV1().getType().getRank(); 280 281 // If rank > 1, we need to do the shuffle in the granularity of slices 282 // instead of scalars. Size of the slice is equal to the rank-1 innermost 283 // dims. Mask of the shuffle op specifies which slice to take from the 284 // outermost dim. 285 if (rank > 1) { 286 llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape(); 287 for (unsigned i = 1; i < shape.size(); ++i) { 288 shuffleSliceLen *= shape[i]; 289 } 290 } 291 292 // For each value in the mask, we generate the indices of the source vectors 293 // that needs to be shuffled to the destination vector. If shuffleSliceLen > 294 // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of 295 // elements) instead of scalars. 296 ArrayAttr mask = shuffleOp.getMask(); 297 int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen; 298 llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts); 299 for (auto [i, value] : 300 llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) { 301 302 int64_t v = value.getZExtValue(); 303 std::iota(indices.begin() + shuffleSliceLen * i, 304 indices.begin() + shuffleSliceLen * (i + 1), 305 shuffleSliceLen * v); 306 } 307 308 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 309 shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices)); 310 return success(); 311 } 312 313 private: 314 unsigned targetVectorBitWidth; 315 }; 316 317 /// This pattern converts the ExtractOp to a ShuffleOp that works on a 318 /// linearized vector. 319 /// Following, 320 /// vector.extract %source [ position ] 321 /// is converted to : 322 /// %source_1d = vector.shape_cast %source 323 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] 324 /// %out_nd = vector.shape_cast %out_1d 325 /// `shuffle_indices_1d` is computed using the position of the original extract. 326 struct LinearizeVectorExtract final 327 : public OpConversionPattern<vector::ExtractOp> { 328 using OpConversionPattern::OpConversionPattern; 329 LinearizeVectorExtract( 330 const TypeConverter &typeConverter, MLIRContext *context, 331 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 332 PatternBenefit benefit = 1) 333 : OpConversionPattern(typeConverter, context, benefit), 334 targetVectorBitWidth(targetVectBitWidth) {} 335 LogicalResult 336 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 337 ConversionPatternRewriter &rewriter) const override { 338 Type dstTy = getTypeConverter()->convertType(extractOp.getType()); 339 assert(!(extractOp.getVector().getType().isScalable() || 340 cast<VectorType>(dstTy).isScalable()) && 341 "scalable vectors are not supported."); 342 if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) 343 return rewriter.notifyMatchFailure( 344 extractOp, "Can't flatten since targetBitWidth <= OpSize"); 345 346 // Dynamic position is not supported. 347 if (extractOp.hasDynamicPosition()) 348 return rewriter.notifyMatchFailure(extractOp, 349 "dynamic position is not supported."); 350 351 llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape(); 352 int64_t size = extractOp.getVector().getType().getNumElements(); 353 354 // Compute linearized offset. 355 int64_t linearizedOffset = 0; 356 llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition(); 357 for (auto [i, off] : llvm::enumerate(offsets)) { 358 size /= shape[i]; 359 linearizedOffset += offsets[i] * size; 360 } 361 362 llvm::SmallVector<int64_t, 2> indices(size); 363 std::iota(indices.begin(), indices.end(), linearizedOffset); 364 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 365 extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), 366 rewriter.getI64ArrayAttr(indices)); 367 368 return success(); 369 } 370 371 private: 372 unsigned targetVectorBitWidth; 373 }; 374 375 /// This pattern converts the InsertOp to a ShuffleOp that works on a 376 /// linearized vector. 377 /// Following, 378 /// vector.insert %source %destination [ position ] 379 /// is converted to : 380 /// %source_1d = vector.shape_cast %source 381 /// %destination_1d = vector.shape_cast %destination 382 /// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d 383 /// ] %out_nd = vector.shape_cast %out_1d 384 /// `shuffle_indices_1d` is computed using the position of the original insert. 385 struct LinearizeVectorInsert final 386 : public OpConversionPattern<vector::InsertOp> { 387 using OpConversionPattern::OpConversionPattern; 388 LinearizeVectorInsert( 389 const TypeConverter &typeConverter, MLIRContext *context, 390 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 391 PatternBenefit benefit = 1) 392 : OpConversionPattern(typeConverter, context, benefit), 393 targetVectorBitWidth(targetVectBitWidth) {} 394 LogicalResult 395 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 396 ConversionPatternRewriter &rewriter) const override { 397 Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType()); 398 assert(!(insertOp.getDestVectorType().isScalable() || 399 cast<VectorType>(dstTy).isScalable()) && 400 "scalable vectors are not supported."); 401 402 if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(), 403 targetVectorBitWidth)) 404 return rewriter.notifyMatchFailure( 405 insertOp, "Can't flatten since targetBitWidth < OpSize"); 406 407 // dynamic position is not supported 408 if (insertOp.hasDynamicPosition()) 409 return rewriter.notifyMatchFailure(insertOp, 410 "dynamic position is not supported."); 411 auto srcTy = insertOp.getSourceType(); 412 auto srcAsVec = dyn_cast<VectorType>(srcTy); 413 uint64_t srcSize = 0; 414 if (srcAsVec) { 415 srcSize = srcAsVec.getNumElements(); 416 } else { 417 return rewriter.notifyMatchFailure(insertOp, 418 "scalars are not supported."); 419 } 420 421 auto dstShape = insertOp.getDestVectorType().getShape(); 422 const auto dstSize = insertOp.getDestVectorType().getNumElements(); 423 auto dstSizeForOffsets = dstSize; 424 425 // compute linearized offset 426 int64_t linearizedOffset = 0; 427 auto offsetsNd = insertOp.getStaticPosition(); 428 for (auto [dim, offset] : llvm::enumerate(offsetsNd)) { 429 dstSizeForOffsets /= dstShape[dim]; 430 linearizedOffset += offset * dstSizeForOffsets; 431 } 432 433 llvm::SmallVector<int64_t, 2> indices(dstSize); 434 auto origValsUntil = indices.begin(); 435 std::advance(origValsUntil, linearizedOffset); 436 std::iota(indices.begin(), origValsUntil, 437 0); // original values that remain [0, offset) 438 auto newValsUntil = origValsUntil; 439 std::advance(newValsUntil, srcSize); 440 std::iota(origValsUntil, newValsUntil, 441 dstSize); // new values [offset, offset+srcNumElements) 442 std::iota(newValsUntil, indices.end(), 443 linearizedOffset + srcSize); // the rest of original values 444 // [offset+srcNumElements, end) 445 446 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 447 insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), 448 rewriter.getI64ArrayAttr(indices)); 449 450 return success(); 451 } 452 453 private: 454 unsigned targetVectorBitWidth; 455 }; 456 } // namespace 457 458 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( 459 TypeConverter &typeConverter, RewritePatternSet &patterns, 460 ConversionTarget &target, unsigned targetBitWidth) { 461 462 typeConverter.addConversion([](VectorType type) -> std::optional<Type> { 463 if (!isLinearizableVector(type)) 464 return type; 465 466 return VectorType::get(type.getNumElements(), type.getElementType(), 467 type.isScalable()); 468 }); 469 470 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, 471 Location loc) -> Value { 472 if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) || 473 !isa<VectorType>(type)) 474 return nullptr; 475 476 return builder.create<vector::ShapeCastOp>(loc, type, inputs.front()); 477 }; 478 typeConverter.addArgumentMaterialization(materializeCast); 479 typeConverter.addSourceMaterialization(materializeCast); 480 typeConverter.addTargetMaterialization(materializeCast); 481 target.markUnknownOpDynamicallyLegal( 482 [=](Operation *op) -> std::optional<bool> { 483 if ((isa<arith::ConstantOp>(op) || 484 op->hasTrait<OpTrait::Vectorizable>())) { 485 return (isLessThanTargetBitWidth(op, targetBitWidth) 486 ? typeConverter.isLegal(op) 487 : true); 488 } 489 return std::nullopt; 490 }); 491 492 patterns.add<LinearizeConstant, LinearizeVectorizable>( 493 typeConverter, patterns.getContext(), targetBitWidth); 494 } 495 496 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( 497 TypeConverter &typeConverter, RewritePatternSet &patterns, 498 ConversionTarget &target, unsigned int targetBitWidth) { 499 target.addDynamicallyLegalOp<vector::ShuffleOp>( 500 [=](vector::ShuffleOp shuffleOp) -> bool { 501 return isLessThanTargetBitWidth(shuffleOp, targetBitWidth) 502 ? (typeConverter.isLegal(shuffleOp) && 503 cast<mlir::VectorType>(shuffleOp.getResult().getType()) 504 .getRank() == 1) 505 : true; 506 }); 507 patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract, 508 LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>( 509 typeConverter, patterns.getContext(), targetBitWidth); 510 } 511