1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass that converts loops generated by the sparsifier into a form that 10 // can exploit SIMD instructions of the target architecture. Note that this pass 11 // ensures the sparsifier can generate efficient SIMD (including ArmSVE 12 // support) with proper separation of concerns as far as sparsification and 13 // vectorization is concerned. However, this pass is not the final abstraction 14 // level we want, and not the general vectorizer we want either. It forms a good 15 // stepping stone for incremental future improvements though. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "Utils/CodegenUtils.h" 20 #include "Utils/LoopEmitter.h" 21 22 #include "mlir/Dialect/Affine/IR/AffineOps.h" 23 #include "mlir/Dialect/Arith/IR/Arith.h" 24 #include "mlir/Dialect/Complex/IR/Complex.h" 25 #include "mlir/Dialect/Math/IR/Math.h" 26 #include "mlir/Dialect/MemRef/IR/MemRef.h" 27 #include "mlir/Dialect/SCF/IR/SCF.h" 28 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 29 #include "mlir/Dialect/Vector/IR/VectorOps.h" 30 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 31 #include "mlir/IR/Matchers.h" 32 33 using namespace mlir; 34 using namespace mlir::sparse_tensor; 35 36 namespace { 37 38 /// Target SIMD properties: 39 /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16) 40 /// enableVLAVectorization: enables scalable vectors (viz. ARMSve) 41 /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency 42 struct VL { 43 unsigned vectorLength; 44 bool enableVLAVectorization; 45 bool enableSIMDIndex32; 46 }; 47 48 /// Helper test for invariant value (defined outside given block). 49 static bool isInvariantValue(Value val, Block *block) { 50 return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block; 51 } 52 53 /// Helper test for invariant argument (defined outside given block). 54 static bool isInvariantArg(BlockArgument arg, Block *block) { 55 return arg.getOwner() != block; 56 } 57 58 /// Constructs vector type for element type. 59 static VectorType vectorType(VL vl, Type etp) { 60 return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization); 61 } 62 63 /// Constructs vector type from a memref value. 64 static VectorType vectorType(VL vl, Value mem) { 65 return vectorType(vl, getMemRefType(mem).getElementType()); 66 } 67 68 /// Constructs vector iteration mask. 69 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl, 70 Value iv, Value lo, Value hi, Value step) { 71 VectorType mtp = vectorType(vl, rewriter.getI1Type()); 72 // Special case if the vector length evenly divides the trip count (for 73 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 74 // so that all subsequent masked memory operations are immediately folded 75 // into unconditional memory operations. 76 IntegerAttr loInt, hiInt, stepInt; 77 if (matchPattern(lo, m_Constant(&loInt)) && 78 matchPattern(hi, m_Constant(&hiInt)) && 79 matchPattern(step, m_Constant(&stepInt))) { 80 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) { 81 Value trueVal = constantI1(rewriter, loc, true); 82 return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal); 83 } 84 } 85 // Otherwise, generate a vector mask that avoids overrunning the upperbound 86 // during vector execution. Here we rely on subsequent loop optimizations to 87 // avoid executing the mask in all iterations, for example, by splitting the 88 // loop into an unconditional vector loop and a scalar cleanup loop. 89 auto min = AffineMap::get( 90 /*dimCount=*/2, /*symbolCount=*/1, 91 {rewriter.getAffineSymbolExpr(0), 92 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 93 rewriter.getContext()); 94 Value end = rewriter.createOrFold<affine::AffineMinOp>( 95 loc, min, ValueRange{hi, iv, step}); 96 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 97 } 98 99 /// Generates a vectorized invariant. Here we rely on subsequent loop 100 /// optimizations to hoist the invariant broadcast out of the vector loop. 101 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl, 102 Value val) { 103 VectorType vtp = vectorType(vl, val.getType()); 104 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 105 } 106 107 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi], 108 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note 109 /// that the sparsifier can only generate indirect loads in 110 /// the last index, i.e. back(). 111 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl, 112 Value mem, ArrayRef<Value> idxs, Value vmask) { 113 VectorType vtp = vectorType(vl, mem); 114 Value pass = constantZero(rewriter, loc, vtp); 115 if (llvm::isa<VectorType>(idxs.back().getType())) { 116 SmallVector<Value> scalarArgs(idxs); 117 Value indexVec = idxs.back(); 118 scalarArgs.back() = constantIndex(rewriter, loc, 0); 119 return rewriter.create<vector::GatherOp>(loc, vtp, mem, scalarArgs, 120 indexVec, vmask, pass); 121 } 122 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask, 123 pass); 124 } 125 126 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs 127 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note 128 /// that the sparsifier can only generate indirect stores in 129 /// the last index, i.e. back(). 130 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem, 131 ArrayRef<Value> idxs, Value vmask, Value rhs) { 132 if (llvm::isa<VectorType>(idxs.back().getType())) { 133 SmallVector<Value> scalarArgs(idxs); 134 Value indexVec = idxs.back(); 135 scalarArgs.back() = constantIndex(rewriter, loc, 0); 136 rewriter.create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask, 137 rhs); 138 return; 139 } 140 rewriter.create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs); 141 } 142 143 /// Detects a vectorizable reduction operations and returns the 144 /// combining kind of reduction on success in `kind`. 145 static bool isVectorizableReduction(Value red, Value iter, 146 vector::CombiningKind &kind) { 147 if (auto addf = red.getDefiningOp<arith::AddFOp>()) { 148 kind = vector::CombiningKind::ADD; 149 return addf->getOperand(0) == iter || addf->getOperand(1) == iter; 150 } 151 if (auto addi = red.getDefiningOp<arith::AddIOp>()) { 152 kind = vector::CombiningKind::ADD; 153 return addi->getOperand(0) == iter || addi->getOperand(1) == iter; 154 } 155 if (auto subf = red.getDefiningOp<arith::SubFOp>()) { 156 kind = vector::CombiningKind::ADD; 157 return subf->getOperand(0) == iter; 158 } 159 if (auto subi = red.getDefiningOp<arith::SubIOp>()) { 160 kind = vector::CombiningKind::ADD; 161 return subi->getOperand(0) == iter; 162 } 163 if (auto mulf = red.getDefiningOp<arith::MulFOp>()) { 164 kind = vector::CombiningKind::MUL; 165 return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter; 166 } 167 if (auto muli = red.getDefiningOp<arith::MulIOp>()) { 168 kind = vector::CombiningKind::MUL; 169 return muli->getOperand(0) == iter || muli->getOperand(1) == iter; 170 } 171 if (auto andi = red.getDefiningOp<arith::AndIOp>()) { 172 kind = vector::CombiningKind::AND; 173 return andi->getOperand(0) == iter || andi->getOperand(1) == iter; 174 } 175 if (auto ori = red.getDefiningOp<arith::OrIOp>()) { 176 kind = vector::CombiningKind::OR; 177 return ori->getOperand(0) == iter || ori->getOperand(1) == iter; 178 } 179 if (auto xori = red.getDefiningOp<arith::XOrIOp>()) { 180 kind = vector::CombiningKind::XOR; 181 return xori->getOperand(0) == iter || xori->getOperand(1) == iter; 182 } 183 return false; 184 } 185 186 /// Generates an initial value for a vector reduction, following the scheme 187 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 188 /// initial scalar value is correctly embedded in the vector reduction value, 189 /// and a straightforward horizontal reduction will complete the operation. 190 /// Value 'r' denotes the initial value of the reduction outside the loop. 191 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc, 192 Value red, Value iter, Value r, 193 VectorType vtp) { 194 vector::CombiningKind kind; 195 if (!isVectorizableReduction(red, iter, kind)) 196 llvm_unreachable("unknown reduction"); 197 switch (kind) { 198 case vector::CombiningKind::ADD: 199 case vector::CombiningKind::XOR: 200 // Initialize reduction vector to: | 0 | .. | 0 | r | 201 return rewriter.create<vector::InsertElementOp>( 202 loc, r, constantZero(rewriter, loc, vtp), 203 constantIndex(rewriter, loc, 0)); 204 case vector::CombiningKind::MUL: 205 // Initialize reduction vector to: | 1 | .. | 1 | r | 206 return rewriter.create<vector::InsertElementOp>( 207 loc, r, constantOne(rewriter, loc, vtp), 208 constantIndex(rewriter, loc, 0)); 209 case vector::CombiningKind::AND: 210 case vector::CombiningKind::OR: 211 // Initialize reduction vector to: | r | .. | r | r | 212 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 213 default: 214 break; 215 } 216 llvm_unreachable("unknown reduction kind"); 217 } 218 219 /// This method is called twice to analyze and rewrite the given subscripts. 220 /// The first call (!codegen) does the analysis. Then, on success, the second 221 /// call (codegen) yields the proper vector form in the output parameter 222 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code 223 /// stay in sync. Note that the analyis part is simple because the sparsifier 224 /// only generates relatively simple subscript expressions. 225 /// 226 /// See https://llvm.org/docs/GetElementPtr.html for some background on 227 /// the complications described below. 228 /// 229 /// We need to generate a position/coordinate load from the sparse storage 230 /// scheme. Narrower data types need to be zero extended before casting 231 /// the value into the `index` type used for looping and indexing. 232 /// 233 /// For the scalar case, subscripts simply zero extend narrower indices 234 /// into 64-bit values before casting to an index type without a performance 235 /// penalty. Indices that already are 64-bit, in theory, cannot express the 236 /// full range since the LLVM backend defines addressing in terms of an 237 /// unsigned pointer/signed index pair. 238 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, 239 VL vl, ValueRange subs, bool codegen, 240 Value vmask, SmallVectorImpl<Value> &idxs) { 241 unsigned d = 0; 242 unsigned dim = subs.size(); 243 Block *block = &forOp.getRegion().front(); 244 for (auto sub : subs) { 245 bool innermost = ++d == dim; 246 // Invariant subscripts in outer dimensions simply pass through. 247 // Note that we rely on LICM to hoist loads where all subscripts 248 // are invariant in the innermost loop. 249 // Example: 250 // a[inv][i] for inv 251 if (isInvariantValue(sub, block)) { 252 if (innermost) 253 return false; 254 if (codegen) 255 idxs.push_back(sub); 256 continue; // success so far 257 } 258 // Invariant block arguments (including outer loop indices) in outer 259 // dimensions simply pass through. Direct loop indices in the 260 // innermost loop simply pass through as well. 261 // Example: 262 // a[i][j] for both i and j 263 if (auto arg = llvm::dyn_cast<BlockArgument>(sub)) { 264 if (isInvariantArg(arg, block) == innermost) 265 return false; 266 if (codegen) 267 idxs.push_back(sub); 268 continue; // success so far 269 } 270 // Look under the hood of casting. 271 auto cast = sub; 272 while (true) { 273 if (auto icast = cast.getDefiningOp<arith::IndexCastOp>()) 274 cast = icast->getOperand(0); 275 else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>()) 276 cast = ecast->getOperand(0); 277 else 278 break; 279 } 280 // Since the index vector is used in a subsequent gather/scatter 281 // operations, which effectively defines an unsigned pointer + signed 282 // index, we must zero extend the vector to an index width. For 8-bit 283 // and 16-bit values, an 32-bit index width suffices. For 32-bit values, 284 // zero extending the elements into 64-bit loses some performance since 285 // the 32-bit indexed gather/scatter is more efficient than the 64-bit 286 // index variant (if the negative 32-bit index space is unused, the 287 // enableSIMDIndex32 flag can preserve this performance). For 64-bit 288 // values, there is no good way to state that the indices are unsigned, 289 // which creates the potential of incorrect address calculations in the 290 // unlikely case we need such extremely large offsets. 291 // Example: 292 // a[ ind[i] ] 293 if (auto load = cast.getDefiningOp<memref::LoadOp>()) { 294 if (!innermost) 295 return false; 296 if (codegen) { 297 SmallVector<Value> idxs2(load.getIndices()); // no need to analyze 298 Location loc = forOp.getLoc(); 299 Value vload = 300 genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask); 301 Type etp = llvm::cast<VectorType>(vload.getType()).getElementType(); 302 if (!llvm::isa<IndexType>(etp)) { 303 if (etp.getIntOrFloatBitWidth() < 32) 304 vload = rewriter.create<arith::ExtUIOp>( 305 loc, vectorType(vl, rewriter.getI32Type()), vload); 306 else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32) 307 vload = rewriter.create<arith::ExtUIOp>( 308 loc, vectorType(vl, rewriter.getI64Type()), vload); 309 } 310 idxs.push_back(vload); 311 } 312 continue; // success so far 313 } 314 // Address calculation 'i = add inv, idx' (after LICM). 315 // Example: 316 // a[base + i] 317 if (auto load = cast.getDefiningOp<arith::AddIOp>()) { 318 Value inv = load.getOperand(0); 319 Value idx = load.getOperand(1); 320 // Swap non-invariant. 321 if (!isInvariantValue(inv, block)) { 322 inv = idx; 323 idx = load.getOperand(0); 324 } 325 // Inspect. 326 if (isInvariantValue(inv, block)) { 327 if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) { 328 if (isInvariantArg(arg, block) || !innermost) 329 return false; 330 if (codegen) 331 idxs.push_back( 332 rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx)); 333 continue; // success so far 334 } 335 } 336 } 337 return false; 338 } 339 return true; 340 } 341 342 #define UNAOP(xxx) \ 343 if (isa<xxx>(def)) { \ 344 if (codegen) \ 345 vexp = rewriter.create<xxx>(loc, vx); \ 346 return true; \ 347 } 348 349 #define TYPEDUNAOP(xxx) \ 350 if (auto x = dyn_cast<xxx>(def)) { \ 351 if (codegen) { \ 352 VectorType vtp = vectorType(vl, x.getType()); \ 353 vexp = rewriter.create<xxx>(loc, vtp, vx); \ 354 } \ 355 return true; \ 356 } 357 358 #define BINOP(xxx) \ 359 if (isa<xxx>(def)) { \ 360 if (codegen) \ 361 vexp = rewriter.create<xxx>(loc, vx, vy); \ 362 return true; \ 363 } 364 365 /// This method is called twice to analyze and rewrite the given expression. 366 /// The first call (!codegen) does the analysis. Then, on success, the second 367 /// call (codegen) yields the proper vector form in the output parameter 'vexp'. 368 /// This mechanism ensures that analysis and rewriting code stay in sync. Note 369 /// that the analyis part is simple because the sparsifier only generates 370 /// relatively simple expressions inside the for-loops. 371 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 372 Value exp, bool codegen, Value vmask, Value &vexp) { 373 Location loc = forOp.getLoc(); 374 // Reject unsupported types. 375 if (!VectorType::isValidElementType(exp.getType())) 376 return false; 377 // A block argument is invariant/reduction/index. 378 if (auto arg = llvm::dyn_cast<BlockArgument>(exp)) { 379 if (arg == forOp.getInductionVar()) { 380 // We encountered a single, innermost index inside the computation, 381 // such as a[i] = i, which must convert to [i, i+1, ...]. 382 if (codegen) { 383 VectorType vtp = vectorType(vl, arg.getType()); 384 Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg); 385 Value incr = rewriter.create<vector::StepOp>(loc, vtp); 386 vexp = rewriter.create<arith::AddIOp>(loc, veci, incr); 387 } 388 return true; 389 } 390 // An invariant or reduction. In both cases, we treat this as an 391 // invariant value, and rely on later replacing and folding to 392 // construct a proper reduction chain for the latter case. 393 if (codegen) 394 vexp = genVectorInvariantValue(rewriter, vl, exp); 395 return true; 396 } 397 // Something defined outside the loop-body is invariant. 398 Operation *def = exp.getDefiningOp(); 399 Block *block = &forOp.getRegion().front(); 400 if (def->getBlock() != block) { 401 if (codegen) 402 vexp = genVectorInvariantValue(rewriter, vl, exp); 403 return true; 404 } 405 // Proper load operations. These are either values involved in the 406 // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi], 407 // or coordinate values inside the computation that are now fetched from 408 // the sparse storage coordinates arrays, such as a[i] = i becomes 409 // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index 410 // and 'hi = lo + vl - 1'. 411 if (auto load = dyn_cast<memref::LoadOp>(def)) { 412 auto subs = load.getIndices(); 413 SmallVector<Value> idxs; 414 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) { 415 if (codegen) 416 vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask); 417 return true; 418 } 419 return false; 420 } 421 // Inside loop-body unary and binary operations. Note that it would be 422 // nicer if we could somehow test and build the operations in a more 423 // concise manner than just listing them all (although this way we know 424 // for certain that they can vectorize). 425 // 426 // TODO: avoid visiting CSEs multiple times 427 // 428 if (def->getNumOperands() == 1) { 429 Value vx; 430 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 431 vx)) { 432 UNAOP(math::AbsFOp) 433 UNAOP(math::AbsIOp) 434 UNAOP(math::CeilOp) 435 UNAOP(math::FloorOp) 436 UNAOP(math::SqrtOp) 437 UNAOP(math::ExpM1Op) 438 UNAOP(math::Log1pOp) 439 UNAOP(math::SinOp) 440 UNAOP(math::TanhOp) 441 UNAOP(arith::NegFOp) 442 TYPEDUNAOP(arith::TruncFOp) 443 TYPEDUNAOP(arith::ExtFOp) 444 TYPEDUNAOP(arith::FPToSIOp) 445 TYPEDUNAOP(arith::FPToUIOp) 446 TYPEDUNAOP(arith::SIToFPOp) 447 TYPEDUNAOP(arith::UIToFPOp) 448 TYPEDUNAOP(arith::ExtSIOp) 449 TYPEDUNAOP(arith::ExtUIOp) 450 TYPEDUNAOP(arith::IndexCastOp) 451 TYPEDUNAOP(arith::TruncIOp) 452 TYPEDUNAOP(arith::BitcastOp) 453 // TODO: complex? 454 } 455 } else if (def->getNumOperands() == 2) { 456 Value vx, vy; 457 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 458 vx) && 459 vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask, 460 vy)) { 461 // We only accept shift-by-invariant (where the same shift factor applies 462 // to all packed elements). In the vector dialect, this is still 463 // represented with an expanded vector at the right-hand-side, however, 464 // so that we do not have to special case the code generation. 465 if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) || 466 isa<arith::ShRSIOp>(def)) { 467 Value shiftFactor = def->getOperand(1); 468 if (!isInvariantValue(shiftFactor, block)) 469 return false; 470 } 471 // Generate code. 472 BINOP(arith::MulFOp) 473 BINOP(arith::MulIOp) 474 BINOP(arith::DivFOp) 475 BINOP(arith::DivSIOp) 476 BINOP(arith::DivUIOp) 477 BINOP(arith::AddFOp) 478 BINOP(arith::AddIOp) 479 BINOP(arith::SubFOp) 480 BINOP(arith::SubIOp) 481 BINOP(arith::AndIOp) 482 BINOP(arith::OrIOp) 483 BINOP(arith::XOrIOp) 484 BINOP(arith::ShLIOp) 485 BINOP(arith::ShRUIOp) 486 BINOP(arith::ShRSIOp) 487 // TODO: complex? 488 } 489 } 490 return false; 491 } 492 493 #undef UNAOP 494 #undef TYPEDUNAOP 495 #undef BINOP 496 497 /// This method is called twice to analyze and rewrite the given for-loop. 498 /// The first call (!codegen) does the analysis. Then, on success, the second 499 /// call (codegen) rewriters the IR into vector form. This mechanism ensures 500 /// that analysis and rewriting code stay in sync. 501 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 502 bool codegen) { 503 Block &block = forOp.getRegion().front(); 504 // For loops with single yield statement (as below) could be generated 505 // when custom reduce is used with unary operation. 506 // for (...) 507 // yield c_0 508 if (block.getOperations().size() <= 1) 509 return false; 510 511 Location loc = forOp.getLoc(); 512 scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator()); 513 auto &last = *++block.rbegin(); 514 scf::ForOp forOpNew; 515 516 // Perform initial set up during codegen (we know that the first analysis 517 // pass was successful). For reductions, we need to construct a completely 518 // new for-loop, since the incoming and outgoing reduction type 519 // changes into SIMD form. For stores, we can simply adjust the stride 520 // and insert in the existing for-loop. In both cases, we set up a vector 521 // mask for all operations which takes care of confining vectors to 522 // the original iteration space (later cleanup loops or other 523 // optimizations can take care of those). 524 Value vmask; 525 if (codegen) { 526 Value step = constantIndex(rewriter, loc, vl.vectorLength); 527 if (vl.enableVLAVectorization) { 528 Value vscale = 529 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); 530 step = rewriter.create<arith::MulIOp>(loc, vscale, step); 531 } 532 if (!yield.getResults().empty()) { 533 Value init = forOp.getInitArgs()[0]; 534 VectorType vtp = vectorType(vl, init.getType()); 535 Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0), 536 forOp.getRegionIterArg(0), init, vtp); 537 forOpNew = rewriter.create<scf::ForOp>( 538 loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit); 539 forOpNew->setAttr( 540 LoopEmitter::getLoopEmitterLoopAttrName(), 541 forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName())); 542 rewriter.setInsertionPointToStart(forOpNew.getBody()); 543 } else { 544 rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); }); 545 rewriter.setInsertionPoint(yield); 546 } 547 vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(), 548 forOp.getLowerBound(), forOp.getUpperBound(), step); 549 } 550 551 // Sparse for-loops either are terminated by a non-empty yield operation 552 // (reduction loop) or otherwise by a store operation (pararallel loop). 553 if (!yield.getResults().empty()) { 554 // Analyze/vectorize reduction. 555 if (yield->getNumOperands() != 1) 556 return false; 557 Value red = yield->getOperand(0); 558 Value iter = forOp.getRegionIterArg(0); 559 vector::CombiningKind kind; 560 Value vrhs; 561 if (isVectorizableReduction(red, iter, kind) && 562 vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) { 563 if (codegen) { 564 Value partial = forOpNew.getResult(0); 565 Value vpass = genVectorInvariantValue(rewriter, vl, iter); 566 Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass); 567 rewriter.create<scf::YieldOp>(loc, vred); 568 rewriter.setInsertionPointAfter(forOpNew); 569 Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial); 570 // Now do some relinking (last one is not completely type safe 571 // but all bad ones are removed right away). This also folds away 572 // nop broadcast operations. 573 rewriter.replaceAllUsesWith(forOp.getResult(0), vres); 574 rewriter.replaceAllUsesWith(forOp.getInductionVar(), 575 forOpNew.getInductionVar()); 576 rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0), 577 forOpNew.getRegionIterArg(0)); 578 rewriter.eraseOp(forOp); 579 } 580 return true; 581 } 582 } else if (auto store = dyn_cast<memref::StoreOp>(last)) { 583 // Analyze/vectorize store operation. 584 auto subs = store.getIndices(); 585 SmallVector<Value> idxs; 586 Value rhs = store.getValue(); 587 Value vrhs; 588 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) && 589 vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) { 590 if (codegen) { 591 genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs); 592 rewriter.eraseOp(store); 593 } 594 return true; 595 } 596 } 597 598 assert(!codegen && "cannot call codegen when analysis failed"); 599 return false; 600 } 601 602 /// Basic for-loop vectorizer. 603 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> { 604 public: 605 using OpRewritePattern<scf::ForOp>::OpRewritePattern; 606 607 ForOpRewriter(MLIRContext *context, unsigned vectorLength, 608 bool enableVLAVectorization, bool enableSIMDIndex32) 609 : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization, 610 enableSIMDIndex32} {} 611 612 LogicalResult matchAndRewrite(scf::ForOp op, 613 PatternRewriter &rewriter) const override { 614 // Check for single block, unit-stride for-loop that is generated by 615 // sparsifier, which means no data dependence analysis is required, 616 // and its loop-body is very restricted in form. 617 if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) || 618 !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) 619 return failure(); 620 // Analyze (!codegen) and rewrite (codegen) loop-body. 621 if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) && 622 vectorizeStmt(rewriter, op, vl, /*codegen=*/true)) 623 return success(); 624 return failure(); 625 } 626 627 private: 628 const VL vl; 629 }; 630 631 /// Reduction chain cleanup. 632 /// v = for { } 633 /// s = vsum(v) v = for { } 634 /// u = expand(s) -> for (v) { } 635 /// for (u) { } 636 template <typename VectorOp> 637 struct ReducChainRewriter : public OpRewritePattern<VectorOp> { 638 public: 639 using OpRewritePattern<VectorOp>::OpRewritePattern; 640 641 LogicalResult matchAndRewrite(VectorOp op, 642 PatternRewriter &rewriter) const override { 643 Value inp = op.getSource(); 644 if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) { 645 if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) { 646 if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) { 647 rewriter.replaceOp(op, redOp.getVector()); 648 return success(); 649 } 650 } 651 } 652 return failure(); 653 } 654 }; 655 656 } // namespace 657 658 //===----------------------------------------------------------------------===// 659 // Public method for populating vectorization rules. 660 //===----------------------------------------------------------------------===// 661 662 /// Populates the given patterns list with vectorization rules. 663 void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns, 664 unsigned vectorLength, 665 bool enableVLAVectorization, 666 bool enableSIMDIndex32) { 667 assert(vectorLength > 0); 668 vector::populateVectorStepLoweringPatterns(patterns); 669 patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength, 670 enableVLAVectorization, enableSIMDIndex32); 671 patterns.add<ReducChainRewriter<vector::InsertElementOp>, 672 ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext()); 673 } 674