1 //===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===// 2 // 3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM 4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information. 5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites and utilities to lower the 10 // 'vector.multi_reduction' operation. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 17 #include "mlir/Dialect/Vector/Transforms/Passes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 namespace mlir { 23 namespace vector { 24 #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION 25 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" 26 } // namespace vector 27 } // namespace mlir 28 29 #define DEBUG_TYPE "vector-multi-reduction" 30 31 using namespace mlir; 32 33 namespace { 34 /// This file implements the following transformations as composable atomic 35 /// patterns. 36 37 /// Converts vector.multi_reduction into inner-most/outer-most reduction form 38 /// by using vector.transpose 39 class InnerOuterDimReductionConversion 40 : public OpRewritePattern<vector::MultiDimReductionOp> { 41 public: 42 using OpRewritePattern::OpRewritePattern; 43 44 explicit InnerOuterDimReductionConversion( 45 MLIRContext *context, vector::VectorMultiReductionLowering options, 46 PatternBenefit benefit = 1) 47 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), 48 useInnerDimsForReduction( 49 options == vector::VectorMultiReductionLowering::InnerReduction) {} 50 51 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 52 PatternRewriter &rewriter) const override { 53 // Vector mask setup. 54 OpBuilder::InsertionGuard guard(rewriter); 55 auto maskableOp = 56 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); 57 Operation *rootOp; 58 if (maskableOp.isMasked()) { 59 rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 60 rootOp = maskableOp.getMaskingOp(); 61 } else { 62 rootOp = multiReductionOp; 63 } 64 65 auto src = multiReductionOp.getSource(); 66 auto loc = multiReductionOp.getLoc(); 67 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 68 69 // Separate reduction and parallel dims 70 ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims(); 71 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(), 72 reductionDims.end()); 73 int64_t reductionSize = reductionDims.size(); 74 SmallVector<int64_t, 4> parallelDims; 75 for (int64_t i = 0; i < srcRank; ++i) 76 if (!reductionDimsSet.contains(i)) 77 parallelDims.push_back(i); 78 79 // Add transpose only if inner-most/outer-most dimensions are not parallel 80 // and there are parallel dims. 81 if (parallelDims.empty()) 82 return failure(); 83 if (useInnerDimsForReduction && 84 (parallelDims == 85 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))) 86 return failure(); 87 88 if (!useInnerDimsForReduction && 89 (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>( 90 reductionDims.size(), 91 parallelDims.size() + reductionDims.size())))) 92 return failure(); 93 94 SmallVector<int64_t, 4> indices; 95 if (useInnerDimsForReduction) { 96 indices.append(parallelDims.begin(), parallelDims.end()); 97 indices.append(reductionDims.begin(), reductionDims.end()); 98 } else { 99 indices.append(reductionDims.begin(), reductionDims.end()); 100 indices.append(parallelDims.begin(), parallelDims.end()); 101 } 102 103 // If masked, transpose the original mask. 104 Value transposedMask; 105 if (maskableOp.isMasked()) { 106 transposedMask = rewriter.create<vector::TransposeOp>( 107 loc, maskableOp.getMaskingOp().getMask(), indices); 108 } 109 110 // Transpose reduction source. 111 auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices); 112 SmallVector<bool> reductionMask(srcRank, false); 113 for (int i = 0; i < reductionSize; ++i) { 114 if (useInnerDimsForReduction) 115 reductionMask[srcRank - i - 1] = true; 116 else 117 reductionMask[i] = true; 118 } 119 120 Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>( 121 multiReductionOp.getLoc(), transposeOp.getResult(), 122 multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind()); 123 newMultiRedOp = 124 mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask); 125 126 rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0)); 127 return success(); 128 } 129 130 private: 131 const bool useInnerDimsForReduction; 132 }; 133 134 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction 135 /// dimensions are either inner most or outer most. 136 class ReduceMultiDimReductionRank 137 : public OpRewritePattern<vector::MultiDimReductionOp> { 138 public: 139 using OpRewritePattern::OpRewritePattern; 140 141 explicit ReduceMultiDimReductionRank( 142 MLIRContext *context, vector::VectorMultiReductionLowering options, 143 PatternBenefit benefit = 1) 144 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), 145 useInnerDimsForReduction( 146 options == vector::VectorMultiReductionLowering::InnerReduction) {} 147 148 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 149 PatternRewriter &rewriter) const override { 150 // Vector mask setup. 151 OpBuilder::InsertionGuard guard(rewriter); 152 auto maskableOp = 153 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); 154 Operation *rootOp; 155 if (maskableOp.isMasked()) { 156 rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 157 rootOp = maskableOp.getMaskingOp(); 158 } else { 159 rootOp = multiReductionOp; 160 } 161 162 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 163 auto srcShape = multiReductionOp.getSourceVectorType().getShape(); 164 auto srcScalableDims = 165 multiReductionOp.getSourceVectorType().getScalableDims(); 166 auto loc = multiReductionOp.getLoc(); 167 168 // If rank less than 2, nothing to do. 169 if (srcRank < 2) 170 return failure(); 171 172 // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g. 173 // `vscale * vscale` that's currently not modelled. 174 if (llvm::count(srcScalableDims, true) > 1) 175 return failure(); 176 177 // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail. 178 SmallVector<bool> reductionMask = multiReductionOp.getReductionMask(); 179 if (srcRank == 2 && reductionMask.front() != reductionMask.back()) 180 return failure(); 181 182 // 1. Separate reduction and parallel dims. 183 SmallVector<int64_t, 4> parallelDims, parallelShapes; 184 SmallVector<bool, 4> parallelScalableDims; 185 SmallVector<int64_t, 4> reductionDims, reductionShapes; 186 bool isReductionDimScalable = false; 187 for (const auto &it : llvm::enumerate(reductionMask)) { 188 int64_t i = it.index(); 189 bool isReduction = it.value(); 190 if (isReduction) { 191 reductionDims.push_back(i); 192 reductionShapes.push_back(srcShape[i]); 193 isReductionDimScalable |= srcScalableDims[i]; 194 } else { 195 parallelDims.push_back(i); 196 parallelShapes.push_back(srcShape[i]); 197 parallelScalableDims.push_back(srcScalableDims[i]); 198 } 199 } 200 201 // 2. Compute flattened parallel and reduction sizes. 202 int flattenedParallelDim = 0; 203 int flattenedReductionDim = 0; 204 if (!parallelShapes.empty()) { 205 flattenedParallelDim = 1; 206 for (auto d : parallelShapes) 207 flattenedParallelDim *= d; 208 } 209 if (!reductionShapes.empty()) { 210 flattenedReductionDim = 1; 211 for (auto d : reductionShapes) 212 flattenedReductionDim *= d; 213 } 214 // We must at least have some parallel or some reduction. 215 assert((flattenedParallelDim || flattenedReductionDim) && 216 "expected at least one parallel or reduction dim"); 217 218 // 3. Fail if reduction/parallel dims are not contiguous. 219 // Check parallelDims are exactly [0 .. size). 220 int64_t counter = 0; 221 if (useInnerDimsForReduction && 222 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) 223 return failure(); 224 // Check parallelDims are exactly {reductionDims.size()} + [0 .. size). 225 counter = reductionDims.size(); 226 if (!useInnerDimsForReduction && 227 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) 228 return failure(); 229 230 // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into 231 // a single parallel (resp. reduction) dim. 232 SmallVector<bool, 2> mask; 233 SmallVector<bool, 2> scalableDims; 234 SmallVector<int64_t, 2> vectorShape; 235 bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true); 236 if (flattenedParallelDim) { 237 mask.push_back(false); 238 vectorShape.push_back(flattenedParallelDim); 239 scalableDims.push_back(isParallelDimScalable); 240 } 241 if (flattenedReductionDim) { 242 mask.push_back(true); 243 vectorShape.push_back(flattenedReductionDim); 244 scalableDims.push_back(isReductionDimScalable); 245 } 246 if (!useInnerDimsForReduction && vectorShape.size() == 2) { 247 std::swap(mask.front(), mask.back()); 248 std::swap(vectorShape.front(), vectorShape.back()); 249 std::swap(scalableDims.front(), scalableDims.back()); 250 } 251 252 Value newVectorMask; 253 if (maskableOp.isMasked()) { 254 Value vectorMask = maskableOp.getMaskingOp().getMask(); 255 auto maskCastedType = VectorType::get( 256 vectorShape, 257 llvm::cast<VectorType>(vectorMask.getType()).getElementType()); 258 newVectorMask = 259 rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask); 260 } 261 262 auto castedType = VectorType::get( 263 vectorShape, multiReductionOp.getSourceVectorType().getElementType(), 264 scalableDims); 265 Value cast = rewriter.create<vector::ShapeCastOp>( 266 loc, castedType, multiReductionOp.getSource()); 267 268 Value acc = multiReductionOp.getAcc(); 269 if (flattenedParallelDim) { 270 auto accType = VectorType::get( 271 {flattenedParallelDim}, 272 multiReductionOp.getSourceVectorType().getElementType(), 273 /*scalableDims=*/{isParallelDimScalable}); 274 acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc); 275 } 276 // 6. Creates the flattened form of vector.multi_reduction with inner/outer 277 // most dim as reduction. 278 Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>( 279 loc, cast, acc, mask, multiReductionOp.getKind()); 280 newMultiDimRedOp = 281 mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask); 282 283 // 7. If there are no parallel shapes, the result is a scalar. 284 // TODO: support 0-d vectors when available. 285 if (parallelShapes.empty()) { 286 rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0)); 287 return success(); 288 } 289 290 // 8. Creates shape cast for the output n-D -> 2-D. 291 VectorType outputCastedType = VectorType::get( 292 parallelShapes, multiReductionOp.getSourceVectorType().getElementType(), 293 parallelScalableDims); 294 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 295 rootOp, outputCastedType, newMultiDimRedOp->getResult(0)); 296 return success(); 297 } 298 299 private: 300 const bool useInnerDimsForReduction; 301 }; 302 303 /// Unrolls vector.multi_reduction with outermost reductions 304 /// and combines results 305 struct TwoDimMultiReductionToElementWise 306 : public OpRewritePattern<vector::MultiDimReductionOp> { 307 using OpRewritePattern::OpRewritePattern; 308 309 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 310 PatternRewriter &rewriter) const override { 311 auto maskableOp = 312 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); 313 if (maskableOp.isMasked()) 314 // TODO: Support masking. 315 return failure(); 316 317 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 318 // Rank-2 ["parallel", "reduce"] or bail. 319 if (srcRank != 2) 320 return failure(); 321 322 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0)) 323 return failure(); 324 325 auto loc = multiReductionOp.getLoc(); 326 ArrayRef<int64_t> srcShape = 327 multiReductionOp.getSourceVectorType().getShape(); 328 329 Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType()); 330 if (!elementType.isIntOrIndexOrFloat()) 331 return failure(); 332 333 Value result = multiReductionOp.getAcc(); 334 for (int64_t i = 0; i < srcShape[0]; i++) { 335 auto operand = rewriter.create<vector::ExtractOp>( 336 loc, multiReductionOp.getSource(), i); 337 result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), 338 operand, result); 339 } 340 341 rewriter.replaceOp(multiReductionOp, result); 342 return success(); 343 } 344 }; 345 346 /// Converts 2d vector.multi_reduction with inner most reduction dimension into 347 /// a sequence of vector.reduction ops. 348 struct TwoDimMultiReductionToReduction 349 : public OpRewritePattern<vector::MultiDimReductionOp> { 350 using OpRewritePattern::OpRewritePattern; 351 352 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 353 PatternRewriter &rewriter) const override { 354 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 355 if (srcRank != 2) 356 return failure(); 357 358 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1)) 359 return failure(); 360 361 // Vector mask setup. 362 OpBuilder::InsertionGuard guard(rewriter); 363 auto maskableOp = 364 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); 365 Operation *rootOp; 366 if (maskableOp.isMasked()) { 367 rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 368 rootOp = maskableOp.getMaskingOp(); 369 } else { 370 rootOp = multiReductionOp; 371 } 372 373 auto loc = multiReductionOp.getLoc(); 374 Value result = rewriter.create<arith::ConstantOp>( 375 loc, multiReductionOp.getDestType(), 376 rewriter.getZeroAttr(multiReductionOp.getDestType())); 377 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; 378 379 for (int i = 0; i < outerDim; ++i) { 380 auto v = rewriter.create<vector::ExtractOp>( 381 loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i}); 382 auto acc = rewriter.create<vector::ExtractOp>( 383 loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i}); 384 Operation *reductionOp = rewriter.create<vector::ReductionOp>( 385 loc, multiReductionOp.getKind(), v, acc); 386 387 // If masked, slice the mask and mask the new reduction operation. 388 if (maskableOp.isMasked()) { 389 Value mask = rewriter.create<vector::ExtractOp>( 390 loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i}); 391 reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask); 392 } 393 394 result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0), 395 result, i); 396 } 397 398 rewriter.replaceOp(rootOp, result); 399 return success(); 400 } 401 }; 402 403 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d 404 /// form with both a single parallel and reduction dimension. 405 /// This is achieved with a simple vector.shape_cast that inserts a leading 1. 406 /// The case with a single parallel dimension is a noop and folds away 407 /// separately. 408 struct OneDimMultiReductionToTwoDim 409 : public OpRewritePattern<vector::MultiDimReductionOp> { 410 using OpRewritePattern::OpRewritePattern; 411 412 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 413 PatternRewriter &rewriter) const override { 414 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 415 // Rank-1 or bail. 416 if (srcRank != 1) 417 return failure(); 418 419 // Vector mask setup. 420 OpBuilder::InsertionGuard guard(rewriter); 421 auto maskableOp = 422 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); 423 Operation *rootOp; 424 Value mask; 425 if (maskableOp.isMasked()) { 426 rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 427 rootOp = maskableOp.getMaskingOp(); 428 mask = maskableOp.getMaskingOp().getMask(); 429 } else { 430 rootOp = multiReductionOp; 431 } 432 433 auto loc = multiReductionOp.getLoc(); 434 auto srcVectorType = multiReductionOp.getSourceVectorType(); 435 auto srcShape = srcVectorType.getShape(); 436 auto castedType = VectorType::get( 437 ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(), 438 ArrayRef<bool>{false, srcVectorType.getScalableDims().back()}); 439 440 auto accType = 441 VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType()); 442 assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) && 443 "multi_reduction with a single dimension expects a scalar result"); 444 445 // If the unique dim is reduced and we insert a parallel in front, we need a 446 // {false, true} mask. 447 SmallVector<bool, 2> reductionMask{false, true}; 448 449 /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0) 450 Value cast = rewriter.create<vector::ShapeCastOp>( 451 loc, castedType, multiReductionOp.getSource()); 452 Value castAcc = rewriter.create<vector::BroadcastOp>( 453 loc, accType, multiReductionOp.getAcc()); 454 Value castMask; 455 if (maskableOp.isMasked()) { 456 auto maskType = llvm::cast<VectorType>(mask.getType()); 457 auto castMaskType = VectorType::get( 458 ArrayRef<int64_t>{1, maskType.getShape().back()}, 459 maskType.getElementType(), 460 ArrayRef<bool>{false, maskType.getScalableDims().back()}); 461 castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask); 462 } 463 464 Operation *newOp = rewriter.create<vector::MultiDimReductionOp>( 465 loc, cast, castAcc, reductionMask, multiReductionOp.getKind()); 466 newOp = vector::maskOperation(rewriter, newOp, castMask); 467 468 rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0), 469 ArrayRef<int64_t>{0}); 470 return success(); 471 } 472 }; 473 474 struct LowerVectorMultiReductionPass 475 : public vector::impl::LowerVectorMultiReductionBase< 476 LowerVectorMultiReductionPass> { 477 LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) { 478 this->loweringStrategy = option; 479 } 480 481 void runOnOperation() override { 482 Operation *op = getOperation(); 483 MLIRContext *context = op->getContext(); 484 485 RewritePatternSet loweringPatterns(context); 486 populateVectorMultiReductionLoweringPatterns(loweringPatterns, 487 this->loweringStrategy); 488 489 if (failed(applyPatternsGreedily(op, std::move(loweringPatterns)))) 490 signalPassFailure(); 491 } 492 493 void getDependentDialects(DialectRegistry ®istry) const override { 494 registry.insert<vector::VectorDialect>(); 495 } 496 }; 497 498 } // namespace 499 500 void mlir::vector::populateVectorMultiReductionLoweringPatterns( 501 RewritePatternSet &patterns, VectorMultiReductionLowering options, 502 PatternBenefit benefit) { 503 patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>( 504 patterns.getContext(), options, benefit); 505 patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit); 506 if (options == VectorMultiReductionLowering ::InnerReduction) 507 patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(), 508 benefit); 509 else 510 patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(), 511 benefit); 512 } 513 514 std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass( 515 vector::VectorMultiReductionLowering option) { 516 return std::make_unique<LowerVectorMultiReductionPass>(option); 517 } 518