1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass that converts loops generated by the sparse compiler into a form that 10 // can exploit SIMD instructions of the target architecture. Note that this pass 11 // ensures the sparse compiler can generate efficient SIMD (including ArmSVE 12 // support) with proper separation of concerns as far as sparsification and 13 // vectorization is concerned. However, this pass is not the final abstraction 14 // level we want, and not the general vectorizer we want either. It forms a good 15 // stepping stone for incremental future improvements though. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "CodegenUtils.h" 20 #include "LoopEmitter.h" 21 22 #include "mlir/Dialect/Affine/IR/AffineOps.h" 23 #include "mlir/Dialect/Arith/IR/Arith.h" 24 #include "mlir/Dialect/Complex/IR/Complex.h" 25 #include "mlir/Dialect/Math/IR/Math.h" 26 #include "mlir/Dialect/MemRef/IR/MemRef.h" 27 #include "mlir/Dialect/SCF/IR/SCF.h" 28 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 29 #include "mlir/Dialect/Vector/IR/VectorOps.h" 30 #include "mlir/IR/Matchers.h" 31 32 using namespace mlir; 33 using namespace mlir::sparse_tensor; 34 35 namespace { 36 37 /// Target SIMD properties: 38 /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16) 39 /// enableVLAVectorization: enables scalable vectors (viz. ARMSve) 40 /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency 41 struct VL { 42 unsigned vectorLength; 43 bool enableVLAVectorization; 44 bool enableSIMDIndex32; 45 }; 46 47 /// Helper test for invariant value (defined outside given block). 48 static bool isInvariantValue(Value val, Block *block) { 49 return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block; 50 } 51 52 /// Helper test for invariant argument (defined outside given block). 53 static bool isInvariantArg(BlockArgument arg, Block *block) { 54 return arg.getOwner() != block; 55 } 56 57 /// Constructs vector type for element type. 58 static VectorType vectorType(VL vl, Type etp) { 59 return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization); 60 } 61 62 /// Constructs vector type from a memref value. 63 static VectorType vectorType(VL vl, Value mem) { 64 return vectorType(vl, getMemRefType(mem).getElementType()); 65 } 66 67 /// Constructs vector iteration mask. 68 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl, 69 Value iv, Value lo, Value hi, Value step) { 70 VectorType mtp = vectorType(vl, rewriter.getI1Type()); 71 // Special case if the vector length evenly divides the trip count (for 72 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 73 // so that all subsequent masked memory operations are immediately folded 74 // into unconditional memory operations. 75 IntegerAttr loInt, hiInt, stepInt; 76 if (matchPattern(lo, m_Constant(&loInt)) && 77 matchPattern(hi, m_Constant(&hiInt)) && 78 matchPattern(step, m_Constant(&stepInt))) { 79 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) { 80 Value trueVal = constantI1(rewriter, loc, true); 81 return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal); 82 } 83 } 84 // Otherwise, generate a vector mask that avoids overrunning the upperbound 85 // during vector execution. Here we rely on subsequent loop optimizations to 86 // avoid executing the mask in all iterations, for example, by splitting the 87 // loop into an unconditional vector loop and a scalar cleanup loop. 88 auto min = AffineMap::get( 89 /*dimCount=*/2, /*symbolCount=*/1, 90 {rewriter.getAffineSymbolExpr(0), 91 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 92 rewriter.getContext()); 93 Value end = rewriter.createOrFold<affine::AffineMinOp>( 94 loc, min, ValueRange{hi, iv, step}); 95 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 96 } 97 98 /// Generates a vectorized invariant. Here we rely on subsequent loop 99 /// optimizations to hoist the invariant broadcast out of the vector loop. 100 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl, 101 Value val) { 102 VectorType vtp = vectorType(vl, val.getType()); 103 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 104 } 105 106 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi], 107 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note 108 /// that the sparse compiler can only generate indirect loads in 109 /// the last index, i.e. back(). 110 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl, 111 Value mem, ArrayRef<Value> idxs, Value vmask) { 112 VectorType vtp = vectorType(vl, mem); 113 Value pass = constantZero(rewriter, loc, vtp); 114 if (llvm::isa<VectorType>(idxs.back().getType())) { 115 SmallVector<Value> scalarArgs(idxs.begin(), idxs.end()); 116 Value indexVec = idxs.back(); 117 scalarArgs.back() = constantIndex(rewriter, loc, 0); 118 return rewriter.create<vector::GatherOp>(loc, vtp, mem, scalarArgs, 119 indexVec, vmask, pass); 120 } 121 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask, 122 pass); 123 } 124 125 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs 126 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note 127 /// that the sparse compiler can only generate indirect stores in 128 /// the last index, i.e. back(). 129 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem, 130 ArrayRef<Value> idxs, Value vmask, Value rhs) { 131 if (llvm::isa<VectorType>(idxs.back().getType())) { 132 SmallVector<Value> scalarArgs(idxs.begin(), idxs.end()); 133 Value indexVec = idxs.back(); 134 scalarArgs.back() = constantIndex(rewriter, loc, 0); 135 rewriter.create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask, 136 rhs); 137 return; 138 } 139 rewriter.create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs); 140 } 141 142 /// Detects a vectorizable reduction operations and returns the 143 /// combining kind of reduction on success in `kind`. 144 static bool isVectorizableReduction(Value red, Value iter, 145 vector::CombiningKind &kind) { 146 if (auto addf = red.getDefiningOp<arith::AddFOp>()) { 147 kind = vector::CombiningKind::ADD; 148 return addf->getOperand(0) == iter || addf->getOperand(1) == iter; 149 } 150 if (auto addi = red.getDefiningOp<arith::AddIOp>()) { 151 kind = vector::CombiningKind::ADD; 152 return addi->getOperand(0) == iter || addi->getOperand(1) == iter; 153 } 154 if (auto subf = red.getDefiningOp<arith::SubFOp>()) { 155 kind = vector::CombiningKind::ADD; 156 return subf->getOperand(0) == iter; 157 } 158 if (auto subi = red.getDefiningOp<arith::SubIOp>()) { 159 kind = vector::CombiningKind::ADD; 160 return subi->getOperand(0) == iter; 161 } 162 if (auto mulf = red.getDefiningOp<arith::MulFOp>()) { 163 kind = vector::CombiningKind::MUL; 164 return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter; 165 } 166 if (auto muli = red.getDefiningOp<arith::MulIOp>()) { 167 kind = vector::CombiningKind::MUL; 168 return muli->getOperand(0) == iter || muli->getOperand(1) == iter; 169 } 170 if (auto andi = red.getDefiningOp<arith::AndIOp>()) { 171 kind = vector::CombiningKind::AND; 172 return andi->getOperand(0) == iter || andi->getOperand(1) == iter; 173 } 174 if (auto ori = red.getDefiningOp<arith::OrIOp>()) { 175 kind = vector::CombiningKind::OR; 176 return ori->getOperand(0) == iter || ori->getOperand(1) == iter; 177 } 178 if (auto xori = red.getDefiningOp<arith::XOrIOp>()) { 179 kind = vector::CombiningKind::XOR; 180 return xori->getOperand(0) == iter || xori->getOperand(1) == iter; 181 } 182 return false; 183 } 184 185 /// Generates an initial value for a vector reduction, following the scheme 186 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 187 /// initial scalar value is correctly embedded in the vector reduction value, 188 /// and a straightforward horizontal reduction will complete the operation. 189 /// Value 'r' denotes the initial value of the reduction outside the loop. 190 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc, 191 Value red, Value iter, Value r, 192 VectorType vtp) { 193 vector::CombiningKind kind; 194 if (!isVectorizableReduction(red, iter, kind)) 195 llvm_unreachable("unknown reduction"); 196 switch (kind) { 197 case vector::CombiningKind::ADD: 198 case vector::CombiningKind::XOR: 199 // Initialize reduction vector to: | 0 | .. | 0 | r | 200 return rewriter.create<vector::InsertElementOp>( 201 loc, r, constantZero(rewriter, loc, vtp), 202 constantIndex(rewriter, loc, 0)); 203 case vector::CombiningKind::MUL: 204 // Initialize reduction vector to: | 1 | .. | 1 | r | 205 return rewriter.create<vector::InsertElementOp>( 206 loc, r, constantOne(rewriter, loc, vtp), 207 constantIndex(rewriter, loc, 0)); 208 case vector::CombiningKind::AND: 209 case vector::CombiningKind::OR: 210 // Initialize reduction vector to: | r | .. | r | r | 211 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 212 default: 213 break; 214 } 215 llvm_unreachable("unknown reduction kind"); 216 } 217 218 /// This method is called twice to analyze and rewrite the given subscripts. 219 /// The first call (!codegen) does the analysis. Then, on success, the second 220 /// call (codegen) yields the proper vector form in the output parameter 221 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code 222 /// stay in sync. Note that the analyis part is simple because the sparse 223 /// compiler only generates relatively simple subscript expressions. 224 /// 225 /// See https://llvm.org/docs/GetElementPtr.html for some background on 226 /// the complications described below. 227 /// 228 /// We need to generate a position/coordinate load from the sparse storage 229 /// scheme. Narrower data types need to be zero extended before casting 230 /// the value into the `index` type used for looping and indexing. 231 /// 232 /// For the scalar case, subscripts simply zero extend narrower indices 233 /// into 64-bit values before casting to an index type without a performance 234 /// penalty. Indices that already are 64-bit, in theory, cannot express the 235 /// full range since the LLVM backend defines addressing in terms of an 236 /// unsigned pointer/signed index pair. 237 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, 238 VL vl, ValueRange subs, bool codegen, 239 Value vmask, SmallVectorImpl<Value> &idxs) { 240 unsigned d = 0; 241 unsigned dim = subs.size(); 242 Block *block = &forOp.getRegion().front(); 243 for (auto sub : subs) { 244 bool innermost = ++d == dim; 245 // Invariant subscripts in outer dimensions simply pass through. 246 // Note that we rely on LICM to hoist loads where all subscripts 247 // are invariant in the innermost loop. 248 // Example: 249 // a[inv][i] for inv 250 if (isInvariantValue(sub, block)) { 251 if (innermost) 252 return false; 253 if (codegen) 254 idxs.push_back(sub); 255 continue; // success so far 256 } 257 // Invariant block arguments (including outer loop indices) in outer 258 // dimensions simply pass through. Direct loop indices in the 259 // innermost loop simply pass through as well. 260 // Example: 261 // a[i][j] for both i and j 262 if (auto arg = llvm::dyn_cast<BlockArgument>(sub)) { 263 if (isInvariantArg(arg, block) == innermost) 264 return false; 265 if (codegen) 266 idxs.push_back(sub); 267 continue; // success so far 268 } 269 // Look under the hood of casting. 270 auto cast = sub; 271 while (true) { 272 if (auto icast = cast.getDefiningOp<arith::IndexCastOp>()) 273 cast = icast->getOperand(0); 274 else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>()) 275 cast = ecast->getOperand(0); 276 else 277 break; 278 } 279 // Since the index vector is used in a subsequent gather/scatter 280 // operations, which effectively defines an unsigned pointer + signed 281 // index, we must zero extend the vector to an index width. For 8-bit 282 // and 16-bit values, an 32-bit index width suffices. For 32-bit values, 283 // zero extending the elements into 64-bit loses some performance since 284 // the 32-bit indexed gather/scatter is more efficient than the 64-bit 285 // index variant (if the negative 32-bit index space is unused, the 286 // enableSIMDIndex32 flag can preserve this performance). For 64-bit 287 // values, there is no good way to state that the indices are unsigned, 288 // which creates the potential of incorrect address calculations in the 289 // unlikely case we need such extremely large offsets. 290 // Example: 291 // a[ ind[i] ] 292 if (auto load = cast.getDefiningOp<memref::LoadOp>()) { 293 if (!innermost) 294 return false; 295 if (codegen) { 296 SmallVector<Value> idxs2(load.getIndices()); // no need to analyze 297 Location loc = forOp.getLoc(); 298 Value vload = 299 genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask); 300 Type etp = llvm::cast<VectorType>(vload.getType()).getElementType(); 301 if (!llvm::isa<IndexType>(etp)) { 302 if (etp.getIntOrFloatBitWidth() < 32) 303 vload = rewriter.create<arith::ExtUIOp>( 304 loc, vectorType(vl, rewriter.getI32Type()), vload); 305 else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32) 306 vload = rewriter.create<arith::ExtUIOp>( 307 loc, vectorType(vl, rewriter.getI64Type()), vload); 308 } 309 idxs.push_back(vload); 310 } 311 continue; // success so far 312 } 313 // Address calculation 'i = add inv, idx' (after LICM). 314 // Example: 315 // a[base + i] 316 if (auto load = cast.getDefiningOp<arith::AddIOp>()) { 317 Value inv = load.getOperand(0); 318 Value idx = load.getOperand(1); 319 if (isInvariantValue(inv, block)) { 320 if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) { 321 if (isInvariantArg(arg, block) || !innermost) 322 return false; 323 if (codegen) 324 idxs.push_back( 325 rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx)); 326 continue; // success so far 327 } 328 } 329 } 330 return false; 331 } 332 return true; 333 } 334 335 #define UNAOP(xxx) \ 336 if (isa<xxx>(def)) { \ 337 if (codegen) \ 338 vexp = rewriter.create<xxx>(loc, vx); \ 339 return true; \ 340 } 341 342 #define TYPEDUNAOP(xxx) \ 343 if (auto x = dyn_cast<xxx>(def)) { \ 344 if (codegen) { \ 345 VectorType vtp = vectorType(vl, x.getType()); \ 346 vexp = rewriter.create<xxx>(loc, vtp, vx); \ 347 } \ 348 return true; \ 349 } 350 351 #define BINOP(xxx) \ 352 if (isa<xxx>(def)) { \ 353 if (codegen) \ 354 vexp = rewriter.create<xxx>(loc, vx, vy); \ 355 return true; \ 356 } 357 358 /// This method is called twice to analyze and rewrite the given expression. 359 /// The first call (!codegen) does the analysis. Then, on success, the second 360 /// call (codegen) yields the proper vector form in the output parameter 'vexp'. 361 /// This mechanism ensures that analysis and rewriting code stay in sync. Note 362 /// that the analyis part is simple because the sparse compiler only generates 363 /// relatively simple expressions inside the for-loops. 364 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 365 Value exp, bool codegen, Value vmask, Value &vexp) { 366 Location loc = forOp.getLoc(); 367 // Reject unsupported types. 368 if (!VectorType::isValidElementType(exp.getType())) 369 return false; 370 // A block argument is invariant/reduction/index. 371 if (auto arg = llvm::dyn_cast<BlockArgument>(exp)) { 372 if (arg == forOp.getInductionVar()) { 373 // We encountered a single, innermost index inside the computation, 374 // such as a[i] = i, which must convert to [i, i+1, ...]. 375 if (codegen) { 376 VectorType vtp = vectorType(vl, arg.getType()); 377 Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg); 378 Value incr; 379 if (vl.enableVLAVectorization) { 380 Type stepvty = vectorType(vl, rewriter.getI64Type()); 381 Value stepv = rewriter.create<LLVM::StepVectorOp>(loc, stepvty); 382 incr = rewriter.create<arith::IndexCastOp>(loc, vtp, stepv); 383 } else { 384 SmallVector<APInt> integers; 385 for (unsigned i = 0, l = vl.vectorLength; i < l; i++) 386 integers.push_back(APInt(/*width=*/64, i)); 387 auto values = DenseElementsAttr::get(vtp, integers); 388 incr = rewriter.create<arith::ConstantOp>(loc, vtp, values); 389 } 390 vexp = rewriter.create<arith::AddIOp>(loc, veci, incr); 391 } 392 return true; 393 } 394 // An invariant or reduction. In both cases, we treat this as an 395 // invariant value, and rely on later replacing and folding to 396 // construct a proper reduction chain for the latter case. 397 if (codegen) 398 vexp = genVectorInvariantValue(rewriter, vl, exp); 399 return true; 400 } 401 // Something defined outside the loop-body is invariant. 402 Operation *def = exp.getDefiningOp(); 403 Block *block = &forOp.getRegion().front(); 404 if (def->getBlock() != block) { 405 if (codegen) 406 vexp = genVectorInvariantValue(rewriter, vl, exp); 407 return true; 408 } 409 // Proper load operations. These are either values involved in the 410 // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi], 411 // or coordinate values inside the computation that are now fetched from 412 // the sparse storage coordinates arrays, such as a[i] = i becomes 413 // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index 414 // and 'hi = lo + vl - 1'. 415 if (auto load = dyn_cast<memref::LoadOp>(def)) { 416 auto subs = load.getIndices(); 417 SmallVector<Value> idxs; 418 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) { 419 if (codegen) 420 vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask); 421 return true; 422 } 423 return false; 424 } 425 // Inside loop-body unary and binary operations. Note that it would be 426 // nicer if we could somehow test and build the operations in a more 427 // concise manner than just listing them all (although this way we know 428 // for certain that they can vectorize). 429 // 430 // TODO: avoid visiting CSEs multiple times 431 // 432 if (def->getNumOperands() == 1) { 433 Value vx; 434 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 435 vx)) { 436 UNAOP(math::AbsFOp) 437 UNAOP(math::AbsIOp) 438 UNAOP(math::CeilOp) 439 UNAOP(math::FloorOp) 440 UNAOP(math::SqrtOp) 441 UNAOP(math::ExpM1Op) 442 UNAOP(math::Log1pOp) 443 UNAOP(math::SinOp) 444 UNAOP(math::TanhOp) 445 UNAOP(arith::NegFOp) 446 TYPEDUNAOP(arith::TruncFOp) 447 TYPEDUNAOP(arith::ExtFOp) 448 TYPEDUNAOP(arith::FPToSIOp) 449 TYPEDUNAOP(arith::FPToUIOp) 450 TYPEDUNAOP(arith::SIToFPOp) 451 TYPEDUNAOP(arith::UIToFPOp) 452 TYPEDUNAOP(arith::ExtSIOp) 453 TYPEDUNAOP(arith::ExtUIOp) 454 TYPEDUNAOP(arith::IndexCastOp) 455 TYPEDUNAOP(arith::TruncIOp) 456 TYPEDUNAOP(arith::BitcastOp) 457 // TODO: complex? 458 } 459 } else if (def->getNumOperands() == 2) { 460 Value vx, vy; 461 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 462 vx) && 463 vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask, 464 vy)) { 465 // We only accept shift-by-invariant (where the same shift factor applies 466 // to all packed elements). In the vector dialect, this is still 467 // represented with an expanded vector at the right-hand-side, however, 468 // so that we do not have to special case the code generation. 469 if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) || 470 isa<arith::ShRSIOp>(def)) { 471 Value shiftFactor = def->getOperand(1); 472 if (!isInvariantValue(shiftFactor, block)) 473 return false; 474 } 475 // Generate code. 476 BINOP(arith::MulFOp) 477 BINOP(arith::MulIOp) 478 BINOP(arith::DivFOp) 479 BINOP(arith::DivSIOp) 480 BINOP(arith::DivUIOp) 481 BINOP(arith::AddFOp) 482 BINOP(arith::AddIOp) 483 BINOP(arith::SubFOp) 484 BINOP(arith::SubIOp) 485 BINOP(arith::AndIOp) 486 BINOP(arith::OrIOp) 487 BINOP(arith::XOrIOp) 488 BINOP(arith::ShLIOp) 489 BINOP(arith::ShRUIOp) 490 BINOP(arith::ShRSIOp) 491 // TODO: complex? 492 } 493 } 494 return false; 495 } 496 497 #undef UNAOP 498 #undef TYPEDUNAOP 499 #undef BINOP 500 501 /// This method is called twice to analyze and rewrite the given for-loop. 502 /// The first call (!codegen) does the analysis. Then, on success, the second 503 /// call (codegen) rewriters the IR into vector form. This mechanism ensures 504 /// that analysis and rewriting code stay in sync. 505 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 506 bool codegen) { 507 Block &block = forOp.getRegion().front(); 508 // For loops with single yield statement (as below) could be generated 509 // when custom reduce is used with unary operation. 510 // for (...) 511 // yield c_0 512 if (block.getOperations().size() <= 1) 513 return false; 514 515 Location loc = forOp.getLoc(); 516 scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator()); 517 auto &last = *++block.rbegin(); 518 scf::ForOp forOpNew; 519 520 // Perform initial set up during codegen (we know that the first analysis 521 // pass was successful). For reductions, we need to construct a completely 522 // new for-loop, since the incoming and outgoing reduction type 523 // changes into SIMD form. For stores, we can simply adjust the stride 524 // and insert in the existing for-loop. In both cases, we set up a vector 525 // mask for all operations which takes care of confining vectors to 526 // the original iteration space (later cleanup loops or other 527 // optimizations can take care of those). 528 Value vmask; 529 if (codegen) { 530 Value step = constantIndex(rewriter, loc, vl.vectorLength); 531 if (vl.enableVLAVectorization) { 532 Value vscale = 533 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); 534 step = rewriter.create<arith::MulIOp>(loc, vscale, step); 535 } 536 if (!yield.getResults().empty()) { 537 Value init = forOp.getInitArgs()[0]; 538 VectorType vtp = vectorType(vl, init.getType()); 539 Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0), 540 forOp.getRegionIterArg(0), init, vtp); 541 forOpNew = rewriter.create<scf::ForOp>( 542 loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit); 543 forOpNew->setAttr( 544 LoopEmitter::getLoopEmitterLoopAttrName(), 545 forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName())); 546 rewriter.setInsertionPointToStart(forOpNew.getBody()); 547 } else { 548 rewriter.updateRootInPlace(forOp, [&]() { forOp.setStep(step); }); 549 rewriter.setInsertionPoint(yield); 550 } 551 vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(), 552 forOp.getLowerBound(), forOp.getUpperBound(), step); 553 } 554 555 // Sparse for-loops either are terminated by a non-empty yield operation 556 // (reduction loop) or otherwise by a store operation (pararallel loop). 557 if (!yield.getResults().empty()) { 558 // Analyze/vectorize reduction. 559 if (yield->getNumOperands() != 1) 560 return false; 561 Value red = yield->getOperand(0); 562 Value iter = forOp.getRegionIterArg(0); 563 vector::CombiningKind kind; 564 Value vrhs; 565 if (isVectorizableReduction(red, iter, kind) && 566 vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) { 567 if (codegen) { 568 Value partial = forOpNew.getResult(0); 569 Value vpass = genVectorInvariantValue(rewriter, vl, iter); 570 Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass); 571 rewriter.create<scf::YieldOp>(loc, vred); 572 rewriter.setInsertionPointAfter(forOpNew); 573 Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial); 574 // Now do some relinking (last one is not completely type safe 575 // but all bad ones are removed right away). This also folds away 576 // nop broadcast operations. 577 rewriter.replaceAllUsesWith(forOp.getResult(0), vres); 578 rewriter.replaceAllUsesWith(forOp.getInductionVar(), 579 forOpNew.getInductionVar()); 580 rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0), 581 forOpNew.getRegionIterArg(0)); 582 rewriter.eraseOp(forOp); 583 } 584 return true; 585 } 586 } else if (auto store = dyn_cast<memref::StoreOp>(last)) { 587 // Analyze/vectorize store operation. 588 auto subs = store.getIndices(); 589 SmallVector<Value> idxs; 590 Value rhs = store.getValue(); 591 Value vrhs; 592 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) && 593 vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) { 594 if (codegen) { 595 genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs); 596 rewriter.eraseOp(store); 597 } 598 return true; 599 } 600 } 601 602 assert(!codegen && "cannot call codegen when analysis failed"); 603 return false; 604 } 605 606 /// Basic for-loop vectorizer. 607 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> { 608 public: 609 using OpRewritePattern<scf::ForOp>::OpRewritePattern; 610 611 ForOpRewriter(MLIRContext *context, unsigned vectorLength, 612 bool enableVLAVectorization, bool enableSIMDIndex32) 613 : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization, 614 enableSIMDIndex32} {} 615 616 LogicalResult matchAndRewrite(scf::ForOp op, 617 PatternRewriter &rewriter) const override { 618 // Check for single block, unit-stride for-loop that is generated by 619 // sparse compiler, which means no data dependence analysis is required, 620 // and its loop-body is very restricted in form. 621 if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) || 622 !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) 623 return failure(); 624 // Analyze (!codegen) and rewrite (codegen) loop-body. 625 if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) && 626 vectorizeStmt(rewriter, op, vl, /*codegen=*/true)) 627 return success(); 628 return failure(); 629 } 630 631 private: 632 const VL vl; 633 }; 634 635 /// Reduction chain cleanup. 636 /// v = for { } 637 /// s = vsum(v) v = for { } 638 /// u = expand(s) -> for (v) { } 639 /// for (u) { } 640 template <typename VectorOp> 641 struct ReducChainRewriter : public OpRewritePattern<VectorOp> { 642 public: 643 using OpRewritePattern<VectorOp>::OpRewritePattern; 644 645 LogicalResult matchAndRewrite(VectorOp op, 646 PatternRewriter &rewriter) const override { 647 Value inp = op.getSource(); 648 if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) { 649 if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) { 650 if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) { 651 rewriter.replaceOp(op, redOp.getVector()); 652 return success(); 653 } 654 } 655 } 656 return failure(); 657 } 658 }; 659 660 } // namespace 661 662 //===----------------------------------------------------------------------===// 663 // Public method for populating vectorization rules. 664 //===----------------------------------------------------------------------===// 665 666 /// Populates the given patterns list with vectorization rules. 667 void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns, 668 unsigned vectorLength, 669 bool enableVLAVectorization, 670 bool enableSIMDIndex32) { 671 assert(vectorLength > 0); 672 patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength, 673 enableVLAVectorization, enableSIMDIndex32); 674 patterns.add<ReducChainRewriter<vector::InsertElementOp>, 675 ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext()); 676 } 677