1 //===- VectorLinearize.cpp - vector linearization transforms --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns and pass for linearizing ND vectors into 1D. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Vector/IR/VectorOps.h" 15 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/BuiltinAttributes.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Support/LogicalResult.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "llvm/ADT/ArrayRef.h" 24 #include <cstdint> 25 #include <numeric> 26 27 using namespace mlir; 28 29 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { 30 auto resultTypes = op->getResultTypes(); 31 for (auto resType : resultTypes) { 32 VectorType vecType = dyn_cast<VectorType>(resType); 33 // Reject index since getElementTypeBitWidth will abort for Index types. 34 if (!vecType || vecType.getElementType().isIndex()) 35 return false; 36 // There are no dimension to fold if it is a 0-D vector. 37 if (vecType.getRank() == 0) 38 return false; 39 unsigned trailingVecDimBitWidth = 40 vecType.getShape().back() * vecType.getElementTypeBitWidth(); 41 if (trailingVecDimBitWidth >= targetBitWidth) 42 return false; 43 } 44 return true; 45 } 46 47 namespace { 48 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> { 49 using OpConversionPattern::OpConversionPattern; 50 LinearizeConstant( 51 const TypeConverter &typeConverter, MLIRContext *context, 52 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 53 PatternBenefit benefit = 1) 54 : OpConversionPattern(typeConverter, context, benefit), 55 targetVectorBitWidth(targetVectBitWidth) {} 56 LogicalResult 57 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, 58 ConversionPatternRewriter &rewriter) const override { 59 Location loc = constOp.getLoc(); 60 auto resType = 61 getTypeConverter()->convertType<VectorType>(constOp.getType()); 62 63 if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue())) 64 return rewriter.notifyMatchFailure( 65 loc, 66 "Cannot linearize a constant scalable vector that's not a splat"); 67 68 if (!resType) 69 return rewriter.notifyMatchFailure(loc, "can't convert return type"); 70 if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth)) 71 return rewriter.notifyMatchFailure( 72 loc, "Can't flatten since targetBitWidth <= OpSize"); 73 auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue()); 74 if (!dstElementsAttr) 75 return rewriter.notifyMatchFailure(loc, "unsupported attr type"); 76 77 dstElementsAttr = dstElementsAttr.reshape(resType); 78 rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType, 79 dstElementsAttr); 80 return success(); 81 } 82 83 private: 84 unsigned targetVectorBitWidth; 85 }; 86 87 struct LinearizeVectorizable final 88 : OpTraitConversionPattern<OpTrait::Vectorizable> { 89 using OpTraitConversionPattern::OpTraitConversionPattern; 90 91 public: 92 LinearizeVectorizable( 93 const TypeConverter &typeConverter, MLIRContext *context, 94 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 95 PatternBenefit benefit = 1) 96 : OpTraitConversionPattern(typeConverter, context, benefit), 97 targetVectorBitWidth(targetVectBitWidth) {} 98 LogicalResult 99 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 100 ConversionPatternRewriter &rewriter) const override { 101 if (!isLessThanTargetBitWidth(op, targetVectorBitWidth)) 102 return rewriter.notifyMatchFailure( 103 op->getLoc(), "Can't flatten since targetBitWidth <= OpSize"); 104 FailureOr<Operation *> newOp = 105 convertOpResultTypes(op, operands, *getTypeConverter(), rewriter); 106 if (failed(newOp)) 107 return failure(); 108 109 rewriter.replaceOp(op, (*newOp)->getResults()); 110 return success(); 111 } 112 113 private: 114 unsigned targetVectorBitWidth; 115 }; 116 117 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works 118 /// on a linearized vector. 119 /// Following, 120 /// vector.extract_strided_slice %source 121 /// { offsets = [..], strides = [..], sizes = [..] } 122 /// is converted to : 123 /// %source_1d = vector.shape_cast %source 124 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] 125 /// %out_nd = vector.shape_cast %out_1d 126 /// `shuffle_indices_1d` is computed using the offsets and sizes of the 127 /// extraction. 128 struct LinearizeVectorExtractStridedSlice final 129 : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> { 130 using OpConversionPattern::OpConversionPattern; 131 LinearizeVectorExtractStridedSlice( 132 const TypeConverter &typeConverter, MLIRContext *context, 133 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 134 PatternBenefit benefit = 1) 135 : OpConversionPattern(typeConverter, context, benefit), 136 targetVectorBitWidth(targetVectBitWidth) {} 137 138 LogicalResult 139 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, 140 ConversionPatternRewriter &rewriter) const override { 141 Type dstType = getTypeConverter()->convertType(extractOp.getType()); 142 assert(!(extractOp.getVector().getType().isScalable() || 143 cast<VectorType>(dstType).isScalable()) && 144 "scalable vectors are not supported."); 145 if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) 146 return rewriter.notifyMatchFailure( 147 extractOp, "Can't flatten since targetBitWidth <= OpSize"); 148 149 ArrayAttr offsets = extractOp.getOffsets(); 150 ArrayAttr sizes = extractOp.getSizes(); 151 ArrayAttr strides = extractOp.getStrides(); 152 if (!isConstantIntValue(strides[0], 1)) 153 return rewriter.notifyMatchFailure( 154 extractOp, "Strided slice with stride != 1 is not supported."); 155 Value srcVector = adaptor.getVector(); 156 // If kD offsets are specified for nD source vector (n > k), the granularity 157 // of the extraction is greater than 1. In this case last (n-k) dimensions 158 // form the extraction granularity. 159 // Example : 160 // vector.extract_strided_slice %src { 161 // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : 162 // vector<4x8x8xf32> to vector<2x2x8xf32> 163 // Here, extraction granularity is 8. 164 int64_t extractGranularitySize = 1; 165 int64_t nD = extractOp.getSourceVectorType().getRank(); 166 int64_t kD = (int64_t)offsets.size(); 167 int64_t k = kD; 168 while (k < nD) { 169 extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k]; 170 ++k; 171 } 172 // Get total number of extracted slices. 173 int64_t nExtractedSlices = 1; 174 for (Attribute size : sizes) { 175 nExtractedSlices *= cast<IntegerAttr>(size).getInt(); 176 } 177 // Compute the strides of the source vector considering first k dimensions. 178 llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize); 179 for (int i = kD - 2; i >= 0; --i) { 180 sourceStrides[i] = sourceStrides[i + 1] * 181 extractOp.getSourceVectorType().getShape()[i + 1]; 182 } 183 // Final shuffle indices has nExtractedSlices * extractGranularitySize 184 // elements. 185 llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * 186 extractGranularitySize); 187 // Compute the strides of the extracted kD vector. 188 llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1); 189 // Compute extractedStrides. 190 for (int i = kD - 2; i >= 0; --i) { 191 extractedStrides[i] = 192 extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt(); 193 } 194 // Iterate over all extracted slices from 0 to nExtractedSlices - 1 195 // and compute the multi-dimensional index and the corresponding linearized 196 // index within the source vector. 197 for (int64_t i = 0; i < nExtractedSlices; ++i) { 198 int64_t index = i; 199 // Compute the corresponding multi-dimensional index. 200 llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0); 201 for (int64_t j = 0; j < kD; ++j) { 202 multiDimIndex[j] = (index / extractedStrides[j]); 203 index -= multiDimIndex[j] * extractedStrides[j]; 204 } 205 // Compute the corresponding linearized index in the source vector 206 // i.e. shift the multiDimIndex by the offsets. 207 int64_t linearizedIndex = 0; 208 for (int64_t j = 0; j < kD; ++j) { 209 linearizedIndex += 210 (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) * 211 sourceStrides[j]; 212 } 213 // Fill the indices array form linearizedIndex to linearizedIndex + 214 // extractGranularitySize. 215 for (int64_t j = 0; j < extractGranularitySize; ++j) { 216 indices[i * extractGranularitySize + j] = linearizedIndex + j; 217 } 218 } 219 // Perform a shuffle to extract the kD vector. 220 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 221 extractOp, dstType, srcVector, srcVector, 222 rewriter.getI64ArrayAttr(indices)); 223 return success(); 224 } 225 226 private: 227 unsigned targetVectorBitWidth; 228 }; 229 230 /// This pattern converts the ShuffleOp that works on nD (n > 1) 231 /// vectors to a ShuffleOp that works on linearized vectors. 232 /// Following, 233 /// vector.shuffle %v1, %v2 [ shuffle_indices ] 234 /// is converted to : 235 /// %v1_1d = vector.shape_cast %v1 236 /// %v2_1d = vector.shape_cast %v2 237 /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] 238 /// %out_nd = vector.shape_cast %out_1d 239 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices` 240 /// of the original shuffle operation. 241 struct LinearizeVectorShuffle final 242 : public OpConversionPattern<vector::ShuffleOp> { 243 using OpConversionPattern::OpConversionPattern; 244 LinearizeVectorShuffle( 245 const TypeConverter &typeConverter, MLIRContext *context, 246 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 247 PatternBenefit benefit = 1) 248 : OpConversionPattern(typeConverter, context, benefit), 249 targetVectorBitWidth(targetVectBitWidth) {} 250 251 LogicalResult 252 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 253 ConversionPatternRewriter &rewriter) const override { 254 Type dstType = getTypeConverter()->convertType(shuffleOp.getType()); 255 assert(!(shuffleOp.getV1VectorType().isScalable() || 256 shuffleOp.getV2VectorType().isScalable() || 257 cast<VectorType>(dstType).isScalable()) && 258 "scalable vectors are not supported."); 259 if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth)) 260 return rewriter.notifyMatchFailure( 261 shuffleOp, "Can't flatten since targetBitWidth <= OpSize"); 262 263 Value vec1 = adaptor.getV1(); 264 Value vec2 = adaptor.getV2(); 265 int shuffleSliceLen = 1; 266 int rank = shuffleOp.getV1().getType().getRank(); 267 268 // If rank > 1, we need to do the shuffle in the granularity of slices 269 // instead of scalars. Size of the slice is equal to the rank-1 innermost 270 // dims. Mask of the shuffle op specifies which slice to take from the 271 // outermost dim. 272 if (rank > 1) { 273 llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape(); 274 for (unsigned i = 1; i < shape.size(); ++i) { 275 shuffleSliceLen *= shape[i]; 276 } 277 } 278 279 // For each value in the mask, we generate the indices of the source vectors 280 // that needs to be shuffled to the destination vector. If shuffleSliceLen > 281 // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of 282 // elements) instead of scalars. 283 ArrayAttr mask = shuffleOp.getMask(); 284 int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen; 285 llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts); 286 for (auto [i, value] : 287 llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) { 288 289 int64_t v = value.getZExtValue(); 290 std::iota(indices.begin() + shuffleSliceLen * i, 291 indices.begin() + shuffleSliceLen * (i + 1), 292 shuffleSliceLen * v); 293 } 294 295 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 296 shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices)); 297 return success(); 298 } 299 300 private: 301 unsigned targetVectorBitWidth; 302 }; 303 304 /// This pattern converts the ExtractOp to a ShuffleOp that works on a 305 /// linearized vector. 306 /// Following, 307 /// vector.extract %source [ position ] 308 /// is converted to : 309 /// %source_1d = vector.shape_cast %source 310 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] 311 /// %out_nd = vector.shape_cast %out_1d 312 /// `shuffle_indices_1d` is computed using the position of the original extract. 313 struct LinearizeVectorExtract final 314 : public OpConversionPattern<vector::ExtractOp> { 315 using OpConversionPattern::OpConversionPattern; 316 LinearizeVectorExtract( 317 const TypeConverter &typeConverter, MLIRContext *context, 318 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 319 PatternBenefit benefit = 1) 320 : OpConversionPattern(typeConverter, context, benefit), 321 targetVectorBitWidth(targetVectBitWidth) {} 322 LogicalResult 323 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 324 ConversionPatternRewriter &rewriter) const override { 325 Type dstTy = getTypeConverter()->convertType(extractOp.getType()); 326 assert(!(extractOp.getVector().getType().isScalable() || 327 cast<VectorType>(dstTy).isScalable()) && 328 "scalable vectors are not supported."); 329 if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) 330 return rewriter.notifyMatchFailure( 331 extractOp, "Can't flatten since targetBitWidth <= OpSize"); 332 333 // Dynamic position is not supported. 334 if (extractOp.hasDynamicPosition()) 335 return rewriter.notifyMatchFailure(extractOp, 336 "dynamic position is not supported."); 337 338 llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape(); 339 int64_t size = extractOp.getVector().getType().getNumElements(); 340 341 // Compute linearized offset. 342 int64_t linearizedOffset = 0; 343 llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition(); 344 for (auto [i, off] : llvm::enumerate(offsets)) { 345 size /= shape[i]; 346 linearizedOffset += offsets[i] * size; 347 } 348 349 llvm::SmallVector<int64_t, 2> indices(size); 350 std::iota(indices.begin(), indices.end(), linearizedOffset); 351 rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 352 extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), 353 rewriter.getI64ArrayAttr(indices)); 354 355 return success(); 356 } 357 358 private: 359 unsigned targetVectorBitWidth; 360 }; 361 } // namespace 362 363 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( 364 TypeConverter &typeConverter, RewritePatternSet &patterns, 365 ConversionTarget &target, unsigned targetBitWidth) { 366 367 typeConverter.addConversion([](VectorType type) -> std::optional<Type> { 368 if (!isLinearizableVector(type)) 369 return type; 370 371 return VectorType::get(type.getNumElements(), type.getElementType(), 372 type.isScalable()); 373 }); 374 375 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, 376 Location loc) -> Value { 377 if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) || 378 !isa<VectorType>(type)) 379 return nullptr; 380 381 return builder.create<vector::ShapeCastOp>(loc, type, inputs.front()); 382 }; 383 typeConverter.addArgumentMaterialization(materializeCast); 384 typeConverter.addSourceMaterialization(materializeCast); 385 typeConverter.addTargetMaterialization(materializeCast); 386 target.markUnknownOpDynamicallyLegal( 387 [=](Operation *op) -> std::optional<bool> { 388 if ((isa<arith::ConstantOp>(op) || 389 op->hasTrait<OpTrait::Vectorizable>())) { 390 return (isLessThanTargetBitWidth(op, targetBitWidth) 391 ? typeConverter.isLegal(op) 392 : true); 393 } 394 return std::nullopt; 395 }); 396 397 patterns.add<LinearizeConstant, LinearizeVectorizable>( 398 typeConverter, patterns.getContext(), targetBitWidth); 399 } 400 401 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( 402 TypeConverter &typeConverter, RewritePatternSet &patterns, 403 ConversionTarget &target, unsigned int targetBitWidth) { 404 target.addDynamicallyLegalOp<vector::ShuffleOp>( 405 [=](vector::ShuffleOp shuffleOp) -> bool { 406 return isLessThanTargetBitWidth(shuffleOp, targetBitWidth) 407 ? (typeConverter.isLegal(shuffleOp) && 408 cast<mlir::VectorType>(shuffleOp.getResult().getType()) 409 .getRank() == 1) 410 : true; 411 }); 412 patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract, 413 LinearizeVectorExtractStridedSlice>( 414 typeConverter, patterns.getContext(), targetBitWidth); 415 } 416