xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (revision cb82d375a8060bd3af83b64d7d2c94f4a59d4b97)
1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts loops generated by the sparse compiler into a form that
10 // can exploit SIMD instructions of the target architecture. Note that this pass
11 // ensures the sparse compiler can generate efficient SIMD (including ArmSVE
12 // support) with proper separation of concerns as far as sparsification and
13 // vectorization is concerned. However, this pass is not the final abstraction
14 // level we want, and not the general vectorizer we want either. It forms a good
15 // stepping stone for incremental future improvements though.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "CodegenUtils.h"
20 
21 #include "mlir/Dialect/Affine/IR/AffineOps.h"
22 #include "mlir/Dialect/Arith/IR/Arith.h"
23 #include "mlir/Dialect/Complex/IR/Complex.h"
24 #include "mlir/Dialect/Math/IR/Math.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/Dialect/SCF/IR/SCF.h"
27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
28 #include "mlir/Dialect/Vector/IR/VectorOps.h"
29 #include "mlir/IR/Matchers.h"
30 
31 using namespace mlir;
32 using namespace mlir::sparse_tensor;
33 
34 namespace {
35 
36 /// Target SIMD properties:
37 ///   vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
38 ///   enableVLAVectorization: enables scalable vectors (viz. ARMSve)
39 ///   enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
40 struct VL {
41   unsigned vectorLength;
42   bool enableVLAVectorization;
43   bool enableSIMDIndex32;
44 };
45 
46 /// Helper to test for given index value.
47 static bool isIntValue(Value val, int64_t idx) {
48   if (auto ival = getConstantIntValue(val))
49     return *ival == idx;
50   return false;
51 }
52 
53 /// Constructs vector type for element type.
54 static VectorType vectorType(VL vl, Type etp) {
55   unsigned numScalableDims = vl.enableVLAVectorization;
56   return VectorType::get(vl.vectorLength, etp, numScalableDims);
57 }
58 
59 /// Constructs vector type from pointer.
60 static VectorType vectorType(VL vl, Value ptr) {
61   return vectorType(vl, ptr.getType().cast<MemRefType>().getElementType());
62 }
63 
64 /// Constructs vector iteration mask.
65 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
66                            Value iv, Value lo, Value hi, Value step) {
67   VectorType mtp = vectorType(vl, rewriter.getI1Type());
68   // Special case if the vector length evenly divides the trip count (for
69   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
70   // so that all subsequent masked memory operations are immediately folded
71   // into unconditional memory operations.
72   IntegerAttr loInt, hiInt, stepInt;
73   if (matchPattern(lo, m_Constant(&loInt)) &&
74       matchPattern(hi, m_Constant(&hiInt)) &&
75       matchPattern(step, m_Constant(&stepInt))) {
76     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
77       Value trueVal = constantI1(rewriter, loc, true);
78       return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal);
79     }
80   }
81   // Otherwise, generate a vector mask that avoids overrunning the upperbound
82   // during vector execution. Here we rely on subsequent loop optimizations to
83   // avoid executing the mask in all iterations, for example, by splitting the
84   // loop into an unconditional vector loop and a scalar cleanup loop.
85   auto min = AffineMap::get(
86       /*dimCount=*/2, /*symbolCount=*/1,
87       {rewriter.getAffineSymbolExpr(0),
88        rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
89       rewriter.getContext());
90   Value end =
91       rewriter.createOrFold<AffineMinOp>(loc, min, ValueRange{hi, iv, step});
92   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
93 }
94 
95 /// Generates a vectorized invariant. Here we rely on subsequent loop
96 /// optimizations to hoist the invariant broadcast out of the vector loop.
97 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
98                                      Value val) {
99   VectorType vtp = vectorType(vl, val.getType());
100   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
101 }
102 
103 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
104 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
105 /// that the sparse compiler can only generate indirect loads in
106 /// the last index, i.e. back().
107 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
108                            Value ptr, ArrayRef<Value> idxs, Value vmask) {
109   VectorType vtp = vectorType(vl, ptr);
110   Value pass = constantZero(rewriter, loc, vtp);
111   if (idxs.back().getType().isa<VectorType>()) {
112     SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
113     Value indexVec = idxs.back();
114     scalarArgs.back() = constantIndex(rewriter, loc, 0);
115     return rewriter.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs,
116                                              indexVec, vmask, pass);
117   }
118   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, idxs, vmask,
119                                                pass);
120 }
121 
122 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
123 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
124 /// that the sparse compiler can only generate indirect stores in
125 /// the last index, i.e. back().
126 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr,
127                            ArrayRef<Value> idxs, Value vmask, Value rhs) {
128   if (idxs.back().getType().isa<VectorType>()) {
129     SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
130     Value indexVec = idxs.back();
131     scalarArgs.back() = constantIndex(rewriter, loc, 0);
132     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, vmask,
133                                        rhs);
134     return;
135   }
136   rewriter.create<vector::MaskedStoreOp>(loc, ptr, idxs, vmask, rhs);
137 }
138 
139 /// Detects a vectorizable reduction operations and returns the
140 /// combining kind of reduction on success in `kind`.
141 static bool isVectorizableReduction(Value red, Value iter,
142                                     vector::CombiningKind &kind) {
143   if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
144     kind = vector::CombiningKind::ADD;
145     return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
146   }
147   if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
148     kind = vector::CombiningKind::ADD;
149     return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
150   }
151   if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
152     kind = vector::CombiningKind::ADD;
153     return subf->getOperand(0) == iter;
154   }
155   if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
156     kind = vector::CombiningKind::ADD;
157     return subi->getOperand(0) == iter;
158   }
159   if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
160     kind = vector::CombiningKind::MUL;
161     return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
162   }
163   if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
164     kind = vector::CombiningKind::MUL;
165     return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
166   }
167   if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
168     kind = vector::CombiningKind::AND;
169     return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
170   }
171   if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
172     kind = vector::CombiningKind::OR;
173     return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
174   }
175   if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
176     kind = vector::CombiningKind::XOR;
177     return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
178   }
179   return false;
180 }
181 
182 /// Generates an initial value for a vector reduction, following the scheme
183 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
184 /// initial scalar value is correctly embedded in the vector reduction value,
185 /// and a straightforward horizontal reduction will complete the operation.
186 /// Value 'r' denotes the initial value of the reduction outside the loop.
187 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
188                                 Value red, Value iter, Value r,
189                                 VectorType vtp) {
190   vector::CombiningKind kind;
191   if (!isVectorizableReduction(red, iter, kind))
192     llvm_unreachable("unknown reduction");
193   switch (kind) {
194   case vector::CombiningKind::ADD:
195   case vector::CombiningKind::XOR:
196     // Initialize reduction vector to: | 0 | .. | 0 | r |
197     return rewriter.create<vector::InsertElementOp>(
198         loc, r, constantZero(rewriter, loc, vtp),
199         constantIndex(rewriter, loc, 0));
200   case vector::CombiningKind::MUL:
201     // Initialize reduction vector to: | 1 | .. | 1 | r |
202     return rewriter.create<vector::InsertElementOp>(
203         loc, r, constantOne(rewriter, loc, vtp),
204         constantIndex(rewriter, loc, 0));
205   case vector::CombiningKind::AND:
206   case vector::CombiningKind::OR:
207     // Initialize reduction vector to: | r | .. | r | r |
208     return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
209   default:
210     break;
211   }
212   llvm_unreachable("unknown reduction kind");
213 }
214 
215 /// This method is called twice to analyze and rewrite the given subscripts.
216 /// The first call (!codegen) does the analysis. Then, on success, the second
217 /// call (codegen) yields the proper vector form in the output parameter
218 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code
219 /// stay in sync.
220 ///
221 /// See https://llvm.org/docs/GetElementPtr.html for some background on
222 /// the complications described below.
223 ///
224 /// We need to generate a pointer/index load from the sparse storage scheme.
225 /// Narrower data types need to be zero extended before casting the value
226 /// into the index type used for looping and indexing.
227 ///
228 /// For the scalar case, subscripts simply zero extend narrower indices
229 /// into 64-bit values before casting to an index type without a performance
230 /// penalty. Indices that already are 64-bit, in theory, cannot express the
231 /// full range since the LLVM backend defines addressing in terms of an
232 /// unsigned pointer/signed index pair.
233 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
234                                 VL vl, ValueRange subs, bool codegen,
235                                 Value vmask, SmallVectorImpl<Value> &idxs) {
236   for (auto sub : subs) {
237     // Invariant indices simply pass through.
238     if (sub.dyn_cast<BlockArgument>() ||
239         sub.getDefiningOp()->getBlock() != &forOp.getRegion().front()) {
240       if (codegen)
241         idxs.push_back(sub);
242       continue; // success so far
243     }
244     // Look under the hood of casting.
245     auto cast = sub;
246     while (1) {
247       if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
248         cast = icast->getOperand(0);
249       else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
250         cast = ecast->getOperand(0);
251       else
252         break;
253     }
254     // Since the index vector is used in a subsequent gather/scatter
255     // operations, which effectively defines an unsigned pointer + signed
256     // index, we must zero extend the vector to an index width. For 8-bit
257     // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
258     // zero extending the elements into 64-bit loses some performance since
259     // the 32-bit indexed gather/scatter is more efficient than the 64-bit
260     // index variant (if the negative 32-bit index space is unused, the
261     // enableSIMDIndex32 flag can preserve this performance). For 64-bit
262     // values, there is no good way to state that the indices are unsigned,
263     // which creates the potential of incorrect address calculations in the
264     // unlikely case we need such extremely large offsets.
265     if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
266       if (codegen) {
267         SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
268         Location loc = forOp.getLoc();
269         Value vload =
270             genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
271         Type etp = vload.getType().cast<VectorType>().getElementType();
272         if (!etp.isa<IndexType>()) {
273           if (etp.getIntOrFloatBitWidth() < 32)
274             vload = rewriter.create<arith::ExtUIOp>(
275                 loc, vectorType(vl, rewriter.getI32Type()), vload);
276           else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
277             vload = rewriter.create<arith::ExtUIOp>(
278                 loc, vectorType(vl, rewriter.getI64Type()), vload);
279         }
280         idxs.push_back(vload);
281       }
282       continue; // success so far
283     }
284     return false;
285   }
286   return true;
287 }
288 
289 #define UNAOP(xxx)                                                             \
290   if (isa<xxx>(def)) {                                                         \
291     if (codegen)                                                               \
292       vexp = rewriter.create<xxx>(loc, vx);                                    \
293     return true;                                                               \
294   }
295 
296 #define BINOP(xxx)                                                             \
297   if (isa<xxx>(def)) {                                                         \
298     if (codegen)                                                               \
299       vexp = rewriter.create<xxx>(loc, vx, vy);                                \
300     return true;                                                               \
301   }
302 
303 /// This method is called twice to analyze and rewrite the given expression.
304 /// The first call (!codegen) does the analysis. Then, on success, the second
305 /// call (codegen) yields the proper vector form in the output parameter 'vexp'.
306 /// This mechanism ensures that analysis and rewriting code stay in sync.
307 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
308                           Value exp, bool codegen, Value vmask, Value &vexp) {
309   // A block argument in invariant.
310   if (auto arg = exp.dyn_cast<BlockArgument>()) {
311     if (codegen)
312       vexp = genVectorInvariantValue(rewriter, vl, exp);
313     return true;
314   }
315   // Something defined outside the loop-body is invariant as well.
316   Operation *def = exp.getDefiningOp();
317   if (def->getBlock() != &forOp.getRegion().front()) {
318     if (codegen)
319       vexp = genVectorInvariantValue(rewriter, vl, exp);
320     return true;
321   }
322   // Inside loop-body unary and binary operations. Note that it would be
323   // nicer if we could somehow test and build the operations in a more
324   // concise manner than just listing them all (although this way we know
325   // for certain that they can vectorize).
326   Location loc = forOp.getLoc();
327   if (auto load = dyn_cast<memref::LoadOp>(def)) {
328     auto subs = load.getIndices();
329     SmallVector<Value> idxs;
330     if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
331       if (codegen)
332         vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
333       return true;
334     }
335   } else if (def->getNumOperands() == 1) {
336     Value vx;
337     if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
338                       vx)) {
339       UNAOP(math::AbsFOp)
340       UNAOP(math::AbsIOp)
341       UNAOP(math::CeilOp)
342       UNAOP(math::FloorOp)
343       UNAOP(math::SqrtOp)
344       UNAOP(math::ExpM1Op)
345       UNAOP(math::Log1pOp)
346       UNAOP(math::SinOp)
347       UNAOP(math::TanhOp)
348       UNAOP(arith::NegFOp)
349     }
350   } else if (def->getNumOperands() == 2) {
351     Value vx, vy;
352     if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
353                       vx) &&
354         vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
355                       vy)) {
356       BINOP(arith::MulFOp)
357       BINOP(arith::MulIOp)
358       BINOP(arith::DivFOp)
359       BINOP(arith::DivSIOp)
360       BINOP(arith::DivUIOp)
361       BINOP(arith::AddFOp)
362       BINOP(arith::AddIOp)
363       BINOP(arith::SubFOp)
364       BINOP(arith::SubIOp)
365       BINOP(arith::AndIOp)
366       BINOP(arith::OrIOp)
367       BINOP(arith::XOrIOp)
368     }
369   }
370   return false;
371 }
372 
373 #undef UNAOP
374 #undef BINOP
375 
376 /// This method is called twice to analyze and rewrite the given for-loop.
377 /// The first call (!codegen) does the analysis. Then, on success, the second
378 /// call (codegen) rewriters the IR into vector form. This mechanism ensures
379 /// that analysis and rewriting code stay in sync.
380 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
381                           bool codegen) {
382   Location loc = forOp.getLoc();
383   Block &block = forOp.getRegion().front();
384   scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
385   auto &last = *++block.rbegin();
386   scf::ForOp forOpNew;
387 
388   // Perform initial set up during codegen (we know that the first analysis
389   // pass was successful). For reductions, we need to construct a completely
390   // new for-loop, since the incoming and outgoing reduction type
391   // changes into SIMD form. For stores, we can simply adjust the stride
392   // and insert in the existing for-loop. In both cases, we set up a vector
393   // mask for all operations which takes care of confining vectors to
394   // the original iteration space (later cleanup loops or other
395   // optimizations can take care of those).
396   Value vmask;
397   if (codegen) {
398     Value step = constantIndex(rewriter, loc, vl.vectorLength);
399     if (vl.enableVLAVectorization) {
400       Value vscale =
401           rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
402       step = rewriter.create<arith::MulIOp>(loc, vscale, step);
403     }
404     if (!yield.getResults().empty()) {
405       Value init = forOp.getInitArgs()[0];
406       VectorType vtp = vectorType(vl, init.getType());
407       Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
408                                        forOp.getRegionIterArg(0), init, vtp);
409       forOpNew = rewriter.create<scf::ForOp>(
410           loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
411       forOpNew->setAttr(
412           SparseTensorLoopEmitter::getLoopEmitterLoopAttrName(),
413           forOp->getAttr(
414               SparseTensorLoopEmitter::getLoopEmitterLoopAttrName()));
415       rewriter.setInsertionPointToStart(forOpNew.getBody());
416     } else {
417       forOp.setStep(step);
418       rewriter.setInsertionPoint(yield);
419     }
420     vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
421                           forOp.getLowerBound(), forOp.getUpperBound(), step);
422   }
423 
424   // Sparse for-loops either are terminated by a non-empty yield operation
425   // (reduction loop) or otherwise by a store operation (pararallel loop).
426   if (!yield.getResults().empty()) {
427     // Analyze/vectorize reduction.
428     if (yield->getNumOperands() != 1)
429       return false;
430     Value red = yield->getOperand(0);
431     Value iter = forOp.getRegionIterArg(0);
432     vector::CombiningKind kind;
433     Value vrhs;
434     if (isVectorizableReduction(red, iter, kind) &&
435         vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
436       if (codegen) {
437         Value partial = forOpNew.getResult(0);
438         Value vpass = genVectorInvariantValue(rewriter, vl, iter);
439         Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
440         rewriter.create<scf::YieldOp>(loc, vred);
441         rewriter.setInsertionPointAfter(forOpNew);
442         Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial);
443         // Now do some relinking (last one is not completely type safe
444         // but all bad ones are removed right away). This also folds away
445         // nop broadcast operations.
446         forOp.getResult(0).replaceAllUsesWith(vres);
447         forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar());
448         forOp.getRegionIterArg(0).replaceAllUsesWith(
449             forOpNew.getRegionIterArg(0));
450         rewriter.eraseOp(forOp);
451       }
452       return true;
453     }
454   } else if (auto store = dyn_cast<memref::StoreOp>(last)) {
455     // Analyze/vectorize store operation.
456     auto subs = store.getIndices();
457     SmallVector<Value> idxs;
458     Value rhs = store.getValue();
459     Value vrhs;
460     if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
461         vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
462       if (codegen) {
463         genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
464         rewriter.eraseOp(store);
465       }
466       return true;
467     }
468   }
469 
470   assert(!codegen && "cannot call codegen when analysis failed");
471   return false;
472 }
473 
474 /// Basic for-loop vectorizer.
475 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
476 public:
477   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
478 
479   ForOpRewriter(MLIRContext *context, unsigned vectorLength,
480                 bool enableVLAVectorization, bool enableSIMDIndex32)
481       : OpRewritePattern(context),
482         vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {}
483 
484   LogicalResult matchAndRewrite(scf::ForOp op,
485                                 PatternRewriter &rewriter) const override {
486     // Check for single block, unit-stride for-loop that is generated by
487     // sparse compiler, which means no data dependence analysis is required,
488     // and its loop-body is very restricted in form.
489     if (!op.getRegion().hasOneBlock() || !isIntValue(op.getStep(), 1) ||
490         !op->hasAttr(SparseTensorLoopEmitter::getLoopEmitterLoopAttrName()))
491       return failure();
492     // Analyze (!codegen) and rewrite (codegen) loop-body.
493     if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
494         vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
495       return success();
496     return failure();
497   }
498 
499 private:
500   const VL vl;
501 };
502 
503 /// Reduction chain cleanup.
504 ///   v = for { }
505 ///   s = vsum(v)               v = for { }
506 ///   u = expand(s)       ->    for (v) { }
507 ///   for (u) { }
508 template <typename VectorOp>
509 struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
510 public:
511   using OpRewritePattern<VectorOp>::OpRewritePattern;
512 
513   LogicalResult matchAndRewrite(VectorOp op,
514                                 PatternRewriter &rewriter) const override {
515     Value inp = op.getSource();
516     if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
517       if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
518         if (forOp->hasAttr(
519                 SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())) {
520           rewriter.replaceOp(op, redOp.getVector());
521           return success();
522         }
523       }
524     }
525     return failure();
526   }
527 };
528 
529 } // namespace
530 
531 //===----------------------------------------------------------------------===//
532 // Public method for populating vectorization rules.
533 //===----------------------------------------------------------------------===//
534 
535 /// Populates the given patterns list with vectorization rules.
536 void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
537                                                unsigned vectorLength,
538                                                bool enableVLAVectorization,
539                                                bool enableSIMDIndex32) {
540   patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
541                               enableVLAVectorization, enableSIMDIndex32);
542   patterns.add<ReducChainRewriter<vector::InsertElementOp>,
543                ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
544 }
545