xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (revision 99b3849d89cfdbc60ce4e18fc9c70dfd377bd93b)
1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts loops generated by the sparse compiler into a form that
10 // can exploit SIMD instructions of the target architecture. Note that this pass
11 // ensures the sparse compiler can generate efficient SIMD (including ArmSVE
12 // support) with proper separation of concerns as far as sparsification and
13 // vectorization is concerned. However, this pass is not the final abstraction
14 // level we want, and not the general vectorizer we want either. It forms a good
15 // stepping stone for incremental future improvements though.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "CodegenUtils.h"
20 
21 #include "mlir/Dialect/Affine/IR/AffineOps.h"
22 #include "mlir/Dialect/Arith/IR/Arith.h"
23 #include "mlir/Dialect/Complex/IR/Complex.h"
24 #include "mlir/Dialect/Math/IR/Math.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/Dialect/SCF/IR/SCF.h"
27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
28 #include "mlir/Dialect/Vector/IR/VectorOps.h"
29 #include "mlir/IR/Matchers.h"
30 
31 using namespace mlir;
32 using namespace mlir::sparse_tensor;
33 
34 namespace {
35 
36 /// Target SIMD properties:
37 ///   vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
38 ///   enableVLAVectorization: enables scalable vectors (viz. ARMSve)
39 ///   enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
40 struct VL {
41   unsigned vectorLength;
42   bool enableVLAVectorization;
43   bool enableSIMDIndex32;
44 };
45 
46 /// Helper to test for given index value.
47 static bool isIntValue(Value val, int64_t idx) {
48   if (auto ival = getConstantIntValue(val))
49     return *ival == idx;
50   return false;
51 }
52 
53 /// Constructs vector type for element type.
54 static VectorType vectorType(VL vl, Type etp) {
55   unsigned numScalableDims = vl.enableVLAVectorization;
56   return VectorType::get(vl.vectorLength, etp, numScalableDims);
57 }
58 
59 /// Constructs vector type from pointer.
60 static VectorType vectorType(VL vl, Value ptr) {
61   return vectorType(vl, ptr.getType().cast<MemRefType>().getElementType());
62 }
63 
64 /// Constructs vector iteration mask.
65 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
66                            Value iv, Value lo, Value hi, Value step) {
67   VectorType mtp = vectorType(vl, rewriter.getI1Type());
68   // Special case if the vector length evenly divides the trip count (for
69   // example, "for i = 0, 128, 16"). A constant all-true mask is generated
70   // so that all subsequent masked memory operations are immediately folded
71   // into unconditional memory operations.
72   IntegerAttr loInt, hiInt, stepInt;
73   if (matchPattern(lo, m_Constant(&loInt)) &&
74       matchPattern(hi, m_Constant(&hiInt)) &&
75       matchPattern(step, m_Constant(&stepInt))) {
76     if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
77       Value trueVal = constantI1(rewriter, loc, true);
78       return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal);
79     }
80   }
81   // Otherwise, generate a vector mask that avoids overrunning the upperbound
82   // during vector execution. Here we rely on subsequent loop optimizations to
83   // avoid executing the mask in all iterations, for example, by splitting the
84   // loop into an unconditional vector loop and a scalar cleanup loop.
85   auto min = AffineMap::get(
86       /*dimCount=*/2, /*symbolCount=*/1,
87       {rewriter.getAffineSymbolExpr(0),
88        rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
89       rewriter.getContext());
90   Value end =
91       rewriter.createOrFold<AffineMinOp>(loc, min, ValueRange{hi, iv, step});
92   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
93 }
94 
95 /// Generates a vectorized invariant. Here we rely on subsequent loop
96 /// optimizations to hoist the invariant broadcast out of the vector loop.
97 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
98                                      Value val) {
99   VectorType vtp = vectorType(vl, val.getType());
100   return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
101 }
102 
103 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
104 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'.
105 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
106                            Value ptr, ArrayRef<Value> idxs, Value vmask) {
107   VectorType vtp = vectorType(vl, ptr);
108   Value pass = constantZero(rewriter, loc, vtp);
109   if (idxs.back().getType().isa<VectorType>()) {
110     SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
111     Value indexVec = idxs.back();
112     scalarArgs.back() = constantIndex(rewriter, loc, 0);
113     return rewriter.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs,
114                                              indexVec, vmask, pass);
115   }
116   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, idxs, vmask,
117                                                pass);
118 }
119 
120 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
121 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'.
122 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr,
123                            ArrayRef<Value> idxs, Value vmask, Value rhs) {
124   if (idxs.back().getType().isa<VectorType>()) {
125     SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
126     Value indexVec = idxs.back();
127     scalarArgs.back() = constantIndex(rewriter, loc, 0);
128     rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, vmask,
129                                        rhs);
130     return;
131   }
132   rewriter.create<vector::MaskedStoreOp>(loc, ptr, idxs, vmask, rhs);
133 }
134 
135 /// Maps operation to combining kind for reduction.
136 static vector::CombiningKind getCombiningKind(Operation *def) {
137   if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def) ||
138       isa<arith::SubFOp>(def) || isa<arith::SubIOp>(def))
139     return vector::CombiningKind::ADD;
140   if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
141     return vector::CombiningKind::MUL;
142   if (isa<arith::AndIOp>(def))
143     return vector::CombiningKind::AND;
144   if (isa<arith::OrIOp>(def))
145     return vector::CombiningKind::OR;
146   if (isa<arith::XOrIOp>(def))
147     return vector::CombiningKind::XOR;
148   llvm_unreachable("unknown reduction kind");
149 }
150 
151 /// Generates an initial value for a vector reduction, following the scheme
152 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
153 /// initial scalar value is correctly embedded in the vector reduction value,
154 /// and a straightforward horizontal reduction will complete the operation.
155 /// The value 'r' denotes the initial value of the accumulator. Value 'rd'
156 /// denotes the accumulation operation, which is solely used here to determine
157 /// the kind of combining reduction (viz. addf -> sum-accumulation).
158 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
159                                 VectorType vtp, Value r, Value rd) {
160   vector::CombiningKind kind = getCombiningKind(rd.getDefiningOp());
161   switch (kind) {
162   case vector::CombiningKind::ADD:
163   case vector::CombiningKind::XOR:
164     // Initialize reduction vector to: | 0 | .. | 0 | r |
165     return rewriter.create<vector::InsertElementOp>(
166         loc, r, constantZero(rewriter, loc, vtp),
167         constantIndex(rewriter, loc, 0));
168   case vector::CombiningKind::MUL:
169     // Initialize reduction vector to: | 1 | .. | 1 | r |
170     return rewriter.create<vector::InsertElementOp>(
171         loc, r, constantOne(rewriter, loc, vtp),
172         constantIndex(rewriter, loc, 0));
173   case vector::CombiningKind::AND:
174   case vector::CombiningKind::OR:
175     // Initialize reduction vector to: | r | .. | r | r |
176     return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
177   default:
178     break;
179   }
180   llvm_unreachable("unknown reduction kind");
181 }
182 
183 /// Generates final value for a vector reduction.
184 static Value genVectorReducEnd(PatternRewriter &rewriter, Location loc,
185                                Value vexp, Value rd) {
186   vector::CombiningKind kind = getCombiningKind(rd.getDefiningOp());
187   return rewriter.create<vector::ReductionOp>(loc, kind, vexp);
188 }
189 
190 /// This method is called twice to analyze and rewrite the given subscripts.
191 /// The first call (!codegen) does the analysis. Then, on success, the second
192 /// call (codegen) yields the proper vector form in the output parameter
193 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code
194 /// stay in sync.
195 ///
196 /// See https://llvm.org/docs/GetElementPtr.html for some background on
197 /// the complications described below.
198 ///
199 /// We need to generate a pointer/index load from the sparse storage scheme.
200 /// Narrower data types need to be zero extended before casting the value
201 /// into the index type used for looping and indexing.
202 ///
203 /// For the scalar case, subscripts simply zero extend narrower indices
204 /// into 64-bit values before casting to an index type without a performance
205 /// penalty. Indices that already are 64-bit, in theory, cannot express the
206 /// full range since the LLVM backend defines addressing in terms of an
207 /// unsigned pointer/signed index pair.
208 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
209                                 VL vl, ValueRange subs, bool codegen,
210                                 Value vmask, SmallVectorImpl<Value> &idxs) {
211   for (auto sub : subs) {
212     // Invariant indices simply pass through.
213     if (sub.dyn_cast<BlockArgument>() ||
214         sub.getDefiningOp()->getBlock() != &forOp.getRegion().front()) {
215       if (codegen)
216         idxs.push_back(sub);
217       continue; // success so far
218     }
219     // Look under the hood of casting.
220     auto cast = sub;
221     while (1) {
222       if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
223         cast = icast->getOperand(0);
224       else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
225         cast = ecast->getOperand(0);
226       else
227         break;
228     }
229     // Since the index vector is used in a subsequent gather/scatter
230     // operations, which effectively defines an unsigned pointer + signed
231     // index, we must zero extend the vector to an index width. For 8-bit
232     // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
233     // zero extending the elements into 64-bit loses some performance since
234     // the 32-bit indexed gather/scatter is more efficient than the 64-bit
235     // index variant (if the negative 32-bit index space is unused, the
236     // enableSIMDIndex32 flag can preserve this performance). For 64-bit
237     // values, there is no good way to state that the indices are unsigned,
238     // which creates the potential of incorrect address calculations in the
239     // unlikely case we need such extremely large offsets.
240     if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
241       if (codegen) {
242         SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
243         Location loc = forOp.getLoc();
244         Value vload =
245             genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
246         Type etp = vload.getType().cast<VectorType>().getElementType();
247         if (!etp.isa<IndexType>()) {
248           if (etp.getIntOrFloatBitWidth() < 32)
249             vload = rewriter.create<arith::ExtUIOp>(
250                 loc, vectorType(vl, rewriter.getI32Type()), vload);
251           else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
252             vload = rewriter.create<arith::ExtUIOp>(
253                 loc, vectorType(vl, rewriter.getI64Type()), vload);
254         }
255         idxs.push_back(vload);
256       }
257       continue; // success so far
258     }
259     return false;
260   }
261   return true;
262 }
263 
264 #define UNAOP(xxx)                                                             \
265   if (isa<xxx>(def)) {                                                         \
266     if (codegen)                                                               \
267       vexp = rewriter.create<xxx>(loc, vx);                                    \
268     return true;                                                               \
269   }
270 
271 #define BINOP(xxx)                                                             \
272   if (isa<xxx>(def)) {                                                         \
273     if (codegen)                                                               \
274       vexp = rewriter.create<xxx>(loc, vx, vy);                                \
275     return true;                                                               \
276   }
277 
278 /// This method is called twice to analyze and rewrite the given expression.
279 /// The first call (!codegen) does the analysis. Then, on success, the second
280 /// call (codegen) yields the proper vector form in the output parameter 'vexp'.
281 /// This mechanism ensures that analysis and rewriting code stay in sync.
282 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
283                           Value exp, bool codegen, Value vmask, Value &vexp) {
284   // A block argument in invariant.
285   if (auto arg = exp.dyn_cast<BlockArgument>()) {
286     if (codegen)
287       vexp = genVectorInvariantValue(rewriter, vl, exp);
288     return true;
289   }
290   // Something defined outside the loop-body is invariant as well.
291   Operation *def = exp.getDefiningOp();
292   if (def->getBlock() != &forOp.getRegion().front()) {
293     if (codegen)
294       vexp = genVectorInvariantValue(rewriter, vl, exp);
295     return true;
296   }
297   // Inside loop-body unary and binary operations. Note that it would be
298   // nicer if we could somehow test and build the operations in a more
299   // concise manner than just listing them all (although this way we know
300   // for certain that they can vectorize).
301   Location loc = forOp.getLoc();
302   if (auto load = dyn_cast<memref::LoadOp>(def)) {
303     auto subs = load.getIndices();
304     SmallVector<Value> idxs;
305     if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
306       if (codegen)
307         vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
308       return true;
309     }
310   } else if (def->getNumOperands() == 1) {
311     Value vx;
312     if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
313                       vx)) {
314       UNAOP(math::AbsFOp)
315       UNAOP(math::AbsIOp)
316       UNAOP(math::CeilOp)
317       UNAOP(math::FloorOp)
318       UNAOP(math::SqrtOp)
319       UNAOP(math::ExpM1Op)
320       UNAOP(math::Log1pOp)
321       UNAOP(math::SinOp)
322       UNAOP(math::TanhOp)
323       UNAOP(arith::NegFOp)
324     }
325   } else if (def->getNumOperands() == 2) {
326     Value vx, vy;
327     if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
328                       vx) &&
329         vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
330                       vy)) {
331       BINOP(arith::MulFOp)
332       BINOP(arith::MulIOp)
333       BINOP(arith::DivFOp)
334       BINOP(arith::DivSIOp)
335       BINOP(arith::DivUIOp)
336       BINOP(arith::AddFOp)
337       BINOP(arith::AddIOp)
338       BINOP(arith::SubFOp)
339       BINOP(arith::SubIOp)
340       BINOP(arith::AndIOp)
341       BINOP(arith::OrIOp)
342       BINOP(arith::XOrIOp)
343     }
344   }
345   return false;
346 }
347 
348 #undef UNAOP
349 #undef BINOP
350 
351 /// This method is called twice to analyze and rewrite the given for-loop.
352 /// The first call (!codegen) does the analysis. Then, on success, the second
353 /// call (codegen) rewriters the IR into vector form. This mechanism ensures
354 /// that analysis and rewriting code stay in sync.
355 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
356                           bool codegen) {
357   Location loc = forOp.getLoc();
358   Block &block = forOp.getRegion().front();
359   scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
360   auto &last = *++block.rbegin();
361   scf::ForOp forOpNew;
362 
363   // Perform initial set up during codegen (we know that the first analysis
364   // pass was successful). For reductions, we need to construct a completely
365   // new for-loop, since the incoming and outgoing reduction type
366   // changes into SIMD form. For stores, we can simply adjust the stride
367   // and insert in the existing for-loop. In both cases, we set up a vector
368   // mask for all operations which takes care of confining vectors to
369   // the original iteration space (later cleanup loops or other
370   // optimizations can take care of those).
371   Value vmask;
372   if (codegen) {
373     Value step = constantIndex(rewriter, loc, vl.vectorLength);
374     if (vl.enableVLAVectorization) {
375       Value vscale =
376           rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
377       step = rewriter.create<arith::MulIOp>(loc, vscale, step);
378     }
379     if (!yield.getResults().empty()) {
380       Value init = forOp.getInitArgs()[0];
381       VectorType vtp = vectorType(vl, init.getType());
382       Value vinit =
383           genVectorReducInit(rewriter, loc, vtp, init, yield->getOperand(0));
384       forOpNew = rewriter.create<scf::ForOp>(
385           loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
386       rewriter.setInsertionPointToStart(forOpNew.getBody());
387     } else {
388       forOp.setStep(step);
389       rewriter.setInsertionPoint(yield);
390     }
391     vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
392                           forOp.getLowerBound(), forOp.getUpperBound(), step);
393   }
394 
395   // Sparse for-loops either are terminated by a non-empty yield operation
396   // (reduction loop) or otherwise by a store operation (pararallel loop).
397   if (!yield.getResults().empty()) {
398     if (yield->getNumOperands() != 1)
399       return false;
400     Value redOp = yield->getOperand(0);
401     // Analyze/vectorize reduction.
402     // TODO: use linalg utils to verify the actual reduction?
403     Value vrhs;
404     if (vectorizeExpr(rewriter, forOp, vl, redOp, codegen, vmask, vrhs)) {
405       if (codegen) {
406         Value vpass =
407             genVectorInvariantValue(rewriter, vl, forOp.getRegionIterArg(0));
408         Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
409         rewriter.create<scf::YieldOp>(loc, vred);
410         rewriter.setInsertionPointAfter(forOpNew);
411         Value vres = genVectorReducEnd(rewriter, loc, forOpNew.getResult(0), redOp);
412         // Now do some relinking (last one is not completely type safe
413         // but all bad ones are removed right away). This also folds away
414         // nop broadcast operations.
415         forOp.getResult(0).replaceAllUsesWith(vres);
416         forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar());
417         forOp.getRegionIterArg(0).replaceAllUsesWith(
418             forOpNew.getRegionIterArg(0));
419         rewriter.eraseOp(forOp);
420       }
421       return true;
422     }
423   } else if (auto store = dyn_cast<memref::StoreOp>(last)) {
424     // Analyze/vectorize store operation.
425     auto subs = store.getIndices();
426     SmallVector<Value> idxs;
427     Value rhs = store.getValue();
428     Value vrhs;
429     if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
430         vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
431       if (codegen) {
432         genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
433         rewriter.eraseOp(store);
434       }
435       return true;
436     }
437   }
438 
439   assert(!codegen && "cannot call codegen when analysis failed");
440   return false;
441 }
442 
443 /// Basic for-loop vectorizer.
444 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
445 public:
446   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
447 
448   ForOpRewriter(MLIRContext *context, unsigned vectorLength,
449                 bool enableVLAVectorization, bool enableSIMDIndex32)
450       : OpRewritePattern(context),
451         vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {}
452 
453   LogicalResult matchAndRewrite(scf::ForOp op,
454                                 PatternRewriter &rewriter) const override {
455     // Check for single block, unit-stride for-loop that is generated by
456     // sparse compiler, which means no data dependence analysis is required,
457     // and its loop-body is very restricted in form.
458     if (!op.getRegion().hasOneBlock() || !isIntValue(op.getStep(), 1) ||
459         !op->hasAttr(SparseTensorLoopEmitter::getLoopEmitterLoopAttrName()))
460       return failure();
461     // Analyze (!codegen) and rewrite (codegen) loop-body.
462     if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
463         vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
464       return success();
465     return failure();
466   }
467 
468 private:
469   const VL vl;
470 };
471 
472 } // namespace
473 
474 //===----------------------------------------------------------------------===//
475 // Public method for populating vectorization rules.
476 //===----------------------------------------------------------------------===//
477 
478 /// Populates the given patterns list with vectorization rules.
479 void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
480                                                unsigned vectorLength,
481                                                bool enableVLAVectorization,
482                                                bool enableSIMDIndex32) {
483   patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
484                               enableVLAVectorization, enableSIMDIndex32);
485 }
486