1 //===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h" 10 #include "../SPIRVCommon/Pattern.h" 11 #include "mlir/Dialect/Index/IR/IndexDialect.h" 12 #include "mlir/Dialect/Index/IR/IndexOps.h" 13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 15 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 16 #include "mlir/Pass/Pass.h" 17 18 using namespace mlir; 19 using namespace index; 20 21 namespace { 22 23 //===----------------------------------------------------------------------===// 24 // Trivial Conversions 25 //===----------------------------------------------------------------------===// 26 27 using ConvertIndexAdd = spirv::ElementwiseOpPattern<AddOp, spirv::IAddOp>; 28 using ConvertIndexSub = spirv::ElementwiseOpPattern<SubOp, spirv::ISubOp>; 29 using ConvertIndexMul = spirv::ElementwiseOpPattern<MulOp, spirv::IMulOp>; 30 using ConvertIndexDivS = spirv::ElementwiseOpPattern<DivSOp, spirv::SDivOp>; 31 using ConvertIndexDivU = spirv::ElementwiseOpPattern<DivUOp, spirv::UDivOp>; 32 using ConvertIndexRemS = spirv::ElementwiseOpPattern<RemSOp, spirv::SRemOp>; 33 using ConvertIndexRemU = spirv::ElementwiseOpPattern<RemUOp, spirv::UModOp>; 34 using ConvertIndexMaxS = spirv::ElementwiseOpPattern<MaxSOp, spirv::GLSMaxOp>; 35 using ConvertIndexMaxU = spirv::ElementwiseOpPattern<MaxUOp, spirv::GLUMaxOp>; 36 using ConvertIndexMinS = spirv::ElementwiseOpPattern<MinSOp, spirv::GLSMinOp>; 37 using ConvertIndexMinU = spirv::ElementwiseOpPattern<MinUOp, spirv::GLUMinOp>; 38 39 using ConvertIndexShl = 40 spirv::ElementwiseOpPattern<ShlOp, spirv::ShiftLeftLogicalOp>; 41 using ConvertIndexShrS = 42 spirv::ElementwiseOpPattern<ShrSOp, spirv::ShiftRightArithmeticOp>; 43 using ConvertIndexShrU = 44 spirv::ElementwiseOpPattern<ShrUOp, spirv::ShiftRightLogicalOp>; 45 46 /// It is the case that when we convert bitwise operations to SPIR-V operations 47 /// we must take into account the special pattern in SPIR-V that if the 48 /// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise, 49 /// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However, 50 /// index.add is never a boolean operation so we can directly convert it to the 51 /// Bitwise[And|Or]Op. 52 using ConvertIndexAnd = spirv::ElementwiseOpPattern<AndOp, spirv::BitwiseAndOp>; 53 using ConvertIndexOr = spirv::ElementwiseOpPattern<OrOp, spirv::BitwiseOrOp>; 54 using ConvertIndexXor = spirv::ElementwiseOpPattern<XOrOp, spirv::BitwiseXorOp>; 55 56 //===----------------------------------------------------------------------===// 57 // ConvertConstantBool 58 //===----------------------------------------------------------------------===// 59 60 // Converts index.bool.constant operation to spirv.Constant. 61 struct ConvertIndexConstantBoolOpPattern final 62 : OpConversionPattern<BoolConstantOp> { 63 using OpConversionPattern::OpConversionPattern; 64 65 LogicalResult 66 matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor, 67 ConversionPatternRewriter &rewriter) const override { 68 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(), 69 op.getValueAttr()); 70 return success(); 71 } 72 }; 73 74 //===----------------------------------------------------------------------===// 75 // ConvertConstant 76 //===----------------------------------------------------------------------===// 77 78 // Converts index.constant op to spirv.Constant. Will truncate from i64 to i32 79 // when required. 80 struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> { 81 using OpConversionPattern::OpConversionPattern; 82 83 LogicalResult 84 matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor, 85 ConversionPatternRewriter &rewriter) const override { 86 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>(); 87 Type indexType = typeConverter->getIndexType(); 88 89 APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth()); 90 rewriter.replaceOpWithNewOp<spirv::ConstantOp>( 91 op, indexType, IntegerAttr::get(indexType, value)); 92 return success(); 93 } 94 }; 95 96 //===----------------------------------------------------------------------===// 97 // ConvertIndexCeilDivS 98 //===----------------------------------------------------------------------===// 99 100 /// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then 101 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent 102 /// conversion in IndexToLLVM. 103 struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> { 104 using OpConversionPattern::OpConversionPattern; 105 106 LogicalResult 107 matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor, 108 ConversionPatternRewriter &rewriter) const override { 109 Location loc = op.getLoc(); 110 Value n = adaptor.getLhs(); 111 Type n_type = n.getType(); 112 Value m = adaptor.getRhs(); 113 114 // Define the constants 115 Value zero = rewriter.create<spirv::ConstantOp>( 116 loc, n_type, IntegerAttr::get(n_type, 0)); 117 Value posOne = rewriter.create<spirv::ConstantOp>( 118 loc, n_type, IntegerAttr::get(n_type, 1)); 119 Value negOne = rewriter.create<spirv::ConstantOp>( 120 loc, n_type, IntegerAttr::get(n_type, -1)); 121 122 // Compute `x`. 123 Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero); 124 Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne); 125 126 // Compute the positive result. 127 Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x); 128 Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m); 129 Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne); 130 131 // Compute the negative result. 132 Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n); 133 Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m); 134 Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM); 135 136 // Pick the positive result if `n` and `m` have the same sign and `n` is 137 // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`. 138 Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero); 139 Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos); 140 Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero); 141 Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero); 142 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes); 143 return success(); 144 } 145 }; 146 147 //===----------------------------------------------------------------------===// 148 // ConvertIndexCeilDivU 149 //===----------------------------------------------------------------------===// 150 151 /// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken 152 /// from the equivalent conversion in IndexToLLVM. 153 struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> { 154 using OpConversionPattern::OpConversionPattern; 155 156 LogicalResult 157 matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor, 158 ConversionPatternRewriter &rewriter) const override { 159 Location loc = op.getLoc(); 160 Value n = adaptor.getLhs(); 161 Type n_type = n.getType(); 162 Value m = adaptor.getRhs(); 163 164 // Define the constants 165 Value zero = rewriter.create<spirv::ConstantOp>( 166 loc, n_type, IntegerAttr::get(n_type, 0)); 167 Value one = rewriter.create<spirv::ConstantOp>(loc, n_type, 168 IntegerAttr::get(n_type, 1)); 169 170 // Compute the non-zero result. 171 Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one); 172 Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m); 173 Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one); 174 175 // Pick the result 176 Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero); 177 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne); 178 return success(); 179 } 180 }; 181 182 //===----------------------------------------------------------------------===// 183 // ConvertIndexFloorDivS 184 //===----------------------------------------------------------------------===// 185 186 /// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then 187 /// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion 188 /// in IndexToLLVM. 189 struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> { 190 using OpConversionPattern::OpConversionPattern; 191 192 LogicalResult 193 matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor, 194 ConversionPatternRewriter &rewriter) const override { 195 Location loc = op.getLoc(); 196 Value n = adaptor.getLhs(); 197 Type n_type = n.getType(); 198 Value m = adaptor.getRhs(); 199 200 // Define the constants 201 Value zero = rewriter.create<spirv::ConstantOp>( 202 loc, n_type, IntegerAttr::get(n_type, 0)); 203 Value posOne = rewriter.create<spirv::ConstantOp>( 204 loc, n_type, IntegerAttr::get(n_type, 1)); 205 Value negOne = rewriter.create<spirv::ConstantOp>( 206 loc, n_type, IntegerAttr::get(n_type, -1)); 207 208 // Compute `x`. 209 Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero); 210 Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne); 211 212 // Compute the negative result 213 Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n); 214 Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m); 215 Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM); 216 217 // Compute the positive result. 218 Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m); 219 220 // Pick the negative result if `n` and `m` have different signs and `n` is 221 // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`. 222 Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero); 223 Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg); 224 Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero); 225 226 Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero); 227 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes); 228 return success(); 229 } 230 }; 231 232 //===----------------------------------------------------------------------===// 233 // ConvertIndexCast 234 //===----------------------------------------------------------------------===// 235 236 /// Convert a cast op. If the materialized index type is the same as the other 237 /// type, fold away the op. Otherwise, use the Convert SPIR-V operation. 238 /// Signed casts sign extend when the result bitwidth is larger. Unsigned casts 239 /// zero extend when the result bitwidth is larger. 240 template <typename CastOp, typename ConvertOp> 241 struct ConvertIndexCast final : OpConversionPattern<CastOp> { 242 using OpConversionPattern<CastOp>::OpConversionPattern; 243 244 LogicalResult 245 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, 246 ConversionPatternRewriter &rewriter) const override { 247 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>(); 248 Type indexType = typeConverter->getIndexType(); 249 250 Type srcType = adaptor.getInput().getType(); 251 Type dstType = op.getType(); 252 if (isa<IndexType>(srcType)) { 253 srcType = indexType; 254 } 255 if (isa<IndexType>(dstType)) { 256 dstType = indexType; 257 } 258 259 if (srcType == dstType) { 260 rewriter.replaceOp(op, adaptor.getInput()); 261 } else { 262 rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType, 263 adaptor.getOperands()); 264 } 265 return success(); 266 } 267 }; 268 269 using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>; 270 using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>; 271 272 //===----------------------------------------------------------------------===// 273 // ConvertIndexCmp 274 //===----------------------------------------------------------------------===// 275 276 // Helper template to replace the operation 277 template <typename ICmpOp> 278 static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor, 279 ConversionPatternRewriter &rewriter) { 280 rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs()); 281 return success(); 282 } 283 284 struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> { 285 using OpConversionPattern::OpConversionPattern; 286 287 LogicalResult 288 matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor, 289 ConversionPatternRewriter &rewriter) const override { 290 // We must convert the predicates to the corresponding int comparions. 291 switch (op.getPred()) { 292 case IndexCmpPredicate::EQ: 293 return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter); 294 case IndexCmpPredicate::NE: 295 return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter); 296 case IndexCmpPredicate::SGE: 297 return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter); 298 case IndexCmpPredicate::SGT: 299 return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter); 300 case IndexCmpPredicate::SLE: 301 return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter); 302 case IndexCmpPredicate::SLT: 303 return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter); 304 case IndexCmpPredicate::UGE: 305 return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter); 306 case IndexCmpPredicate::UGT: 307 return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter); 308 case IndexCmpPredicate::ULE: 309 return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter); 310 case IndexCmpPredicate::ULT: 311 return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter); 312 } 313 llvm_unreachable("Unknown predicate in ConvertIndexCmpPattern"); 314 } 315 }; 316 317 //===----------------------------------------------------------------------===// 318 // ConvertIndexSizeOf 319 //===----------------------------------------------------------------------===// 320 321 /// Lower `index.sizeof` to a constant with the value of the index bitwidth. 322 struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> { 323 using OpConversionPattern::OpConversionPattern; 324 325 LogicalResult 326 matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor, 327 ConversionPatternRewriter &rewriter) const override { 328 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>(); 329 Type indexType = typeConverter->getIndexType(); 330 unsigned bitwidth = typeConverter->getIndexTypeBitwidth(); 331 rewriter.replaceOpWithNewOp<spirv::ConstantOp>( 332 op, indexType, IntegerAttr::get(indexType, bitwidth)); 333 return success(); 334 } 335 }; 336 } // namespace 337 338 //===----------------------------------------------------------------------===// 339 // Pattern Population 340 //===----------------------------------------------------------------------===// 341 342 void index::populateIndexToSPIRVPatterns( 343 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { 344 patterns.add< 345 // clang-format off 346 ConvertIndexAdd, 347 ConvertIndexSub, 348 ConvertIndexMul, 349 ConvertIndexDivS, 350 ConvertIndexDivU, 351 ConvertIndexRemS, 352 ConvertIndexRemU, 353 ConvertIndexMaxS, 354 ConvertIndexMaxU, 355 ConvertIndexMinS, 356 ConvertIndexMinU, 357 ConvertIndexShl, 358 ConvertIndexShrS, 359 ConvertIndexShrU, 360 ConvertIndexAnd, 361 ConvertIndexOr, 362 ConvertIndexXor, 363 ConvertIndexConstantBoolOpPattern, 364 ConvertIndexConstantOpPattern, 365 ConvertIndexCeilDivSPattern, 366 ConvertIndexCeilDivUPattern, 367 ConvertIndexFloorDivSPattern, 368 ConvertIndexCastS, 369 ConvertIndexCastU, 370 ConvertIndexCmpPattern, 371 ConvertIndexSizeOf 372 >(typeConverter, patterns.getContext()); 373 } 374 375 //===----------------------------------------------------------------------===// 376 // ODS-Generated Definitions 377 //===----------------------------------------------------------------------===// 378 379 namespace mlir { 380 #define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS 381 #include "mlir/Conversion/Passes.h.inc" 382 } // namespace mlir 383 384 //===----------------------------------------------------------------------===// 385 // Pass Definition 386 //===----------------------------------------------------------------------===// 387 388 namespace { 389 struct ConvertIndexToSPIRVPass 390 : public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> { 391 using Base::Base; 392 393 void runOnOperation() override { 394 Operation *op = getOperation(); 395 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); 396 std::unique_ptr<SPIRVConversionTarget> target = 397 SPIRVConversionTarget::get(targetAttr); 398 399 SPIRVConversionOptions options; 400 options.use64bitIndex = this->use64bitIndex; 401 SPIRVTypeConverter typeConverter(targetAttr, options); 402 403 // Use UnrealizedConversionCast as the bridge so that we don't need to pull 404 // in patterns for other dialects. 405 target->addLegalOp<UnrealizedConversionCastOp>(); 406 407 // Allow the spirv operations we are converting to 408 target->addLegalDialect<spirv::SPIRVDialect>(); 409 // Fail hard when there are any remaining 'index' ops. 410 target->addIllegalDialect<index::IndexDialect>(); 411 412 RewritePatternSet patterns(&getContext()); 413 index::populateIndexToSPIRVPatterns(typeConverter, patterns); 414 415 if (failed(applyPartialConversion(op, *target, std::move(patterns)))) 416 signalPassFailure(); 417 } 418 }; 419 } // namespace 420