xref: /llvm-project/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (revision 35df525fd00c2037ef144189ee818b7d612241ff)
1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
19 #include "mlir/Dialect/Utils/StaticValueUtils.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/IR/Attributes.h"
22 #include "mlir/IR/BuiltinAttributes.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Location.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/IR/TypeUtilities.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/SmallVectorExtras.h"
33 #include "llvm/Support/FormatVariadic.h"
34 #include <cassert>
35 #include <cstdint>
36 #include <numeric>
37 
38 using namespace mlir;
39 
40 /// Returns the integer value from the first valid input element, assuming Value
41 /// inputs are defined by a constant index ops and Attribute inputs are integer
42 /// attributes.
43 static uint64_t getFirstIntValue(ArrayAttr attr) {
44   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
45 }
46 
47 /// Returns the number of bits for the given scalar/vector type.
48 static int getNumBits(Type type) {
49   // TODO: This does not take into account any memory layout or widening
50   // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
51   // though in practice it will likely be stored as in a 4xi64 vector register.
52   if (auto vectorType = dyn_cast<VectorType>(type))
53     return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
54   return type.getIntOrFloatBitWidth();
55 }
56 
57 namespace {
58 
59 struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
60   using OpConversionPattern::OpConversionPattern;
61 
62   LogicalResult
63   matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
64                   ConversionPatternRewriter &rewriter) const override {
65     Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
66     if (!dstType)
67       return failure();
68 
69     // If dstType is same as the source type or the vector size is 1, it can be
70     // directly replaced by the source.
71     if (dstType == adaptor.getSource().getType() ||
72         shapeCastOp.getResultVectorType().getNumElements() == 1) {
73       rewriter.replaceOp(shapeCastOp, adaptor.getSource());
74       return success();
75     }
76 
77     // Lowering for size-n vectors when n > 1 hasn't been implemented.
78     return failure();
79   }
80 };
81 
82 struct VectorBitcastConvert final
83     : public OpConversionPattern<vector::BitCastOp> {
84   using OpConversionPattern::OpConversionPattern;
85 
86   LogicalResult
87   matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
88                   ConversionPatternRewriter &rewriter) const override {
89     Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
90     if (!dstType)
91       return failure();
92 
93     if (dstType == adaptor.getSource().getType()) {
94       rewriter.replaceOp(bitcastOp, adaptor.getSource());
95       return success();
96     }
97 
98     // Check that the source and destination type have the same bitwidth.
99     // Depending on the target environment, we may need to emulate certain
100     // types, which can cause issue with bitcast.
101     Type srcType = adaptor.getSource().getType();
102     if (getNumBits(dstType) != getNumBits(srcType)) {
103       return rewriter.notifyMatchFailure(
104           bitcastOp,
105           llvm::formatv("different source ({0}) and target ({1}) bitwidth",
106                         srcType, dstType));
107     }
108 
109     rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
110                                                   adaptor.getSource());
111     return success();
112   }
113 };
114 
115 struct VectorBroadcastConvert final
116     : public OpConversionPattern<vector::BroadcastOp> {
117   using OpConversionPattern::OpConversionPattern;
118 
119   LogicalResult
120   matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
121                   ConversionPatternRewriter &rewriter) const override {
122     Type resultType =
123         getTypeConverter()->convertType(castOp.getResultVectorType());
124     if (!resultType)
125       return failure();
126 
127     if (isa<spirv::ScalarType>(resultType)) {
128       rewriter.replaceOp(castOp, adaptor.getSource());
129       return success();
130     }
131 
132     SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
133                                  adaptor.getSource());
134     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
135                                                              source);
136     return success();
137   }
138 };
139 
140 struct VectorExtractOpConvert final
141     : public OpConversionPattern<vector::ExtractOp> {
142   using OpConversionPattern::OpConversionPattern;
143 
144   LogicalResult
145   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
146                   ConversionPatternRewriter &rewriter) const override {
147     Type dstType = getTypeConverter()->convertType(extractOp.getType());
148     if (!dstType)
149       return failure();
150 
151     if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
152       rewriter.replaceOp(extractOp, adaptor.getVector());
153       return success();
154     }
155 
156     if (std::optional<int64_t> id =
157             getConstantIntValue(extractOp.getMixedPosition()[0]))
158       rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
159           extractOp, dstType, adaptor.getVector(),
160           rewriter.getI32ArrayAttr(id.value()));
161     else
162       rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
163           extractOp, dstType, adaptor.getVector(),
164           adaptor.getDynamicPosition()[0]);
165     return success();
166   }
167 };
168 
169 struct VectorExtractStridedSliceOpConvert final
170     : public OpConversionPattern<vector::ExtractStridedSliceOp> {
171   using OpConversionPattern::OpConversionPattern;
172 
173   LogicalResult
174   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
175                   ConversionPatternRewriter &rewriter) const override {
176     Type dstType = getTypeConverter()->convertType(extractOp.getType());
177     if (!dstType)
178       return failure();
179 
180     uint64_t offset = getFirstIntValue(extractOp.getOffsets());
181     uint64_t size = getFirstIntValue(extractOp.getSizes());
182     uint64_t stride = getFirstIntValue(extractOp.getStrides());
183     if (stride != 1)
184       return failure();
185 
186     Value srcVector = adaptor.getOperands().front();
187 
188     // Extract vector<1xT> case.
189     if (isa<spirv::ScalarType>(dstType)) {
190       rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
191                                                              srcVector, offset);
192       return success();
193     }
194 
195     SmallVector<int32_t, 2> indices(size);
196     std::iota(indices.begin(), indices.end(), offset);
197 
198     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
199         extractOp, dstType, srcVector, srcVector,
200         rewriter.getI32ArrayAttr(indices));
201 
202     return success();
203   }
204 };
205 
206 template <class SPIRVFMAOp>
207 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
208   using OpConversionPattern::OpConversionPattern;
209 
210   LogicalResult
211   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
212                   ConversionPatternRewriter &rewriter) const override {
213     Type dstType = getTypeConverter()->convertType(fmaOp.getType());
214     if (!dstType)
215       return failure();
216     rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
217                                             adaptor.getRhs(), adaptor.getAcc());
218     return success();
219   }
220 };
221 
222 struct VectorFromElementsOpConvert final
223     : public OpConversionPattern<vector::FromElementsOp> {
224   using OpConversionPattern::OpConversionPattern;
225 
226   LogicalResult
227   matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
228                   ConversionPatternRewriter &rewriter) const override {
229     Type resultType = getTypeConverter()->convertType(op.getType());
230     if (!resultType)
231       return failure();
232     OperandRange elements = op.getElements();
233     if (isa<spirv::ScalarType>(resultType)) {
234       // In the case with a single scalar operand / single-element result,
235       // pass through the scalar.
236       rewriter.replaceOp(op, elements[0]);
237       return success();
238     }
239     // SPIRVTypeConverter rejects vectors with rank > 1, so multi-dimensional
240     // vector.from_elements cases should not need to be handled, only 1d.
241     assert(cast<VectorType>(resultType).getRank() == 1);
242     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
243                                                              elements);
244     return success();
245   }
246 };
247 
248 struct VectorInsertOpConvert final
249     : public OpConversionPattern<vector::InsertOp> {
250   using OpConversionPattern::OpConversionPattern;
251 
252   LogicalResult
253   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
254                   ConversionPatternRewriter &rewriter) const override {
255     if (isa<VectorType>(insertOp.getSourceType()))
256       return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
257     if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
258       return rewriter.notifyMatchFailure(insertOp,
259                                          "unsupported dest vector type");
260 
261     // Special case for inserting scalar values into size-1 vectors.
262     if (insertOp.getSourceType().isIntOrFloat() &&
263         insertOp.getDestVectorType().getNumElements() == 1) {
264       rewriter.replaceOp(insertOp, adaptor.getSource());
265       return success();
266     }
267 
268     if (std::optional<int64_t> id =
269             getConstantIntValue(insertOp.getMixedPosition()[0]))
270       rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
271           insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
272     else
273       rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
274           insertOp, insertOp.getDest(), adaptor.getSource(),
275           adaptor.getDynamicPosition()[0]);
276     return success();
277   }
278 };
279 
280 struct VectorExtractElementOpConvert final
281     : public OpConversionPattern<vector::ExtractElementOp> {
282   using OpConversionPattern::OpConversionPattern;
283 
284   LogicalResult
285   matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
286                   ConversionPatternRewriter &rewriter) const override {
287     Type resultType = getTypeConverter()->convertType(extractOp.getType());
288     if (!resultType)
289       return failure();
290 
291     if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
292       rewriter.replaceOp(extractOp, adaptor.getVector());
293       return success();
294     }
295 
296     APInt cstPos;
297     if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
298       rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
299           extractOp, resultType, adaptor.getVector(),
300           rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
301     else
302       rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
303           extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
304     return success();
305   }
306 };
307 
308 struct VectorInsertElementOpConvert final
309     : public OpConversionPattern<vector::InsertElementOp> {
310   using OpConversionPattern::OpConversionPattern;
311 
312   LogicalResult
313   matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
314                   ConversionPatternRewriter &rewriter) const override {
315     Type vectorType = getTypeConverter()->convertType(insertOp.getType());
316     if (!vectorType)
317       return failure();
318 
319     if (isa<spirv::ScalarType>(vectorType)) {
320       rewriter.replaceOp(insertOp, adaptor.getSource());
321       return success();
322     }
323 
324     APInt cstPos;
325     if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
326       rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
327           insertOp, adaptor.getSource(), adaptor.getDest(),
328           cstPos.getSExtValue());
329     else
330       rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
331           insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
332           adaptor.getPosition());
333     return success();
334   }
335 };
336 
337 struct VectorInsertStridedSliceOpConvert final
338     : public OpConversionPattern<vector::InsertStridedSliceOp> {
339   using OpConversionPattern::OpConversionPattern;
340 
341   LogicalResult
342   matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
343                   ConversionPatternRewriter &rewriter) const override {
344     Value srcVector = adaptor.getOperands().front();
345     Value dstVector = adaptor.getOperands().back();
346 
347     uint64_t stride = getFirstIntValue(insertOp.getStrides());
348     if (stride != 1)
349       return failure();
350     uint64_t offset = getFirstIntValue(insertOp.getOffsets());
351 
352     if (isa<spirv::ScalarType>(srcVector.getType())) {
353       assert(!isa<spirv::ScalarType>(dstVector.getType()));
354       rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
355           insertOp, dstVector.getType(), srcVector, dstVector,
356           rewriter.getI32ArrayAttr(offset));
357       return success();
358     }
359 
360     uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
361     uint64_t insertSize =
362         cast<VectorType>(srcVector.getType()).getNumElements();
363 
364     SmallVector<int32_t, 2> indices(totalSize);
365     std::iota(indices.begin(), indices.end(), 0);
366     std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
367               totalSize);
368 
369     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
370         insertOp, dstVector.getType(), dstVector, srcVector,
371         rewriter.getI32ArrayAttr(indices));
372 
373     return success();
374   }
375 };
376 
377 static SmallVector<Value> extractAllElements(
378     vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
379     VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
380   int numElements = static_cast<int>(srcVectorType.getDimSize(0));
381   SmallVector<Value> values;
382   values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
383   Location loc = reduceOp.getLoc();
384 
385   for (int i = 0; i < numElements; ++i) {
386     values.push_back(rewriter.create<spirv::CompositeExtractOp>(
387         loc, srcVectorType.getElementType(), adaptor.getVector(),
388         rewriter.getI32ArrayAttr({i})));
389   }
390   if (Value acc = adaptor.getAcc())
391     values.push_back(acc);
392 
393   return values;
394 }
395 
396 struct ReductionRewriteInfo {
397   Type resultType;
398   SmallVector<Value> extractedElements;
399 };
400 
401 FailureOr<ReductionRewriteInfo> static getReductionInfo(
402     vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
403     ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
404   Type resultType = typeConverter.convertType(op.getType());
405   if (!resultType)
406     return failure();
407 
408   auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
409   if (!srcVectorType || srcVectorType.getRank() != 1)
410     return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
411 
412   SmallVector<Value> extractedElements =
413       extractAllElements(op, adaptor, srcVectorType, rewriter);
414 
415   return ReductionRewriteInfo{resultType, std::move(extractedElements)};
416 }
417 
418 template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
419           typename SPIRVSMinOp>
420 struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
421   using OpConversionPattern::OpConversionPattern;
422 
423   LogicalResult
424   matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
425                   ConversionPatternRewriter &rewriter) const override {
426     auto reductionInfo =
427         getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
428     if (failed(reductionInfo))
429       return failure();
430 
431     auto [resultType, extractedElements] = *reductionInfo;
432     Location loc = reduceOp->getLoc();
433     Value result = extractedElements.front();
434     for (Value next : llvm::drop_begin(extractedElements)) {
435       switch (reduceOp.getKind()) {
436 
437 #define INT_AND_FLOAT_CASE(kind, iop, fop)                                     \
438   case vector::CombiningKind::kind:                                            \
439     if (llvm::isa<IntegerType>(resultType)) {                                  \
440       result = rewriter.create<spirv::iop>(loc, resultType, result, next);     \
441     } else {                                                                   \
442       assert(llvm::isa<FloatType>(resultType));                                \
443       result = rewriter.create<spirv::fop>(loc, resultType, result, next);     \
444     }                                                                          \
445     break
446 
447 #define INT_OR_FLOAT_CASE(kind, fop)                                           \
448   case vector::CombiningKind::kind:                                            \
449     result = rewriter.create<fop>(loc, resultType, result, next);              \
450     break
451 
452         INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
453         INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
454         INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
455         INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
456         INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
457         INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
458 
459       case vector::CombiningKind::AND:
460       case vector::CombiningKind::OR:
461       case vector::CombiningKind::XOR:
462         return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
463       default:
464         return rewriter.notifyMatchFailure(reduceOp, "not handled here");
465       }
466 #undef INT_AND_FLOAT_CASE
467 #undef INT_OR_FLOAT_CASE
468     }
469 
470     rewriter.replaceOp(reduceOp, result);
471     return success();
472   }
473 };
474 
475 template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
476 struct VectorReductionFloatMinMax final
477     : OpConversionPattern<vector::ReductionOp> {
478   using OpConversionPattern::OpConversionPattern;
479 
480   LogicalResult
481   matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
482                   ConversionPatternRewriter &rewriter) const override {
483     auto reductionInfo =
484         getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
485     if (failed(reductionInfo))
486       return failure();
487 
488     auto [resultType, extractedElements] = *reductionInfo;
489     Location loc = reduceOp->getLoc();
490     Value result = extractedElements.front();
491     for (Value next : llvm::drop_begin(extractedElements)) {
492       switch (reduceOp.getKind()) {
493 
494 #define INT_OR_FLOAT_CASE(kind, fop)                                           \
495   case vector::CombiningKind::kind:                                            \
496     result = rewriter.create<fop>(loc, resultType, result, next);              \
497     break
498 
499         INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
500         INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
501         INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp);
502         INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp);
503 
504       default:
505         return rewriter.notifyMatchFailure(reduceOp, "not handled here");
506       }
507 #undef INT_OR_FLOAT_CASE
508     }
509 
510     rewriter.replaceOp(reduceOp, result);
511     return success();
512   }
513 };
514 
515 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
516 public:
517   using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
518 
519   LogicalResult
520   matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
521                   ConversionPatternRewriter &rewriter) const override {
522     Type dstType = getTypeConverter()->convertType(op.getType());
523     if (!dstType)
524       return failure();
525     if (isa<spirv::ScalarType>(dstType)) {
526       rewriter.replaceOp(op, adaptor.getInput());
527     } else {
528       auto dstVecType = cast<VectorType>(dstType);
529       SmallVector<Value, 4> source(dstVecType.getNumElements(),
530                                    adaptor.getInput());
531       rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
532                                                                source);
533     }
534     return success();
535   }
536 };
537 
538 struct VectorShuffleOpConvert final
539     : public OpConversionPattern<vector::ShuffleOp> {
540   using OpConversionPattern::OpConversionPattern;
541 
542   LogicalResult
543   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
544                   ConversionPatternRewriter &rewriter) const override {
545     VectorType oldResultType = shuffleOp.getResultVectorType();
546     Type newResultType = getTypeConverter()->convertType(oldResultType);
547     if (!newResultType)
548       return rewriter.notifyMatchFailure(shuffleOp,
549                                          "unsupported result vector type");
550 
551     auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
552 
553     VectorType oldV1Type = shuffleOp.getV1VectorType();
554     VectorType oldV2Type = shuffleOp.getV2VectorType();
555 
556     // When both operands and the result are SPIR-V vectors, emit a SPIR-V
557     // shuffle.
558     if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
559         oldResultType.getNumElements() > 1) {
560       rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
561           shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
562           rewriter.getI32ArrayAttr(mask));
563       return success();
564     }
565 
566     // When at least one of the operands or the result becomes a scalar after
567     // type conversion for SPIR-V, extract all the required elements and
568     // construct the result vector.
569     auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
570                                Value scalarOrVec, int32_t idx) -> Value {
571       if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
572         return rewriter.create<spirv::CompositeExtractOp>(loc, scalarOrVec,
573                                                           idx);
574 
575       assert(idx == 0 && "Invalid scalar element index");
576       return scalarOrVec;
577     };
578 
579     int32_t numV1Elems = oldV1Type.getNumElements();
580     SmallVector<Value> newOperands(mask.size());
581     for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
582       Value vec = adaptor.getV1();
583       int32_t elementIdx = shuffleIdx;
584       if (elementIdx >= numV1Elems) {
585         vec = adaptor.getV2();
586         elementIdx -= numV1Elems;
587       }
588 
589       newOperand = getElementAtIdx(vec, elementIdx);
590     }
591 
592     // Handle the scalar result corner case.
593     if (newOperands.size() == 1) {
594       rewriter.replaceOp(shuffleOp, newOperands.front());
595       return success();
596     }
597 
598     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
599         shuffleOp, newResultType, newOperands);
600     return success();
601   }
602 };
603 
604 struct VectorInterleaveOpConvert final
605     : public OpConversionPattern<vector::InterleaveOp> {
606   using OpConversionPattern::OpConversionPattern;
607 
608   LogicalResult
609   matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
610                   ConversionPatternRewriter &rewriter) const override {
611     // Check the result vector type.
612     VectorType oldResultType = interleaveOp.getResultVectorType();
613     Type newResultType = getTypeConverter()->convertType(oldResultType);
614     if (!newResultType)
615       return rewriter.notifyMatchFailure(interleaveOp,
616                                          "unsupported result vector type");
617 
618     // Interleave the indices.
619     VectorType sourceType = interleaveOp.getSourceVectorType();
620     int n = sourceType.getNumElements();
621 
622     // Input vectors of size 1 are converted to scalars by the type converter.
623     // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
624     // use `spirv::CompositeConstructOp`.
625     if (n == 1) {
626       Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
627       rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
628           interleaveOp, newResultType, newOperands);
629       return success();
630     }
631 
632     auto seq = llvm::seq<int64_t>(2 * n);
633     auto indices = llvm::map_to_vector(
634         seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; });
635 
636     // Emit a SPIR-V shuffle.
637     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
638         interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
639         rewriter.getI32ArrayAttr(indices));
640 
641     return success();
642   }
643 };
644 
645 struct VectorDeinterleaveOpConvert final
646     : public OpConversionPattern<vector::DeinterleaveOp> {
647   using OpConversionPattern::OpConversionPattern;
648 
649   LogicalResult
650   matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
651                   ConversionPatternRewriter &rewriter) const override {
652 
653     // Check the result vector type.
654     VectorType oldResultType = deinterleaveOp.getResultVectorType();
655     Type newResultType = getTypeConverter()->convertType(oldResultType);
656     if (!newResultType)
657       return rewriter.notifyMatchFailure(deinterleaveOp,
658                                          "unsupported result vector type");
659 
660     Location loc = deinterleaveOp->getLoc();
661 
662     // Deinterleave the indices.
663     Value sourceVector = adaptor.getSource();
664     VectorType sourceType = deinterleaveOp.getSourceVectorType();
665     int n = sourceType.getNumElements();
666 
667     // Output vectors of size 1 are converted to scalars by the type converter.
668     // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
669     // use `spirv::CompositeExtractOp`.
670     if (n == 2) {
671       auto elem0 = rewriter.create<spirv::CompositeExtractOp>(
672           loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0}));
673 
674       auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
675           loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1}));
676 
677       rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
678       return success();
679     }
680 
681     // Indices for `shuffleEven` (result 0).
682     auto seqEven = llvm::seq<int64_t>(n / 2);
683     auto indicesEven =
684         llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
685 
686     // Indices for `shuffleOdd` (result 1).
687     auto seqOdd = llvm::seq<int64_t>(n / 2);
688     auto indicesOdd =
689         llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
690 
691     // Create two SPIR-V shuffles.
692     auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>(
693         loc, newResultType, sourceVector, sourceVector,
694         rewriter.getI32ArrayAttr(indicesEven));
695 
696     auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
697         loc, newResultType, sourceVector, sourceVector,
698         rewriter.getI32ArrayAttr(indicesOdd));
699 
700     rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
701     return success();
702   }
703 };
704 
705 struct VectorLoadOpConverter final
706     : public OpConversionPattern<vector::LoadOp> {
707   using OpConversionPattern::OpConversionPattern;
708 
709   LogicalResult
710   matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
711                   ConversionPatternRewriter &rewriter) const override {
712     auto memrefType = loadOp.getMemRefType();
713     auto attr =
714         dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
715     if (!attr)
716       return rewriter.notifyMatchFailure(
717           loadOp, "expected spirv.storage_class memory space");
718 
719     const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
720     auto loc = loadOp.getLoc();
721     Value accessChain =
722         spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
723                              adaptor.getIndices(), loc, rewriter);
724     if (!accessChain)
725       return rewriter.notifyMatchFailure(
726           loadOp, "failed to get memref element pointer");
727 
728     spirv::StorageClass storageClass = attr.getValue();
729     auto vectorType = loadOp.getVectorType();
730     auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
731     Value castedAccessChain =
732         rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
733     rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
734                                                castedAccessChain);
735 
736     return success();
737   }
738 };
739 
740 struct VectorStoreOpConverter final
741     : public OpConversionPattern<vector::StoreOp> {
742   using OpConversionPattern::OpConversionPattern;
743 
744   LogicalResult
745   matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
746                   ConversionPatternRewriter &rewriter) const override {
747     auto memrefType = storeOp.getMemRefType();
748     auto attr =
749         dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
750     if (!attr)
751       return rewriter.notifyMatchFailure(
752           storeOp, "expected spirv.storage_class memory space");
753 
754     const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
755     auto loc = storeOp.getLoc();
756     Value accessChain =
757         spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
758                              adaptor.getIndices(), loc, rewriter);
759     if (!accessChain)
760       return rewriter.notifyMatchFailure(
761           storeOp, "failed to get memref element pointer");
762 
763     spirv::StorageClass storageClass = attr.getValue();
764     auto vectorType = storeOp.getVectorType();
765     auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
766     Value castedAccessChain =
767         rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
768     rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
769                                                 adaptor.getValueToStore());
770 
771     return success();
772   }
773 };
774 
775 struct VectorReductionToIntDotProd final
776     : OpRewritePattern<vector::ReductionOp> {
777   using OpRewritePattern::OpRewritePattern;
778 
779   LogicalResult matchAndRewrite(vector::ReductionOp op,
780                                 PatternRewriter &rewriter) const override {
781     if (op.getKind() != vector::CombiningKind::ADD)
782       return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
783 
784     auto resultType = dyn_cast<IntegerType>(op.getType());
785     if (!resultType)
786       return rewriter.notifyMatchFailure(op, "result is not an integer");
787 
788     int64_t resultBitwidth = resultType.getIntOrFloatBitWidth();
789     if (!llvm::is_contained({32, 64}, resultBitwidth))
790       return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
791 
792     VectorType inVecTy = op.getSourceVectorType();
793     if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
794         inVecTy.getShape().size() != 1 || inVecTy.isScalable())
795       return rewriter.notifyMatchFailure(op, "unsupported vector shape");
796 
797     auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
798     if (!mul)
799       return rewriter.notifyMatchFailure(
800           op, "reduction operand is not 'arith.muli'");
801 
802     if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
803                              spirv::SDotAccSatOp, false>(op, mul, rewriter)))
804       return success();
805 
806     if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
807                              spirv::UDotAccSatOp, false>(op, mul, rewriter)))
808       return success();
809 
810     if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
811                              spirv::SUDotAccSatOp, false>(op, mul, rewriter)))
812       return success();
813 
814     if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
815                              spirv::SUDotAccSatOp, true>(op, mul, rewriter)))
816       return success();
817 
818     return failure();
819   }
820 
821 private:
822   template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp,
823             typename DotAccOp, bool SwapOperands>
824   static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
825                                   PatternRewriter &rewriter) {
826     auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
827     if (!lhs)
828       return failure();
829     Value lhsIn = lhs.getIn();
830     auto lhsInType = cast<VectorType>(lhsIn.getType());
831     if (!lhsInType.getElementType().isInteger(8))
832       return failure();
833 
834     auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
835     if (!rhs)
836       return failure();
837     Value rhsIn = rhs.getIn();
838     auto rhsInType = cast<VectorType>(rhsIn.getType());
839     if (!rhsInType.getElementType().isInteger(8))
840       return failure();
841 
842     if (op.getSourceVectorType().getNumElements() == 3) {
843       IntegerType i8Type = rewriter.getI8Type();
844       auto v4i8Type = VectorType::get({4}, i8Type);
845       Location loc = op.getLoc();
846       Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
847       lhsIn = rewriter.create<spirv::CompositeConstructOp>(
848           loc, v4i8Type, ValueRange{lhsIn, zero});
849       rhsIn = rewriter.create<spirv::CompositeConstructOp>(
850           loc, v4i8Type, ValueRange{rhsIn, zero});
851     }
852 
853     // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
854     // we have to swap operands instead in that case.
855     if (SwapOperands)
856       std::swap(lhsIn, rhsIn);
857 
858     if (Value acc = op.getAcc()) {
859       rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc,
860                                             nullptr);
861     } else {
862       rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn,
863                                          nullptr);
864     }
865 
866     return success();
867   }
868 };
869 
870 struct VectorReductionToFPDotProd final
871     : OpConversionPattern<vector::ReductionOp> {
872   using OpConversionPattern::OpConversionPattern;
873 
874   LogicalResult
875   matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
876                   ConversionPatternRewriter &rewriter) const override {
877     if (op.getKind() != vector::CombiningKind::ADD)
878       return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
879 
880     auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
881     if (!resultType)
882       return rewriter.notifyMatchFailure(op, "result is not a float");
883 
884     Value vec = adaptor.getVector();
885     Value acc = adaptor.getAcc();
886 
887     auto vectorType = dyn_cast<VectorType>(vec.getType());
888     if (!vectorType) {
889       assert(isa<FloatType>(vec.getType()) &&
890              "Expected the vector to be scalarized");
891       if (acc) {
892         rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
893         return success();
894       }
895 
896       rewriter.replaceOp(op, vec);
897       return success();
898     }
899 
900     Location loc = op.getLoc();
901     Value lhs;
902     Value rhs;
903     if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
904       lhs = mul.getLhs();
905       rhs = mul.getRhs();
906     } else {
907       // If the operand is not a mul, use a vector of ones for the dot operand
908       // to just sum up all values.
909       lhs = vec;
910       Attribute oneAttr =
911           rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
912       oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
913       rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr);
914     }
915     assert(lhs);
916     assert(rhs);
917 
918     Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
919     if (acc)
920       res = rewriter.create<spirv::FAddOp>(loc, acc, res);
921 
922     rewriter.replaceOp(op, res);
923     return success();
924   }
925 };
926 
927 struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
928   using OpConversionPattern::OpConversionPattern;
929 
930   LogicalResult
931   matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
932                   ConversionPatternRewriter &rewriter) const override {
933     const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
934     Type dstType = typeConverter.convertType(stepOp.getType());
935     if (!dstType)
936       return failure();
937 
938     Location loc = stepOp.getLoc();
939     int64_t numElements = stepOp.getType().getNumElements();
940     auto intType =
941         rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
942 
943     // Input vectors of size 1 are converted to scalars by the type converter.
944     // We just create a constant in this case.
945     if (numElements == 1) {
946       Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
947       rewriter.replaceOp(stepOp, zero);
948       return success();
949     }
950 
951     SmallVector<Value> source;
952     source.reserve(numElements);
953     for (int64_t i = 0; i < numElements; ++i) {
954       Attribute intAttr = rewriter.getIntegerAttr(intType, i);
955       Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
956       source.push_back(constOp);
957     }
958     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
959                                                              source);
960     return success();
961   }
962 };
963 
964 } // namespace
965 #define CL_INT_MAX_MIN_OPS                                                     \
966   spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
967 
968 #define GL_INT_MAX_MIN_OPS                                                     \
969   spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
970 
971 #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
972 #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
973 
974 void mlir::populateVectorToSPIRVPatterns(
975     const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
976   patterns.add<
977       VectorBitcastConvert, VectorBroadcastConvert,
978       VectorExtractElementOpConvert, VectorExtractOpConvert,
979       VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
980       VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
981       VectorInsertElementOpConvert, VectorInsertOpConvert,
982       VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
983       VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
984       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
985       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
986       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
987       VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
988       VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
989       VectorStepOpConvert>(typeConverter, patterns.getContext(),
990                            PatternBenefit(1));
991 
992   // Make sure that the more specialized dot product pattern has higher benefit
993   // than the generic one that extracts all elements.
994   patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
995                                            PatternBenefit(2));
996 }
997 
998 void mlir::populateVectorReductionToSPIRVDotProductPatterns(
999     RewritePatternSet &patterns) {
1000   patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
1001 }
1002