xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (revision 16aa4e4bd1754b6d7e964408c514fa0cce0678f5)
1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts loops generated by the sparse compiler into a form that
10 // can exploit SIMD instructions of the target architecture. Note that this pass
11 // ensures the sparse compiler can generate efficient SIMD (including ArmSVE
12 // support) with proper separation of concerns as far as sparsification and
13 // vectorization is concerned. However, this pass is not the final abstraction
14 // level we want, and not the general vectorizer we want either. It forms a good
15 // stepping stone for incremental future improvements though.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "CodegenUtils.h"
20 
21 #include "mlir/Dialect/Affine/IR/AffineOps.h"
22 #include "mlir/Dialect/Arith/IR/Arith.h"
23 #include "mlir/Dialect/Complex/IR/Complex.h"
24 #include "mlir/Dialect/Math/IR/Math.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/Dialect/SCF/IR/SCF.h"
27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
28 #include "mlir/Dialect/Vector/IR/VectorOps.h"
29 #include "mlir/IR/Matchers.h"
30 
31 using namespace mlir;
32 using namespace mlir::sparse_tensor;
33 
34 namespace {
35 
36 /// Target SIMD properties:
37 ///   vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
38 ///   enableVLAVectorization: enables scalable vectors (viz. ARMSve)
39 ///   enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
40 struct VL {
41   unsigned vectorLength;
42   bool enableVLAVectorization;
43   bool enableSIMDIndex32;
44 };
45 
46 /// Helper to test for given index value.
47 static bool isIntValue(Value val, int64_t idx) {
48   if (auto ival = getConstantIntValue(val))
49     return *ival == idx;
50   return false;
51 }
52 
53 /// Constructs vector type for element type.
54 static VectorType vectorType(VL vl, Type etp) {
55   unsigned numScalableDims = vl.enableVLAVectorization;
56   return VectorType::get(vl.vectorLength, etp, numScalableDims);
57 }
58 
59 /// Constructs vector type from pointer.
60 static VectorType vectorType(VL vl, Value ptr) {
61   return vectorType(vl, ptr.getType().cast<MemRefType>().getElementType());
62 }
63 
64 /// Constructs vector iteration mask.
65 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
66                            Value iv, Value lo, Value hi, Value step) {
67   VectorType mtp = vectorType(vl, rewriter.getI1Type());
68   // Special case if the vector length evenly divides the trip count (for
69   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
70   // so that all subsequent masked memory operations are immediately folded
71   // into unconditional memory operations.
72   IntegerAttr loInt, hiInt, stepInt;
73   if (matchPattern(lo, m_Constant(&loInt)) &&
74       matchPattern(hi, m_Constant(&hiInt)) &&
75       matchPattern(step, m_Constant(&stepInt))) {
76     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
77       Value trueVal = constantI1(rewriter, loc, true);
78       return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal);
79     }
80   }
81   // Otherwise, generate a vector mask that avoids overrunning the upperbound
82   // during vector execution. Here we rely on subsequent loop optimizations to
83   // avoid executing the mask in all iterations, for example, by splitting the
84   // loop into an unconditional vector loop and a scalar cleanup loop.
85   auto min = AffineMap::get(
86       /*dimCount=*/2, /*symbolCount=*/1,
87       {rewriter.getAffineSymbolExpr(0),
88        rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
89       rewriter.getContext());
90   Value end =
91       rewriter.createOrFold<AffineMinOp>(loc, min, ValueRange{hi, iv, step});
92   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
93 }
94 
95 /// Generates a vectorized invariant. Here we rely on subsequent loop
96 /// optimizations to hoist the invariant broadcast out of the vector loop.
97 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
98                                      Value val) {
99   VectorType vtp = vectorType(vl, val.getType());
100   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
101 }
102 
103 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
104 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
105 /// that the sparse compiler can only generate indirect loads in
106 /// the last index, i.e. back().
107 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
108                            Value ptr, ArrayRef<Value> idxs, Value vmask) {
109   VectorType vtp = vectorType(vl, ptr);
110   Value pass = constantZero(rewriter, loc, vtp);
111   if (idxs.back().getType().isa<VectorType>()) {
112     SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
113     Value indexVec = idxs.back();
114     scalarArgs.back() = constantIndex(rewriter, loc, 0);
115     return rewriter.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs,
116                                              indexVec, vmask, pass);
117   }
118   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, idxs, vmask,
119                                                pass);
120 }
121 
122 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
123 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
124 /// that the sparse compiler can only generate indirect stores in
125 /// the last index, i.e. back().
126 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr,
127                            ArrayRef<Value> idxs, Value vmask, Value rhs) {
128   if (idxs.back().getType().isa<VectorType>()) {
129     SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
130     Value indexVec = idxs.back();
131     scalarArgs.back() = constantIndex(rewriter, loc, 0);
132     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, vmask,
133                                        rhs);
134     return;
135   }
136   rewriter.create<vector::MaskedStoreOp>(loc, ptr, idxs, vmask, rhs);
137 }
138 
139 /// Detects a vectorizable reduction operations and returns the
140 /// combining kind of reduction on success in `kind`.
141 static bool isVectorizableReduction(Value red, Value iter,
142                                     vector::CombiningKind &kind) {
143   if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
144     kind = vector::CombiningKind::ADD;
145     return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
146   }
147   if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
148     kind = vector::CombiningKind::ADD;
149     return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
150   }
151   if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
152     kind = vector::CombiningKind::ADD;
153     return subf->getOperand(0) == iter;
154   }
155   if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
156     kind = vector::CombiningKind::ADD;
157     return subi->getOperand(0) == iter;
158   }
159   if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
160     kind = vector::CombiningKind::MUL;
161     return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
162   }
163   if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
164     kind = vector::CombiningKind::MUL;
165     return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
166   }
167   if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
168     kind = vector::CombiningKind::AND;
169     return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
170   }
171   if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
172     kind = vector::CombiningKind::OR;
173     return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
174   }
175   if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
176     kind = vector::CombiningKind::XOR;
177     return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
178   }
179   return false;
180 }
181 
182 /// Generates an initial value for a vector reduction, following the scheme
183 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
184 /// initial scalar value is correctly embedded in the vector reduction value,
185 /// and a straightforward horizontal reduction will complete the operation.
186 /// Value 'r' denotes the initial value of the reduction outside the loop.
187 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
188                                 Value red, Value iter, Value r,
189                                 VectorType vtp) {
190   vector::CombiningKind kind;
191   if (!isVectorizableReduction(red, iter, kind))
192     llvm_unreachable("unknown reduction");
193   switch (kind) {
194   case vector::CombiningKind::ADD:
195   case vector::CombiningKind::XOR:
196     // Initialize reduction vector to: | 0 | .. | 0 | r |
197     return rewriter.create<vector::InsertElementOp>(
198         loc, r, constantZero(rewriter, loc, vtp),
199         constantIndex(rewriter, loc, 0));
200   case vector::CombiningKind::MUL:
201     // Initialize reduction vector to: | 1 | .. | 1 | r |
202     return rewriter.create<vector::InsertElementOp>(
203         loc, r, constantOne(rewriter, loc, vtp),
204         constantIndex(rewriter, loc, 0));
205   case vector::CombiningKind::AND:
206   case vector::CombiningKind::OR:
207     // Initialize reduction vector to: | r | .. | r | r |
208     return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
209   default:
210     break;
211   }
212   llvm_unreachable("unknown reduction kind");
213 }
214 
215 /// This method is called twice to analyze and rewrite the given subscripts.
216 /// The first call (!codegen) does the analysis. Then, on success, the second
217 /// call (codegen) yields the proper vector form in the output parameter
218 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code
219 /// stay in sync. Note that the analyis part is simple because the sparse
220 /// compiler only generates relatively simple subscript expressions.
221 ///
222 /// See https://llvm.org/docs/GetElementPtr.html for some background on
223 /// the complications described below.
224 ///
225 /// We need to generate a pointer/index load from the sparse storage scheme.
226 /// Narrower data types need to be zero extended before casting the value
227 /// into the index type used for looping and indexing.
228 ///
229 /// For the scalar case, subscripts simply zero extend narrower indices
230 /// into 64-bit values before casting to an index type without a performance
231 /// penalty. Indices that already are 64-bit, in theory, cannot express the
232 /// full range since the LLVM backend defines addressing in terms of an
233 /// unsigned pointer/signed index pair.
234 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
235                                 VL vl, ValueRange subs, bool codegen,
236                                 Value vmask, SmallVectorImpl<Value> &idxs) {
237   for (auto sub : subs) {
238     // Invariant/loop indices simply pass through.
239     if (sub.dyn_cast<BlockArgument>() ||
240         sub.getDefiningOp()->getBlock() != &forOp.getRegion().front()) {
241       if (codegen)
242         idxs.push_back(sub);
243       continue; // success so far
244     }
245     // Look under the hood of casting.
246     auto cast = sub;
247     while (true) {
248       if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
249         cast = icast->getOperand(0);
250       else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
251         cast = ecast->getOperand(0);
252       else
253         break;
254     }
255     // Since the index vector is used in a subsequent gather/scatter
256     // operations, which effectively defines an unsigned pointer + signed
257     // index, we must zero extend the vector to an index width. For 8-bit
258     // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
259     // zero extending the elements into 64-bit loses some performance since
260     // the 32-bit indexed gather/scatter is more efficient than the 64-bit
261     // index variant (if the negative 32-bit index space is unused, the
262     // enableSIMDIndex32 flag can preserve this performance). For 64-bit
263     // values, there is no good way to state that the indices are unsigned,
264     // which creates the potential of incorrect address calculations in the
265     // unlikely case we need such extremely large offsets.
266     if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
267       if (codegen) {
268         SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
269         Location loc = forOp.getLoc();
270         Value vload =
271             genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
272         Type etp = vload.getType().cast<VectorType>().getElementType();
273         if (!etp.isa<IndexType>()) {
274           if (etp.getIntOrFloatBitWidth() < 32)
275             vload = rewriter.create<arith::ExtUIOp>(
276                 loc, vectorType(vl, rewriter.getI32Type()), vload);
277           else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
278             vload = rewriter.create<arith::ExtUIOp>(
279                 loc, vectorType(vl, rewriter.getI64Type()), vload);
280         }
281         idxs.push_back(vload);
282       }
283       continue; // success so far
284     }
285     // Address calculation 'i = add inv, idx' (after LICM).
286     if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
287       Value inv = load.getOperand(0);
288       Value idx = load.getOperand(1);
289       if (!inv.dyn_cast<BlockArgument>() &&
290           inv.getDefiningOp()->getBlock() != &forOp.getRegion().front() &&
291           idx.dyn_cast<BlockArgument>()) {
292         if (codegen)
293           idxs.push_back(
294               rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
295         continue; // success so far
296       }
297     }
298     return false;
299   }
300   return true;
301 }
302 
303 #define UNAOP(xxx)                                                             \
304   if (isa<xxx>(def)) {                                                         \
305     if (codegen)                                                               \
306       vexp = rewriter.create<xxx>(loc, vx);                                    \
307     return true;                                                               \
308   }
309 
310 #define TYPEDUNAOP(xxx)                                                        \
311   if (auto x = dyn_cast<xxx>(def)) {                                           \
312     if (codegen) {                                                             \
313       VectorType vtp = vectorType(vl, x.getType());                            \
314       vexp = rewriter.create<xxx>(loc, vtp, vx);                               \
315     }                                                                          \
316     return true;                                                               \
317   }
318 
319 #define BINOP(xxx)                                                             \
320   if (isa<xxx>(def)) {                                                         \
321     if (codegen)                                                               \
322       vexp = rewriter.create<xxx>(loc, vx, vy);                                \
323     return true;                                                               \
324   }
325 
326 /// This method is called twice to analyze and rewrite the given expression.
327 /// The first call (!codegen) does the analysis. Then, on success, the second
328 /// call (codegen) yields the proper vector form in the output parameter 'vexp'.
329 /// This mechanism ensures that analysis and rewriting code stay in sync. Note
330 /// that the analyis part is simple because the sparse compiler only generates
331 /// relatively simple expressions inside the for-loops.
332 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
333                           Value exp, bool codegen, Value vmask, Value &vexp) {
334   Location loc = forOp.getLoc();
335   // Reject unsupported types.
336   if (!VectorType::isValidElementType(exp.getType()))
337     return false;
338   // A block argument is invariant/reduction/index.
339   if (auto arg = exp.dyn_cast<BlockArgument>()) {
340     if (arg == forOp.getInductionVar()) {
341       // We encountered a single, innermost index inside the computation,
342       // such as a[i] = i, which must convert to [i, i+1, ...].
343       if (codegen) {
344         VectorType vtp = vectorType(vl, arg.getType());
345         Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg);
346         Value incr;
347         if (vl.enableVLAVectorization) {
348           Type stepvty = vectorType(vl, rewriter.getI64Type());
349           Value stepv = rewriter.create<LLVM::StepVectorOp>(loc, stepvty);
350           incr = rewriter.create<arith::IndexCastOp>(loc, vtp, stepv);
351         } else {
352           SmallVector<APInt> integers;
353           for (unsigned i = 0, l = vl.vectorLength; i < l; i++)
354             integers.push_back(APInt(/*width=*/64, i));
355           auto values = DenseElementsAttr::get(vtp, integers);
356           incr = rewriter.create<arith::ConstantOp>(loc, vtp, values);
357         }
358         vexp = rewriter.create<arith::AddIOp>(loc, veci, incr);
359       }
360       return true;
361     } else {
362       // An invariant or reduction. In both cases, we treat this as an
363       // invariant value, and rely on later replacing and folding to
364       // construct a proper reduction chain for the latter case.
365       if (codegen)
366         vexp = genVectorInvariantValue(rewriter, vl, exp);
367       return true;
368     }
369   }
370   // Something defined outside the loop-body is invariant.
371   Operation *def = exp.getDefiningOp();
372   if (def->getBlock() != &forOp.getRegion().front()) {
373     if (codegen)
374       vexp = genVectorInvariantValue(rewriter, vl, exp);
375     return true;
376   }
377   // Proper load operations. These are either values involved in the
378   // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi],
379   // or index values inside the computation that are now fetched from
380   // the sparse storage index arrays, such as a[i] = i becomes
381   // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index
382   // and 'hi = lo + vl - 1'.
383   if (auto load = dyn_cast<memref::LoadOp>(def)) {
384     auto subs = load.getIndices();
385     SmallVector<Value> idxs;
386     if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
387       if (codegen)
388         vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
389       return true;
390     }
391     return false;
392   }
393   // Inside loop-body unary and binary operations. Note that it would be
394   // nicer if we could somehow test and build the operations in a more
395   // concise manner than just listing them all (although this way we know
396   // for certain that they can vectorize).
397   //
398   // TODO: avoid visiting CSEs multiple times
399   //
400   if (def->getNumOperands() == 1) {
401     Value vx;
402     if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
403                       vx)) {
404       UNAOP(math::AbsFOp)
405       UNAOP(math::AbsIOp)
406       UNAOP(math::CeilOp)
407       UNAOP(math::FloorOp)
408       UNAOP(math::SqrtOp)
409       UNAOP(math::ExpM1Op)
410       UNAOP(math::Log1pOp)
411       UNAOP(math::SinOp)
412       UNAOP(math::TanhOp)
413       UNAOP(arith::NegFOp)
414       TYPEDUNAOP(arith::TruncFOp)
415       TYPEDUNAOP(arith::ExtFOp)
416       TYPEDUNAOP(arith::FPToSIOp)
417       TYPEDUNAOP(arith::FPToUIOp)
418       TYPEDUNAOP(arith::SIToFPOp)
419       TYPEDUNAOP(arith::UIToFPOp)
420       TYPEDUNAOP(arith::ExtSIOp)
421       TYPEDUNAOP(arith::ExtUIOp)
422       TYPEDUNAOP(arith::IndexCastOp)
423       TYPEDUNAOP(arith::TruncIOp)
424       TYPEDUNAOP(arith::BitcastOp)
425       // TODO: complex?
426     }
427   } else if (def->getNumOperands() == 2) {
428     Value vx, vy;
429     if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
430                       vx) &&
431         vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
432                       vy)) {
433       BINOP(arith::MulFOp)
434       BINOP(arith::MulIOp)
435       BINOP(arith::DivFOp)
436       BINOP(arith::DivSIOp)
437       BINOP(arith::DivUIOp)
438       BINOP(arith::AddFOp)
439       BINOP(arith::AddIOp)
440       BINOP(arith::SubFOp)
441       BINOP(arith::SubIOp)
442       BINOP(arith::AndIOp)
443       BINOP(arith::OrIOp)
444       BINOP(arith::XOrIOp)
445       // TODO: complex?
446       // TODO: shift by invariant?
447     }
448   }
449   return false;
450 }
451 
452 #undef UNAOP
453 #undef TYPEDUNAOP
454 #undef BINOP
455 
456 /// This method is called twice to analyze and rewrite the given for-loop.
457 /// The first call (!codegen) does the analysis. Then, on success, the second
458 /// call (codegen) rewriters the IR into vector form. This mechanism ensures
459 /// that analysis and rewriting code stay in sync.
460 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
461                           bool codegen) {
462   Location loc = forOp.getLoc();
463   Block &block = forOp.getRegion().front();
464   scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
465   auto &last = *++block.rbegin();
466   scf::ForOp forOpNew;
467 
468   // Perform initial set up during codegen (we know that the first analysis
469   // pass was successful). For reductions, we need to construct a completely
470   // new for-loop, since the incoming and outgoing reduction type
471   // changes into SIMD form. For stores, we can simply adjust the stride
472   // and insert in the existing for-loop. In both cases, we set up a vector
473   // mask for all operations which takes care of confining vectors to
474   // the original iteration space (later cleanup loops or other
475   // optimizations can take care of those).
476   Value vmask;
477   if (codegen) {
478     Value step = constantIndex(rewriter, loc, vl.vectorLength);
479     if (vl.enableVLAVectorization) {
480       Value vscale =
481           rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
482       step = rewriter.create<arith::MulIOp>(loc, vscale, step);
483     }
484     if (!yield.getResults().empty()) {
485       Value init = forOp.getInitArgs()[0];
486       VectorType vtp = vectorType(vl, init.getType());
487       Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
488                                        forOp.getRegionIterArg(0), init, vtp);
489       forOpNew = rewriter.create<scf::ForOp>(
490           loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
491       forOpNew->setAttr(
492           SparseTensorLoopEmitter::getLoopEmitterLoopAttrName(),
493           forOp->getAttr(
494               SparseTensorLoopEmitter::getLoopEmitterLoopAttrName()));
495       rewriter.setInsertionPointToStart(forOpNew.getBody());
496     } else {
497       forOp.setStep(step);
498       rewriter.setInsertionPoint(yield);
499     }
500     vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
501                           forOp.getLowerBound(), forOp.getUpperBound(), step);
502   }
503 
504   // Sparse for-loops either are terminated by a non-empty yield operation
505   // (reduction loop) or otherwise by a store operation (pararallel loop).
506   if (!yield.getResults().empty()) {
507     // Analyze/vectorize reduction.
508     if (yield->getNumOperands() != 1)
509       return false;
510     Value red = yield->getOperand(0);
511     Value iter = forOp.getRegionIterArg(0);
512     vector::CombiningKind kind;
513     Value vrhs;
514     if (isVectorizableReduction(red, iter, kind) &&
515         vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
516       if (codegen) {
517         Value partial = forOpNew.getResult(0);
518         Value vpass = genVectorInvariantValue(rewriter, vl, iter);
519         Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
520         rewriter.create<scf::YieldOp>(loc, vred);
521         rewriter.setInsertionPointAfter(forOpNew);
522         Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial);
523         // Now do some relinking (last one is not completely type safe
524         // but all bad ones are removed right away). This also folds away
525         // nop broadcast operations.
526         forOp.getResult(0).replaceAllUsesWith(vres);
527         forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar());
528         forOp.getRegionIterArg(0).replaceAllUsesWith(
529             forOpNew.getRegionIterArg(0));
530         rewriter.eraseOp(forOp);
531       }
532       return true;
533     }
534   } else if (auto store = dyn_cast<memref::StoreOp>(last)) {
535     // Analyze/vectorize store operation.
536     auto subs = store.getIndices();
537     SmallVector<Value> idxs;
538     Value rhs = store.getValue();
539     Value vrhs;
540     if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
541         vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
542       if (codegen) {
543         genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
544         rewriter.eraseOp(store);
545       }
546       return true;
547     }
548   }
549 
550   assert(!codegen && "cannot call codegen when analysis failed");
551   return false;
552 }
553 
554 /// Basic for-loop vectorizer.
555 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
556 public:
557   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
558 
559   ForOpRewriter(MLIRContext *context, unsigned vectorLength,
560                 bool enableVLAVectorization, bool enableSIMDIndex32)
561       : OpRewritePattern(context),
562         vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {}
563 
564   LogicalResult matchAndRewrite(scf::ForOp op,
565                                 PatternRewriter &rewriter) const override {
566     // Check for single block, unit-stride for-loop that is generated by
567     // sparse compiler, which means no data dependence analysis is required,
568     // and its loop-body is very restricted in form.
569     if (!op.getRegion().hasOneBlock() || !isIntValue(op.getStep(), 1) ||
570         !op->hasAttr(SparseTensorLoopEmitter::getLoopEmitterLoopAttrName()))
571       return failure();
572     // Analyze (!codegen) and rewrite (codegen) loop-body.
573     if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
574         vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
575       return success();
576     return failure();
577   }
578 
579 private:
580   const VL vl;
581 };
582 
583 /// Reduction chain cleanup.
584 ///   v = for { }
585 ///   s = vsum(v)               v = for { }
586 ///   u = expand(s)       ->    for (v) { }
587 ///   for (u) { }
588 template <typename VectorOp>
589 struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
590 public:
591   using OpRewritePattern<VectorOp>::OpRewritePattern;
592 
593   LogicalResult matchAndRewrite(VectorOp op,
594                                 PatternRewriter &rewriter) const override {
595     Value inp = op.getSource();
596     if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
597       if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
598         if (forOp->hasAttr(
599                 SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())) {
600           rewriter.replaceOp(op, redOp.getVector());
601           return success();
602         }
603       }
604     }
605     return failure();
606   }
607 };
608 
609 } // namespace
610 
611 //===----------------------------------------------------------------------===//
612 // Public method for populating vectorization rules.
613 //===----------------------------------------------------------------------===//
614 
615 /// Populates the given patterns list with vectorization rules.
616 void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
617                                                unsigned vectorLength,
618                                                bool enableVLAVectorization,
619                                                bool enableSIMDIndex32) {
620   assert(vectorLength > 0);
621   patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
622                               enableVLAVectorization, enableSIMDIndex32);
623   patterns.add<ReducChainRewriter<vector::InsertElementOp>,
624                ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
625 }
626