xref: /llvm-project/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp (revision d5746d73cedcf7a593dc4b4f2ce2465e2d45750b)
1 //===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
10 #include "../SPIRVCommon/Pattern.h"
11 #include "mlir/Dialect/Index/IR/IndexDialect.h"
12 #include "mlir/Dialect/Index/IR/IndexOps.h"
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
16 #include "mlir/Pass/Pass.h"
17 
18 using namespace mlir;
19 using namespace index;
20 
21 namespace {
22 
23 //===----------------------------------------------------------------------===//
24 // Trivial Conversions
25 //===----------------------------------------------------------------------===//
26 
27 using ConvertIndexAdd = spirv::ElementwiseOpPattern<AddOp, spirv::IAddOp>;
28 using ConvertIndexSub = spirv::ElementwiseOpPattern<SubOp, spirv::ISubOp>;
29 using ConvertIndexMul = spirv::ElementwiseOpPattern<MulOp, spirv::IMulOp>;
30 using ConvertIndexDivS = spirv::ElementwiseOpPattern<DivSOp, spirv::SDivOp>;
31 using ConvertIndexDivU = spirv::ElementwiseOpPattern<DivUOp, spirv::UDivOp>;
32 using ConvertIndexRemS = spirv::ElementwiseOpPattern<RemSOp, spirv::SRemOp>;
33 using ConvertIndexRemU = spirv::ElementwiseOpPattern<RemUOp, spirv::UModOp>;
34 using ConvertIndexMaxS = spirv::ElementwiseOpPattern<MaxSOp, spirv::GLSMaxOp>;
35 using ConvertIndexMaxU = spirv::ElementwiseOpPattern<MaxUOp, spirv::GLUMaxOp>;
36 using ConvertIndexMinS = spirv::ElementwiseOpPattern<MinSOp, spirv::GLSMinOp>;
37 using ConvertIndexMinU = spirv::ElementwiseOpPattern<MinUOp, spirv::GLUMinOp>;
38 
39 using ConvertIndexShl =
40     spirv::ElementwiseOpPattern<ShlOp, spirv::ShiftLeftLogicalOp>;
41 using ConvertIndexShrS =
42     spirv::ElementwiseOpPattern<ShrSOp, spirv::ShiftRightArithmeticOp>;
43 using ConvertIndexShrU =
44     spirv::ElementwiseOpPattern<ShrUOp, spirv::ShiftRightLogicalOp>;
45 
46 /// It is the case that when we convert bitwise operations to SPIR-V operations
47 /// we must take into account the special pattern in SPIR-V that if the
48 /// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
49 /// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
50 /// index.add is never a boolean operation so we can directly convert it to the
51 /// Bitwise[And|Or]Op.
52 using ConvertIndexAnd = spirv::ElementwiseOpPattern<AndOp, spirv::BitwiseAndOp>;
53 using ConvertIndexOr = spirv::ElementwiseOpPattern<OrOp, spirv::BitwiseOrOp>;
54 using ConvertIndexXor = spirv::ElementwiseOpPattern<XOrOp, spirv::BitwiseXorOp>;
55 
56 //===----------------------------------------------------------------------===//
57 // ConvertConstantBool
58 //===----------------------------------------------------------------------===//
59 
60 // Converts index.bool.constant operation to spirv.Constant.
61 struct ConvertIndexConstantBoolOpPattern final
62     : OpConversionPattern<BoolConstantOp> {
63   using OpConversionPattern::OpConversionPattern;
64 
65   LogicalResult
66   matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
67                   ConversionPatternRewriter &rewriter) const override {
68     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(),
69                                                    op.getValueAttr());
70     return success();
71   }
72 };
73 
74 //===----------------------------------------------------------------------===//
75 // ConvertConstant
76 //===----------------------------------------------------------------------===//
77 
78 // Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
79 // when required.
80 struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
81   using OpConversionPattern::OpConversionPattern;
82 
83   LogicalResult
84   matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
85                   ConversionPatternRewriter &rewriter) const override {
86     auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
87     Type indexType = typeConverter->getIndexType();
88 
89     APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
90     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
91         op, indexType, IntegerAttr::get(indexType, value));
92     return success();
93   }
94 };
95 
96 //===----------------------------------------------------------------------===//
97 // ConvertIndexCeilDivS
98 //===----------------------------------------------------------------------===//
99 
100 /// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
101 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent
102 /// conversion in IndexToLLVM.
103 struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
104   using OpConversionPattern::OpConversionPattern;
105 
106   LogicalResult
107   matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
108                   ConversionPatternRewriter &rewriter) const override {
109     Location loc = op.getLoc();
110     Value n = adaptor.getLhs();
111     Type n_type = n.getType();
112     Value m = adaptor.getRhs();
113 
114     // Define the constants
115     Value zero = rewriter.create<spirv::ConstantOp>(
116         loc, n_type, IntegerAttr::get(n_type, 0));
117     Value posOne = rewriter.create<spirv::ConstantOp>(
118         loc, n_type, IntegerAttr::get(n_type, 1));
119     Value negOne = rewriter.create<spirv::ConstantOp>(
120         loc, n_type, IntegerAttr::get(n_type, -1));
121 
122     // Compute `x`.
123     Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero);
124     Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne);
125 
126     // Compute the positive result.
127     Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x);
128     Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m);
129     Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
130 
131     // Compute the negative result.
132     Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n);
133     Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m);
134     Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM);
135 
136     // Pick the positive result if `n` and `m` have the same sign and `n` is
137     // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
138     Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero);
139     Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos);
140     Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
141     Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
142     rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
143     return success();
144   }
145 };
146 
147 //===----------------------------------------------------------------------===//
148 // ConvertIndexCeilDivU
149 //===----------------------------------------------------------------------===//
150 
151 /// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken
152 /// from the equivalent conversion in IndexToLLVM.
153 struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
154   using OpConversionPattern::OpConversionPattern;
155 
156   LogicalResult
157   matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
158                   ConversionPatternRewriter &rewriter) const override {
159     Location loc = op.getLoc();
160     Value n = adaptor.getLhs();
161     Type n_type = n.getType();
162     Value m = adaptor.getRhs();
163 
164     // Define the constants
165     Value zero = rewriter.create<spirv::ConstantOp>(
166         loc, n_type, IntegerAttr::get(n_type, 0));
167     Value one = rewriter.create<spirv::ConstantOp>(loc, n_type,
168                                                    IntegerAttr::get(n_type, 1));
169 
170     // Compute the non-zero result.
171     Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one);
172     Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m);
173     Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one);
174 
175     // Pick the result
176     Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero);
177     rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
178     return success();
179   }
180 };
181 
182 //===----------------------------------------------------------------------===//
183 // ConvertIndexFloorDivS
184 //===----------------------------------------------------------------------===//
185 
186 /// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
187 /// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion
188 /// in IndexToLLVM.
189 struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
190   using OpConversionPattern::OpConversionPattern;
191 
192   LogicalResult
193   matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
194                   ConversionPatternRewriter &rewriter) const override {
195     Location loc = op.getLoc();
196     Value n = adaptor.getLhs();
197     Type n_type = n.getType();
198     Value m = adaptor.getRhs();
199 
200     // Define the constants
201     Value zero = rewriter.create<spirv::ConstantOp>(
202         loc, n_type, IntegerAttr::get(n_type, 0));
203     Value posOne = rewriter.create<spirv::ConstantOp>(
204         loc, n_type, IntegerAttr::get(n_type, 1));
205     Value negOne = rewriter.create<spirv::ConstantOp>(
206         loc, n_type, IntegerAttr::get(n_type, -1));
207 
208     // Compute `x`.
209     Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero);
210     Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
211 
212     // Compute the negative result
213     Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n);
214     Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m);
215     Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
216 
217     // Compute the positive result.
218     Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m);
219 
220     // Pick the negative result if `n` and `m` have different signs and `n` is
221     // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
222     Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero);
223     Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
224     Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
225 
226     Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
227     rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
228     return success();
229   }
230 };
231 
232 //===----------------------------------------------------------------------===//
233 // ConvertIndexCast
234 //===----------------------------------------------------------------------===//
235 
236 /// Convert a cast op. If the materialized index type is the same as the other
237 /// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
238 /// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
239 /// zero extend when the result bitwidth is larger.
240 template <typename CastOp, typename ConvertOp>
241 struct ConvertIndexCast final : OpConversionPattern<CastOp> {
242   using OpConversionPattern<CastOp>::OpConversionPattern;
243 
244   LogicalResult
245   matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
246                   ConversionPatternRewriter &rewriter) const override {
247     auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
248     Type indexType = typeConverter->getIndexType();
249 
250     Type srcType = adaptor.getInput().getType();
251     Type dstType = op.getType();
252     if (isa<IndexType>(srcType)) {
253       srcType = indexType;
254     }
255     if (isa<IndexType>(dstType)) {
256       dstType = indexType;
257     }
258 
259     if (srcType == dstType) {
260       rewriter.replaceOp(op, adaptor.getInput());
261     } else {
262       rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
263                                                       adaptor.getOperands());
264     }
265     return success();
266   }
267 };
268 
269 using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
270 using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
271 
272 //===----------------------------------------------------------------------===//
273 // ConvertIndexCmp
274 //===----------------------------------------------------------------------===//
275 
276 // Helper template to replace the operation
277 template <typename ICmpOp>
278 static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
279                                   ConversionPatternRewriter &rewriter) {
280   rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
281   return success();
282 }
283 
284 struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
285   using OpConversionPattern::OpConversionPattern;
286 
287   LogicalResult
288   matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
289                   ConversionPatternRewriter &rewriter) const override {
290     // We must convert the predicates to the corresponding int comparions.
291     switch (op.getPred()) {
292     case IndexCmpPredicate::EQ:
293       return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
294     case IndexCmpPredicate::NE:
295       return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
296     case IndexCmpPredicate::SGE:
297       return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
298     case IndexCmpPredicate::SGT:
299       return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
300     case IndexCmpPredicate::SLE:
301       return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
302     case IndexCmpPredicate::SLT:
303       return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
304     case IndexCmpPredicate::UGE:
305       return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
306     case IndexCmpPredicate::UGT:
307       return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
308     case IndexCmpPredicate::ULE:
309       return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
310     case IndexCmpPredicate::ULT:
311       return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
312     }
313     llvm_unreachable("Unknown predicate in ConvertIndexCmpPattern");
314   }
315 };
316 
317 //===----------------------------------------------------------------------===//
318 // ConvertIndexSizeOf
319 //===----------------------------------------------------------------------===//
320 
321 /// Lower `index.sizeof` to a constant with the value of the index bitwidth.
322 struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
323   using OpConversionPattern::OpConversionPattern;
324 
325   LogicalResult
326   matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
327                   ConversionPatternRewriter &rewriter) const override {
328     auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
329     Type indexType = typeConverter->getIndexType();
330     unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
331     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
332         op, indexType, IntegerAttr::get(indexType, bitwidth));
333     return success();
334   }
335 };
336 } // namespace
337 
338 //===----------------------------------------------------------------------===//
339 // Pattern Population
340 //===----------------------------------------------------------------------===//
341 
342 void index::populateIndexToSPIRVPatterns(
343     const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
344   patterns.add<
345       // clang-format off
346     ConvertIndexAdd,
347     ConvertIndexSub,
348     ConvertIndexMul,
349     ConvertIndexDivS,
350     ConvertIndexDivU,
351     ConvertIndexRemS,
352     ConvertIndexRemU,
353     ConvertIndexMaxS,
354     ConvertIndexMaxU,
355     ConvertIndexMinS,
356     ConvertIndexMinU,
357     ConvertIndexShl,
358     ConvertIndexShrS,
359     ConvertIndexShrU,
360     ConvertIndexAnd,
361     ConvertIndexOr,
362     ConvertIndexXor,
363     ConvertIndexConstantBoolOpPattern,
364     ConvertIndexConstantOpPattern,
365     ConvertIndexCeilDivSPattern,
366     ConvertIndexCeilDivUPattern,
367     ConvertIndexFloorDivSPattern,
368     ConvertIndexCastS,
369     ConvertIndexCastU,
370     ConvertIndexCmpPattern,
371     ConvertIndexSizeOf
372   >(typeConverter, patterns.getContext());
373 }
374 
375 //===----------------------------------------------------------------------===//
376 // ODS-Generated Definitions
377 //===----------------------------------------------------------------------===//
378 
379 namespace mlir {
380 #define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
381 #include "mlir/Conversion/Passes.h.inc"
382 } // namespace mlir
383 
384 //===----------------------------------------------------------------------===//
385 // Pass Definition
386 //===----------------------------------------------------------------------===//
387 
388 namespace {
389 struct ConvertIndexToSPIRVPass
390     : public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> {
391   using Base::Base;
392 
393   void runOnOperation() override {
394     Operation *op = getOperation();
395     spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
396     std::unique_ptr<SPIRVConversionTarget> target =
397       SPIRVConversionTarget::get(targetAttr);
398 
399     SPIRVConversionOptions options;
400     options.use64bitIndex = this->use64bitIndex;
401     SPIRVTypeConverter typeConverter(targetAttr, options);
402 
403     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
404     // in patterns for other dialects.
405     target->addLegalOp<UnrealizedConversionCastOp>();
406 
407     // Allow the spirv operations we are converting to
408     target->addLegalDialect<spirv::SPIRVDialect>();
409     // Fail hard when there are any remaining 'index' ops.
410     target->addIllegalDialect<index::IndexDialect>();
411 
412     RewritePatternSet patterns(&getContext());
413     index::populateIndexToSPIRVPatterns(typeConverter, patterns);
414 
415     if (failed(applyPartialConversion(op, *target, std::move(patterns))))
416       signalPassFailure();
417   }
418 };
419 } // namespace
420