1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass that converts loops generated by the sparse compiler into a form that 10 // can exploit SIMD instructions of the target architecture. Note that this pass 11 // ensures the sparse compiler can generate efficient SIMD (including ArmSVE 12 // support) with proper separation of concerns as far as sparsification and 13 // vectorization is concerned. However, this pass is not the final abstraction 14 // level we want, and not the general vectorizer we want either. It forms a good 15 // stepping stone for incremental future improvements though. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "CodegenUtils.h" 20 21 #include "mlir/Dialect/Affine/IR/AffineOps.h" 22 #include "mlir/Dialect/Arith/IR/Arith.h" 23 #include "mlir/Dialect/Complex/IR/Complex.h" 24 #include "mlir/Dialect/Math/IR/Math.h" 25 #include "mlir/Dialect/MemRef/IR/MemRef.h" 26 #include "mlir/Dialect/SCF/IR/SCF.h" 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 28 #include "mlir/Dialect/Vector/IR/VectorOps.h" 29 #include "mlir/IR/Matchers.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 namespace { 35 36 /// Target SIMD properties: 37 /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16) 38 /// enableVLAVectorization: enables scalable vectors (viz. ARMSve) 39 /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency 40 struct VL { 41 unsigned vectorLength; 42 bool enableVLAVectorization; 43 bool enableSIMDIndex32; 44 }; 45 46 /// Helper to test for given index value. 47 static bool isIntValue(Value val, int64_t idx) { 48 if (auto ival = getConstantIntValue(val)) 49 return *ival == idx; 50 return false; 51 } 52 53 /// Constructs vector type for element type. 54 static VectorType vectorType(VL vl, Type etp) { 55 unsigned numScalableDims = vl.enableVLAVectorization; 56 return VectorType::get(vl.vectorLength, etp, numScalableDims); 57 } 58 59 /// Constructs vector type from pointer. 60 static VectorType vectorType(VL vl, Value ptr) { 61 return vectorType(vl, ptr.getType().cast<MemRefType>().getElementType()); 62 } 63 64 /// Constructs vector iteration mask. 65 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl, 66 Value iv, Value lo, Value hi, Value step) { 67 VectorType mtp = vectorType(vl, rewriter.getI1Type()); 68 // Special case if the vector length evenly divides the trip count (for 69 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 70 // so that all subsequent masked memory operations are immediately folded 71 // into unconditional memory operations. 72 IntegerAttr loInt, hiInt, stepInt; 73 if (matchPattern(lo, m_Constant(&loInt)) && 74 matchPattern(hi, m_Constant(&hiInt)) && 75 matchPattern(step, m_Constant(&stepInt))) { 76 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) { 77 Value trueVal = constantI1(rewriter, loc, true); 78 return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal); 79 } 80 } 81 // Otherwise, generate a vector mask that avoids overrunning the upperbound 82 // during vector execution. Here we rely on subsequent loop optimizations to 83 // avoid executing the mask in all iterations, for example, by splitting the 84 // loop into an unconditional vector loop and a scalar cleanup loop. 85 auto min = AffineMap::get( 86 /*dimCount=*/2, /*symbolCount=*/1, 87 {rewriter.getAffineSymbolExpr(0), 88 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 89 rewriter.getContext()); 90 Value end = 91 rewriter.createOrFold<AffineMinOp>(loc, min, ValueRange{hi, iv, step}); 92 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 93 } 94 95 /// Generates a vectorized invariant. Here we rely on subsequent loop 96 /// optimizations to hoist the invariant broadcast out of the vector loop. 97 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl, 98 Value val) { 99 VectorType vtp = vectorType(vl, val.getType()); 100 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 101 } 102 103 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi], 104 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. 105 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl, 106 Value ptr, ArrayRef<Value> idxs, Value vmask) { 107 VectorType vtp = vectorType(vl, ptr); 108 Value pass = constantZero(rewriter, loc, vtp); 109 if (idxs.back().getType().isa<VectorType>()) { 110 SmallVector<Value> scalarArgs(idxs.begin(), idxs.end()); 111 Value indexVec = idxs.back(); 112 scalarArgs.back() = constantIndex(rewriter, loc, 0); 113 return rewriter.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs, 114 indexVec, vmask, pass); 115 } 116 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, idxs, vmask, 117 pass); 118 } 119 120 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs 121 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. 122 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr, 123 ArrayRef<Value> idxs, Value vmask, Value rhs) { 124 if (idxs.back().getType().isa<VectorType>()) { 125 SmallVector<Value> scalarArgs(idxs.begin(), idxs.end()); 126 Value indexVec = idxs.back(); 127 scalarArgs.back() = constantIndex(rewriter, loc, 0); 128 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, vmask, 129 rhs); 130 return; 131 } 132 rewriter.create<vector::MaskedStoreOp>(loc, ptr, idxs, vmask, rhs); 133 } 134 135 /// Maps operation to combining kind for reduction. 136 static vector::CombiningKind getCombiningKind(Operation *def) { 137 if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def) || 138 isa<arith::SubFOp>(def) || isa<arith::SubIOp>(def)) 139 return vector::CombiningKind::ADD; 140 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) 141 return vector::CombiningKind::MUL; 142 if (isa<arith::AndIOp>(def)) 143 return vector::CombiningKind::AND; 144 if (isa<arith::OrIOp>(def)) 145 return vector::CombiningKind::OR; 146 if (isa<arith::XOrIOp>(def)) 147 return vector::CombiningKind::XOR; 148 llvm_unreachable("unknown reduction kind"); 149 } 150 151 /// Generates an initial value for a vector reduction, following the scheme 152 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 153 /// initial scalar value is correctly embedded in the vector reduction value, 154 /// and a straightforward horizontal reduction will complete the operation. 155 /// The value 'r' denotes the initial value of the accumulator. Value 'rd' 156 /// denotes the accumulation operation, which is solely used here to determine 157 /// the kind of combining reduction (viz. addf -> sum-accumulation). 158 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc, 159 VectorType vtp, Value r, Value rd) { 160 vector::CombiningKind kind = getCombiningKind(rd.getDefiningOp()); 161 switch (kind) { 162 case vector::CombiningKind::ADD: 163 case vector::CombiningKind::XOR: 164 // Initialize reduction vector to: | 0 | .. | 0 | r | 165 return rewriter.create<vector::InsertElementOp>( 166 loc, r, constantZero(rewriter, loc, vtp), 167 constantIndex(rewriter, loc, 0)); 168 case vector::CombiningKind::MUL: 169 // Initialize reduction vector to: | 1 | .. | 1 | r | 170 return rewriter.create<vector::InsertElementOp>( 171 loc, r, constantOne(rewriter, loc, vtp), 172 constantIndex(rewriter, loc, 0)); 173 case vector::CombiningKind::AND: 174 case vector::CombiningKind::OR: 175 // Initialize reduction vector to: | r | .. | r | r | 176 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 177 default: 178 break; 179 } 180 llvm_unreachable("unknown reduction kind"); 181 } 182 183 /// Generates final value for a vector reduction. 184 static Value genVectorReducEnd(PatternRewriter &rewriter, Location loc, 185 Value vexp, Value rd) { 186 vector::CombiningKind kind = getCombiningKind(rd.getDefiningOp()); 187 return rewriter.create<vector::ReductionOp>(loc, kind, vexp); 188 } 189 190 /// This method is called twice to analyze and rewrite the given subscripts. 191 /// The first call (!codegen) does the analysis. Then, on success, the second 192 /// call (codegen) yields the proper vector form in the output parameter 193 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code 194 /// stay in sync. 195 /// 196 /// See https://llvm.org/docs/GetElementPtr.html for some background on 197 /// the complications described below. 198 /// 199 /// We need to generate a pointer/index load from the sparse storage scheme. 200 /// Narrower data types need to be zero extended before casting the value 201 /// into the index type used for looping and indexing. 202 /// 203 /// For the scalar case, subscripts simply zero extend narrower indices 204 /// into 64-bit values before casting to an index type without a performance 205 /// penalty. Indices that already are 64-bit, in theory, cannot express the 206 /// full range since the LLVM backend defines addressing in terms of an 207 /// unsigned pointer/signed index pair. 208 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, 209 VL vl, ValueRange subs, bool codegen, 210 Value vmask, SmallVectorImpl<Value> &idxs) { 211 for (auto sub : subs) { 212 // Invariant indices simply pass through. 213 if (sub.dyn_cast<BlockArgument>() || 214 sub.getDefiningOp()->getBlock() != &forOp.getRegion().front()) { 215 if (codegen) 216 idxs.push_back(sub); 217 continue; // success so far 218 } 219 // Look under the hood of casting. 220 auto cast = sub; 221 while (1) { 222 if (auto icast = cast.getDefiningOp<arith::IndexCastOp>()) 223 cast = icast->getOperand(0); 224 else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>()) 225 cast = ecast->getOperand(0); 226 else 227 break; 228 } 229 // Since the index vector is used in a subsequent gather/scatter 230 // operations, which effectively defines an unsigned pointer + signed 231 // index, we must zero extend the vector to an index width. For 8-bit 232 // and 16-bit values, an 32-bit index width suffices. For 32-bit values, 233 // zero extending the elements into 64-bit loses some performance since 234 // the 32-bit indexed gather/scatter is more efficient than the 64-bit 235 // index variant (if the negative 32-bit index space is unused, the 236 // enableSIMDIndex32 flag can preserve this performance). For 64-bit 237 // values, there is no good way to state that the indices are unsigned, 238 // which creates the potential of incorrect address calculations in the 239 // unlikely case we need such extremely large offsets. 240 if (auto load = cast.getDefiningOp<memref::LoadOp>()) { 241 if (codegen) { 242 SmallVector<Value> idxs2(load.getIndices()); // no need to analyze 243 Location loc = forOp.getLoc(); 244 Value vload = 245 genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask); 246 Type etp = vload.getType().cast<VectorType>().getElementType(); 247 if (!etp.isa<IndexType>()) { 248 if (etp.getIntOrFloatBitWidth() < 32) 249 vload = rewriter.create<arith::ExtUIOp>( 250 loc, vectorType(vl, rewriter.getI32Type()), vload); 251 else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32) 252 vload = rewriter.create<arith::ExtUIOp>( 253 loc, vectorType(vl, rewriter.getI64Type()), vload); 254 } 255 idxs.push_back(vload); 256 } 257 continue; // success so far 258 } 259 return false; 260 } 261 return true; 262 } 263 264 #define UNAOP(xxx) \ 265 if (isa<xxx>(def)) { \ 266 if (codegen) \ 267 vexp = rewriter.create<xxx>(loc, vx); \ 268 return true; \ 269 } 270 271 #define BINOP(xxx) \ 272 if (isa<xxx>(def)) { \ 273 if (codegen) \ 274 vexp = rewriter.create<xxx>(loc, vx, vy); \ 275 return true; \ 276 } 277 278 /// This method is called twice to analyze and rewrite the given expression. 279 /// The first call (!codegen) does the analysis. Then, on success, the second 280 /// call (codegen) yields the proper vector form in the output parameter 'vexp'. 281 /// This mechanism ensures that analysis and rewriting code stay in sync. 282 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 283 Value exp, bool codegen, Value vmask, Value &vexp) { 284 // A block argument in invariant. 285 if (auto arg = exp.dyn_cast<BlockArgument>()) { 286 if (codegen) 287 vexp = genVectorInvariantValue(rewriter, vl, exp); 288 return true; 289 } 290 // Something defined outside the loop-body is invariant as well. 291 Operation *def = exp.getDefiningOp(); 292 if (def->getBlock() != &forOp.getRegion().front()) { 293 if (codegen) 294 vexp = genVectorInvariantValue(rewriter, vl, exp); 295 return true; 296 } 297 // Inside loop-body unary and binary operations. Note that it would be 298 // nicer if we could somehow test and build the operations in a more 299 // concise manner than just listing them all (although this way we know 300 // for certain that they can vectorize). 301 Location loc = forOp.getLoc(); 302 if (auto load = dyn_cast<memref::LoadOp>(def)) { 303 auto subs = load.getIndices(); 304 SmallVector<Value> idxs; 305 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) { 306 if (codegen) 307 vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask); 308 return true; 309 } 310 } else if (def->getNumOperands() == 1) { 311 Value vx; 312 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 313 vx)) { 314 UNAOP(math::AbsFOp) 315 UNAOP(math::AbsIOp) 316 UNAOP(math::CeilOp) 317 UNAOP(math::FloorOp) 318 UNAOP(math::SqrtOp) 319 UNAOP(math::ExpM1Op) 320 UNAOP(math::Log1pOp) 321 UNAOP(math::SinOp) 322 UNAOP(math::TanhOp) 323 UNAOP(arith::NegFOp) 324 } 325 } else if (def->getNumOperands() == 2) { 326 Value vx, vy; 327 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, 328 vx) && 329 vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask, 330 vy)) { 331 BINOP(arith::MulFOp) 332 BINOP(arith::MulIOp) 333 BINOP(arith::DivFOp) 334 BINOP(arith::DivSIOp) 335 BINOP(arith::DivUIOp) 336 BINOP(arith::AddFOp) 337 BINOP(arith::AddIOp) 338 BINOP(arith::SubFOp) 339 BINOP(arith::SubIOp) 340 BINOP(arith::AndIOp) 341 BINOP(arith::OrIOp) 342 BINOP(arith::XOrIOp) 343 } 344 } 345 return false; 346 } 347 348 #undef UNAOP 349 #undef BINOP 350 351 /// This method is called twice to analyze and rewrite the given for-loop. 352 /// The first call (!codegen) does the analysis. Then, on success, the second 353 /// call (codegen) rewriters the IR into vector form. This mechanism ensures 354 /// that analysis and rewriting code stay in sync. 355 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, 356 bool codegen) { 357 Location loc = forOp.getLoc(); 358 Block &block = forOp.getRegion().front(); 359 scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator()); 360 auto &last = *++block.rbegin(); 361 scf::ForOp forOpNew; 362 363 // Perform initial set up during codegen (we know that the first analysis 364 // pass was successful). For reductions, we need to construct a completely 365 // new for-loop, since the incoming and outgoing reduction type 366 // changes into SIMD form. For stores, we can simply adjust the stride 367 // and insert in the existing for-loop. In both cases, we set up a vector 368 // mask for all operations which takes care of confining vectors to 369 // the original iteration space (later cleanup loops or other 370 // optimizations can take care of those). 371 Value vmask; 372 if (codegen) { 373 Value step = constantIndex(rewriter, loc, vl.vectorLength); 374 if (vl.enableVLAVectorization) { 375 Value vscale = 376 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); 377 step = rewriter.create<arith::MulIOp>(loc, vscale, step); 378 } 379 if (!yield.getResults().empty()) { 380 Value init = forOp.getInitArgs()[0]; 381 VectorType vtp = vectorType(vl, init.getType()); 382 Value vinit = 383 genVectorReducInit(rewriter, loc, vtp, init, yield->getOperand(0)); 384 forOpNew = rewriter.create<scf::ForOp>( 385 loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit); 386 rewriter.setInsertionPointToStart(forOpNew.getBody()); 387 } else { 388 forOp.setStep(step); 389 rewriter.setInsertionPoint(yield); 390 } 391 vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(), 392 forOp.getLowerBound(), forOp.getUpperBound(), step); 393 } 394 395 // Sparse for-loops either are terminated by a non-empty yield operation 396 // (reduction loop) or otherwise by a store operation (pararallel loop). 397 if (!yield.getResults().empty()) { 398 if (yield->getNumOperands() != 1) 399 return false; 400 Value redOp = yield->getOperand(0); 401 // Analyze/vectorize reduction. 402 // TODO: use linalg utils to verify the actual reduction? 403 Value vrhs; 404 if (vectorizeExpr(rewriter, forOp, vl, redOp, codegen, vmask, vrhs)) { 405 if (codegen) { 406 Value vpass = 407 genVectorInvariantValue(rewriter, vl, forOp.getRegionIterArg(0)); 408 Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass); 409 rewriter.create<scf::YieldOp>(loc, vred); 410 rewriter.setInsertionPointAfter(forOpNew); 411 Value vres = genVectorReducEnd(rewriter, loc, forOpNew.getResult(0), redOp); 412 // Now do some relinking (last one is not completely type safe 413 // but all bad ones are removed right away). This also folds away 414 // nop broadcast operations. 415 forOp.getResult(0).replaceAllUsesWith(vres); 416 forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar()); 417 forOp.getRegionIterArg(0).replaceAllUsesWith( 418 forOpNew.getRegionIterArg(0)); 419 rewriter.eraseOp(forOp); 420 } 421 return true; 422 } 423 } else if (auto store = dyn_cast<memref::StoreOp>(last)) { 424 // Analyze/vectorize store operation. 425 auto subs = store.getIndices(); 426 SmallVector<Value> idxs; 427 Value rhs = store.getValue(); 428 Value vrhs; 429 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) && 430 vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) { 431 if (codegen) { 432 genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs); 433 rewriter.eraseOp(store); 434 } 435 return true; 436 } 437 } 438 439 assert(!codegen && "cannot call codegen when analysis failed"); 440 return false; 441 } 442 443 /// Basic for-loop vectorizer. 444 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> { 445 public: 446 using OpRewritePattern<scf::ForOp>::OpRewritePattern; 447 448 ForOpRewriter(MLIRContext *context, unsigned vectorLength, 449 bool enableVLAVectorization, bool enableSIMDIndex32) 450 : OpRewritePattern(context), 451 vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {} 452 453 LogicalResult matchAndRewrite(scf::ForOp op, 454 PatternRewriter &rewriter) const override { 455 // Check for single block, unit-stride for-loop that is generated by 456 // sparse compiler, which means no data dependence analysis is required, 457 // and its loop-body is very restricted in form. 458 if (!op.getRegion().hasOneBlock() || !isIntValue(op.getStep(), 1) || 459 !op->hasAttr(SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())) 460 return failure(); 461 // Analyze (!codegen) and rewrite (codegen) loop-body. 462 if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) && 463 vectorizeStmt(rewriter, op, vl, /*codegen=*/true)) 464 return success(); 465 return failure(); 466 } 467 468 private: 469 const VL vl; 470 }; 471 472 } // namespace 473 474 //===----------------------------------------------------------------------===// 475 // Public method for populating vectorization rules. 476 //===----------------------------------------------------------------------===// 477 478 /// Populates the given patterns list with vectorization rules. 479 void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns, 480 unsigned vectorLength, 481 bool enableVLAVectorization, 482 bool enableSIMDIndex32) { 483 patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength, 484 enableVLAVectorization, enableSIMDIndex32); 485 } 486