xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (revision c1fa60b4cde512964544ab66404dea79dbc5dcb4)
1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts loops generated by the sparse compiler into a form that
10 // can exploit SIMD instructions of the target architecture. Note that this pass
11 // ensures the sparse compiler can generate efficient SIMD (including ArmSVE
12 // support) with proper separation of concerns as far as sparsification and
13 // vectorization is concerned. However, this pass is not the final abstraction
14 // level we want, and not the general vectorizer we want either. It forms a good
15 // stepping stone for incremental future improvements though.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "CodegenUtils.h"
20 #include "LoopEmitter.h"
21 
22 #include "mlir/Dialect/Affine/IR/AffineOps.h"
23 #include "mlir/Dialect/Arith/IR/Arith.h"
24 #include "mlir/Dialect/Complex/IR/Complex.h"
25 #include "mlir/Dialect/Math/IR/Math.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/SCF/IR/SCF.h"
28 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
29 #include "mlir/Dialect/Vector/IR/VectorOps.h"
30 #include "mlir/IR/Matchers.h"
31 
32 using namespace mlir;
33 using namespace mlir::sparse_tensor;
34 
35 namespace {
36 
37 /// Target SIMD properties:
38 ///   vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
39 ///   enableVLAVectorization: enables scalable vectors (viz. ARMSve)
40 ///   enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
41 struct VL {
42   unsigned vectorLength;
43   bool enableVLAVectorization;
44   bool enableSIMDIndex32;
45 };
46 
47 /// Helper test for invariant value (defined outside given block).
48 static bool isInvariantValue(Value val, Block *block) {
49   return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block;
50 }
51 
52 /// Helper test for invariant argument (defined outside given block).
53 static bool isInvariantArg(BlockArgument arg, Block *block) {
54   return arg.getOwner() != block;
55 }
56 
57 /// Constructs vector type for element type.
58 static VectorType vectorType(VL vl, Type etp) {
59   unsigned numScalableDims = vl.enableVLAVectorization;
60   return VectorType::get(vl.vectorLength, etp, numScalableDims);
61 }
62 
63 /// Constructs vector type from a memref value.
64 static VectorType vectorType(VL vl, Value mem) {
65   return vectorType(vl, getMemRefType(mem).getElementType());
66 }
67 
68 /// Constructs vector iteration mask.
69 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
70                            Value iv, Value lo, Value hi, Value step) {
71   VectorType mtp = vectorType(vl, rewriter.getI1Type());
72   // Special case if the vector length evenly divides the trip count (for
73   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
74   // so that all subsequent masked memory operations are immediately folded
75   // into unconditional memory operations.
76   IntegerAttr loInt, hiInt, stepInt;
77   if (matchPattern(lo, m_Constant(&loInt)) &&
78       matchPattern(hi, m_Constant(&hiInt)) &&
79       matchPattern(step, m_Constant(&stepInt))) {
80     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
81       Value trueVal = constantI1(rewriter, loc, true);
82       return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal);
83     }
84   }
85   // Otherwise, generate a vector mask that avoids overrunning the upperbound
86   // during vector execution. Here we rely on subsequent loop optimizations to
87   // avoid executing the mask in all iterations, for example, by splitting the
88   // loop into an unconditional vector loop and a scalar cleanup loop.
89   auto min = AffineMap::get(
90       /*dimCount=*/2, /*symbolCount=*/1,
91       {rewriter.getAffineSymbolExpr(0),
92        rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
93       rewriter.getContext());
94   Value end = rewriter.createOrFold<affine::AffineMinOp>(
95       loc, min, ValueRange{hi, iv, step});
96   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
97 }
98 
99 /// Generates a vectorized invariant. Here we rely on subsequent loop
100 /// optimizations to hoist the invariant broadcast out of the vector loop.
101 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
102                                      Value val) {
103   VectorType vtp = vectorType(vl, val.getType());
104   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
105 }
106 
107 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
108 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
109 /// that the sparse compiler can only generate indirect loads in
110 /// the last index, i.e. back().
111 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
112                            Value mem, ArrayRef<Value> idxs, Value vmask) {
113   VectorType vtp = vectorType(vl, mem);
114   Value pass = constantZero(rewriter, loc, vtp);
115   if (llvm::isa<VectorType>(idxs.back().getType())) {
116     SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
117     Value indexVec = idxs.back();
118     scalarArgs.back() = constantIndex(rewriter, loc, 0);
119     return rewriter.create<vector::GatherOp>(loc, vtp, mem, scalarArgs,
120                                              indexVec, vmask, pass);
121   }
122   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask,
123                                                pass);
124 }
125 
126 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
127 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
128 /// that the sparse compiler can only generate indirect stores in
129 /// the last index, i.e. back().
130 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
131                            ArrayRef<Value> idxs, Value vmask, Value rhs) {
132   if (llvm::isa<VectorType>(idxs.back().getType())) {
133     SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
134     Value indexVec = idxs.back();
135     scalarArgs.back() = constantIndex(rewriter, loc, 0);
136     rewriter.create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask,
137                                        rhs);
138     return;
139   }
140   rewriter.create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs);
141 }
142 
143 /// Detects a vectorizable reduction operations and returns the
144 /// combining kind of reduction on success in `kind`.
145 static bool isVectorizableReduction(Value red, Value iter,
146                                     vector::CombiningKind &kind) {
147   if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
148     kind = vector::CombiningKind::ADD;
149     return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
150   }
151   if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
152     kind = vector::CombiningKind::ADD;
153     return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
154   }
155   if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
156     kind = vector::CombiningKind::ADD;
157     return subf->getOperand(0) == iter;
158   }
159   if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
160     kind = vector::CombiningKind::ADD;
161     return subi->getOperand(0) == iter;
162   }
163   if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
164     kind = vector::CombiningKind::MUL;
165     return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
166   }
167   if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
168     kind = vector::CombiningKind::MUL;
169     return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
170   }
171   if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
172     kind = vector::CombiningKind::AND;
173     return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
174   }
175   if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
176     kind = vector::CombiningKind::OR;
177     return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
178   }
179   if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
180     kind = vector::CombiningKind::XOR;
181     return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
182   }
183   return false;
184 }
185 
186 /// Generates an initial value for a vector reduction, following the scheme
187 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
188 /// initial scalar value is correctly embedded in the vector reduction value,
189 /// and a straightforward horizontal reduction will complete the operation.
190 /// Value 'r' denotes the initial value of the reduction outside the loop.
191 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
192                                 Value red, Value iter, Value r,
193                                 VectorType vtp) {
194   vector::CombiningKind kind;
195   if (!isVectorizableReduction(red, iter, kind))
196     llvm_unreachable("unknown reduction");
197   switch (kind) {
198   case vector::CombiningKind::ADD:
199   case vector::CombiningKind::XOR:
200     // Initialize reduction vector to: | 0 | .. | 0 | r |
201     return rewriter.create<vector::InsertElementOp>(
202         loc, r, constantZero(rewriter, loc, vtp),
203         constantIndex(rewriter, loc, 0));
204   case vector::CombiningKind::MUL:
205     // Initialize reduction vector to: | 1 | .. | 1 | r |
206     return rewriter.create<vector::InsertElementOp>(
207         loc, r, constantOne(rewriter, loc, vtp),
208         constantIndex(rewriter, loc, 0));
209   case vector::CombiningKind::AND:
210   case vector::CombiningKind::OR:
211     // Initialize reduction vector to: | r | .. | r | r |
212     return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
213   default:
214     break;
215   }
216   llvm_unreachable("unknown reduction kind");
217 }
218 
219 /// This method is called twice to analyze and rewrite the given subscripts.
220 /// The first call (!codegen) does the analysis. Then, on success, the second
221 /// call (codegen) yields the proper vector form in the output parameter
222 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code
223 /// stay in sync. Note that the analyis part is simple because the sparse
224 /// compiler only generates relatively simple subscript expressions.
225 ///
226 /// See https://llvm.org/docs/GetElementPtr.html for some background on
227 /// the complications described below.
228 ///
229 /// We need to generate a position/coordinate load from the sparse storage
230 /// scheme.  Narrower data types need to be zero extended before casting
231 /// the value into the `index` type used for looping and indexing.
232 ///
233 /// For the scalar case, subscripts simply zero extend narrower indices
234 /// into 64-bit values before casting to an index type without a performance
235 /// penalty. Indices that already are 64-bit, in theory, cannot express the
236 /// full range since the LLVM backend defines addressing in terms of an
237 /// unsigned pointer/signed index pair.
238 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
239                                 VL vl, ValueRange subs, bool codegen,
240                                 Value vmask, SmallVectorImpl<Value> &idxs) {
241   unsigned d = 0;
242   unsigned dim = subs.size();
243   Block *block = &forOp.getRegion().front();
244   for (auto sub : subs) {
245     bool innermost = ++d == dim;
246     // Invariant subscripts in outer dimensions simply pass through.
247     // Note that we rely on LICM to hoist loads where all subscripts
248     // are invariant in the innermost loop.
249     // Example:
250     //   a[inv][i] for inv
251     if (isInvariantValue(sub, block)) {
252       if (innermost)
253         return false;
254       if (codegen)
255         idxs.push_back(sub);
256       continue; // success so far
257     }
258     // Invariant block arguments (including outer loop indices) in outer
259     // dimensions simply pass through. Direct loop indices in the
260     // innermost loop simply pass through as well.
261     // Example:
262     //   a[i][j] for both i and j
263     if (auto arg = llvm::dyn_cast<BlockArgument>(sub)) {
264       if (isInvariantArg(arg, block) == innermost)
265         return false;
266       if (codegen)
267         idxs.push_back(sub);
268       continue; // success so far
269     }
270     // Look under the hood of casting.
271     auto cast = sub;
272     while (true) {
273       if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
274         cast = icast->getOperand(0);
275       else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
276         cast = ecast->getOperand(0);
277       else
278         break;
279     }
280     // Since the index vector is used in a subsequent gather/scatter
281     // operations, which effectively defines an unsigned pointer + signed
282     // index, we must zero extend the vector to an index width. For 8-bit
283     // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
284     // zero extending the elements into 64-bit loses some performance since
285     // the 32-bit indexed gather/scatter is more efficient than the 64-bit
286     // index variant (if the negative 32-bit index space is unused, the
287     // enableSIMDIndex32 flag can preserve this performance). For 64-bit
288     // values, there is no good way to state that the indices are unsigned,
289     // which creates the potential of incorrect address calculations in the
290     // unlikely case we need such extremely large offsets.
291     // Example:
292     //    a[ ind[i] ]
293     if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
294       if (!innermost)
295         return false;
296       if (codegen) {
297         SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
298         Location loc = forOp.getLoc();
299         Value vload =
300             genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
301         Type etp = llvm::cast<VectorType>(vload.getType()).getElementType();
302         if (!llvm::isa<IndexType>(etp)) {
303           if (etp.getIntOrFloatBitWidth() < 32)
304             vload = rewriter.create<arith::ExtUIOp>(
305                 loc, vectorType(vl, rewriter.getI32Type()), vload);
306           else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
307             vload = rewriter.create<arith::ExtUIOp>(
308                 loc, vectorType(vl, rewriter.getI64Type()), vload);
309         }
310         idxs.push_back(vload);
311       }
312       continue; // success so far
313     }
314     // Address calculation 'i = add inv, idx' (after LICM).
315     // Example:
316     //    a[base + i]
317     if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
318       Value inv = load.getOperand(0);
319       Value idx = load.getOperand(1);
320       if (isInvariantValue(inv, block)) {
321         if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) {
322           if (isInvariantArg(arg, block) || !innermost)
323             return false;
324           if (codegen)
325             idxs.push_back(
326                 rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
327           continue; // success so far
328         }
329       }
330     }
331     return false;
332   }
333   return true;
334 }
335 
336 #define UNAOP(xxx)                                                             \
337   if (isa<xxx>(def)) {                                                         \
338     if (codegen)                                                               \
339       vexp = rewriter.create<xxx>(loc, vx);                                    \
340     return true;                                                               \
341   }
342 
343 #define TYPEDUNAOP(xxx)                                                        \
344   if (auto x = dyn_cast<xxx>(def)) {                                           \
345     if (codegen) {                                                             \
346       VectorType vtp = vectorType(vl, x.getType());                            \
347       vexp = rewriter.create<xxx>(loc, vtp, vx);                               \
348     }                                                                          \
349     return true;                                                               \
350   }
351 
352 #define BINOP(xxx)                                                             \
353   if (isa<xxx>(def)) {                                                         \
354     if (codegen)                                                               \
355       vexp = rewriter.create<xxx>(loc, vx, vy);                                \
356     return true;                                                               \
357   }
358 
359 /// This method is called twice to analyze and rewrite the given expression.
360 /// The first call (!codegen) does the analysis. Then, on success, the second
361 /// call (codegen) yields the proper vector form in the output parameter 'vexp'.
362 /// This mechanism ensures that analysis and rewriting code stay in sync. Note
363 /// that the analyis part is simple because the sparse compiler only generates
364 /// relatively simple expressions inside the for-loops.
365 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
366                           Value exp, bool codegen, Value vmask, Value &vexp) {
367   Location loc = forOp.getLoc();
368   // Reject unsupported types.
369   if (!VectorType::isValidElementType(exp.getType()))
370     return false;
371   // A block argument is invariant/reduction/index.
372   if (auto arg = llvm::dyn_cast<BlockArgument>(exp)) {
373     if (arg == forOp.getInductionVar()) {
374       // We encountered a single, innermost index inside the computation,
375       // such as a[i] = i, which must convert to [i, i+1, ...].
376       if (codegen) {
377         VectorType vtp = vectorType(vl, arg.getType());
378         Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg);
379         Value incr;
380         if (vl.enableVLAVectorization) {
381           Type stepvty = vectorType(vl, rewriter.getI64Type());
382           Value stepv = rewriter.create<LLVM::StepVectorOp>(loc, stepvty);
383           incr = rewriter.create<arith::IndexCastOp>(loc, vtp, stepv);
384         } else {
385           SmallVector<APInt> integers;
386           for (unsigned i = 0, l = vl.vectorLength; i < l; i++)
387             integers.push_back(APInt(/*width=*/64, i));
388           auto values = DenseElementsAttr::get(vtp, integers);
389           incr = rewriter.create<arith::ConstantOp>(loc, vtp, values);
390         }
391         vexp = rewriter.create<arith::AddIOp>(loc, veci, incr);
392       }
393       return true;
394     }
395     // An invariant or reduction. In both cases, we treat this as an
396     // invariant value, and rely on later replacing and folding to
397     // construct a proper reduction chain for the latter case.
398     if (codegen)
399       vexp = genVectorInvariantValue(rewriter, vl, exp);
400     return true;
401   }
402   // Something defined outside the loop-body is invariant.
403   Operation *def = exp.getDefiningOp();
404   Block *block = &forOp.getRegion().front();
405   if (def->getBlock() != block) {
406     if (codegen)
407       vexp = genVectorInvariantValue(rewriter, vl, exp);
408     return true;
409   }
410   // Proper load operations. These are either values involved in the
411   // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi],
412   // or coordinate values inside the computation that are now fetched from
413   // the sparse storage coordinates arrays, such as a[i] = i becomes
414   // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index
415   // and 'hi = lo + vl - 1'.
416   if (auto load = dyn_cast<memref::LoadOp>(def)) {
417     auto subs = load.getIndices();
418     SmallVector<Value> idxs;
419     if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
420       if (codegen)
421         vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
422       return true;
423     }
424     return false;
425   }
426   // Inside loop-body unary and binary operations. Note that it would be
427   // nicer if we could somehow test and build the operations in a more
428   // concise manner than just listing them all (although this way we know
429   // for certain that they can vectorize).
430   //
431   // TODO: avoid visiting CSEs multiple times
432   //
433   if (def->getNumOperands() == 1) {
434     Value vx;
435     if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
436                       vx)) {
437       UNAOP(math::AbsFOp)
438       UNAOP(math::AbsIOp)
439       UNAOP(math::CeilOp)
440       UNAOP(math::FloorOp)
441       UNAOP(math::SqrtOp)
442       UNAOP(math::ExpM1Op)
443       UNAOP(math::Log1pOp)
444       UNAOP(math::SinOp)
445       UNAOP(math::TanhOp)
446       UNAOP(arith::NegFOp)
447       TYPEDUNAOP(arith::TruncFOp)
448       TYPEDUNAOP(arith::ExtFOp)
449       TYPEDUNAOP(arith::FPToSIOp)
450       TYPEDUNAOP(arith::FPToUIOp)
451       TYPEDUNAOP(arith::SIToFPOp)
452       TYPEDUNAOP(arith::UIToFPOp)
453       TYPEDUNAOP(arith::ExtSIOp)
454       TYPEDUNAOP(arith::ExtUIOp)
455       TYPEDUNAOP(arith::IndexCastOp)
456       TYPEDUNAOP(arith::TruncIOp)
457       TYPEDUNAOP(arith::BitcastOp)
458       // TODO: complex?
459     }
460   } else if (def->getNumOperands() == 2) {
461     Value vx, vy;
462     if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
463                       vx) &&
464         vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
465                       vy)) {
466       // We only accept shift-by-invariant (where the same shift factor applies
467       // to all packed elements). In the vector dialect, this is still
468       // represented with an expanded vector at the right-hand-side, however,
469       // so that we do not have to special case the code generation.
470       if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
471           isa<arith::ShRSIOp>(def)) {
472         Value shiftFactor = def->getOperand(1);
473         if (!isInvariantValue(shiftFactor, block))
474           return false;
475       }
476       // Generate code.
477       BINOP(arith::MulFOp)
478       BINOP(arith::MulIOp)
479       BINOP(arith::DivFOp)
480       BINOP(arith::DivSIOp)
481       BINOP(arith::DivUIOp)
482       BINOP(arith::AddFOp)
483       BINOP(arith::AddIOp)
484       BINOP(arith::SubFOp)
485       BINOP(arith::SubIOp)
486       BINOP(arith::AndIOp)
487       BINOP(arith::OrIOp)
488       BINOP(arith::XOrIOp)
489       BINOP(arith::ShLIOp)
490       BINOP(arith::ShRUIOp)
491       BINOP(arith::ShRSIOp)
492       // TODO: complex?
493     }
494   }
495   return false;
496 }
497 
498 #undef UNAOP
499 #undef TYPEDUNAOP
500 #undef BINOP
501 
502 /// This method is called twice to analyze and rewrite the given for-loop.
503 /// The first call (!codegen) does the analysis. Then, on success, the second
504 /// call (codegen) rewriters the IR into vector form. This mechanism ensures
505 /// that analysis and rewriting code stay in sync.
506 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
507                           bool codegen) {
508   Location loc = forOp.getLoc();
509   Block &block = forOp.getRegion().front();
510   scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
511   auto &last = *++block.rbegin();
512   scf::ForOp forOpNew;
513 
514   // Perform initial set up during codegen (we know that the first analysis
515   // pass was successful). For reductions, we need to construct a completely
516   // new for-loop, since the incoming and outgoing reduction type
517   // changes into SIMD form. For stores, we can simply adjust the stride
518   // and insert in the existing for-loop. In both cases, we set up a vector
519   // mask for all operations which takes care of confining vectors to
520   // the original iteration space (later cleanup loops or other
521   // optimizations can take care of those).
522   Value vmask;
523   if (codegen) {
524     Value step = constantIndex(rewriter, loc, vl.vectorLength);
525     if (vl.enableVLAVectorization) {
526       Value vscale =
527           rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
528       step = rewriter.create<arith::MulIOp>(loc, vscale, step);
529     }
530     if (!yield.getResults().empty()) {
531       Value init = forOp.getInitArgs()[0];
532       VectorType vtp = vectorType(vl, init.getType());
533       Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
534                                        forOp.getRegionIterArg(0), init, vtp);
535       forOpNew = rewriter.create<scf::ForOp>(
536           loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
537       forOpNew->setAttr(
538           LoopEmitter::getLoopEmitterLoopAttrName(),
539           forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
540       rewriter.setInsertionPointToStart(forOpNew.getBody());
541     } else {
542       rewriter.updateRootInPlace(forOp, [&]() { forOp.setStep(step); });
543       rewriter.setInsertionPoint(yield);
544     }
545     vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
546                           forOp.getLowerBound(), forOp.getUpperBound(), step);
547   }
548 
549   // Sparse for-loops either are terminated by a non-empty yield operation
550   // (reduction loop) or otherwise by a store operation (pararallel loop).
551   if (!yield.getResults().empty()) {
552     // Analyze/vectorize reduction.
553     if (yield->getNumOperands() != 1)
554       return false;
555     Value red = yield->getOperand(0);
556     Value iter = forOp.getRegionIterArg(0);
557     vector::CombiningKind kind;
558     Value vrhs;
559     if (isVectorizableReduction(red, iter, kind) &&
560         vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
561       if (codegen) {
562         Value partial = forOpNew.getResult(0);
563         Value vpass = genVectorInvariantValue(rewriter, vl, iter);
564         Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
565         rewriter.create<scf::YieldOp>(loc, vred);
566         rewriter.setInsertionPointAfter(forOpNew);
567         Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial);
568         // Now do some relinking (last one is not completely type safe
569         // but all bad ones are removed right away). This also folds away
570         // nop broadcast operations.
571         rewriter.replaceAllUsesWith(forOp.getResult(0), vres);
572         rewriter.replaceAllUsesWith(forOp.getInductionVar(),
573                                     forOpNew.getInductionVar());
574         rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0),
575                                     forOpNew.getRegionIterArg(0));
576         rewriter.eraseOp(forOp);
577       }
578       return true;
579     }
580   } else if (auto store = dyn_cast<memref::StoreOp>(last)) {
581     // Analyze/vectorize store operation.
582     auto subs = store.getIndices();
583     SmallVector<Value> idxs;
584     Value rhs = store.getValue();
585     Value vrhs;
586     if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
587         vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
588       if (codegen) {
589         genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
590         rewriter.eraseOp(store);
591       }
592       return true;
593     }
594   }
595 
596   assert(!codegen && "cannot call codegen when analysis failed");
597   return false;
598 }
599 
600 /// Basic for-loop vectorizer.
601 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
602 public:
603   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
604 
605   ForOpRewriter(MLIRContext *context, unsigned vectorLength,
606                 bool enableVLAVectorization, bool enableSIMDIndex32)
607       : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization,
608                                       enableSIMDIndex32} {}
609 
610   LogicalResult matchAndRewrite(scf::ForOp op,
611                                 PatternRewriter &rewriter) const override {
612     // Check for single block, unit-stride for-loop that is generated by
613     // sparse compiler, which means no data dependence analysis is required,
614     // and its loop-body is very restricted in form.
615     if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) ||
616         !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()))
617       return failure();
618     // Analyze (!codegen) and rewrite (codegen) loop-body.
619     if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
620         vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
621       return success();
622     return failure();
623   }
624 
625 private:
626   const VL vl;
627 };
628 
629 /// Reduction chain cleanup.
630 ///   v = for { }
631 ///   s = vsum(v)               v = for { }
632 ///   u = expand(s)       ->    for (v) { }
633 ///   for (u) { }
634 template <typename VectorOp>
635 struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
636 public:
637   using OpRewritePattern<VectorOp>::OpRewritePattern;
638 
639   LogicalResult matchAndRewrite(VectorOp op,
640                                 PatternRewriter &rewriter) const override {
641     Value inp = op.getSource();
642     if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
643       if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
644         if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
645           rewriter.replaceOp(op, redOp.getVector());
646           return success();
647         }
648       }
649     }
650     return failure();
651   }
652 };
653 
654 } // namespace
655 
656 //===----------------------------------------------------------------------===//
657 // Public method for populating vectorization rules.
658 //===----------------------------------------------------------------------===//
659 
660 /// Populates the given patterns list with vectorization rules.
661 void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
662                                                unsigned vectorLength,
663                                                bool enableVLAVectorization,
664                                                bool enableSIMDIndex32) {
665   assert(vectorLength > 0);
666   patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
667                               enableVLAVectorization, enableSIMDIndex32);
668   patterns.add<ReducChainRewriter<vector::InsertElementOp>,
669                ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
670 }
671