1 //===- UBToSPIRV.cpp - UB to SPIRV-V dialect conversion -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h" 10 11 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 12 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 13 #include "mlir/Dialect/UB/IR/UBOps.h" 14 #include "mlir/Pass/Pass.h" 15 16 namespace mlir { 17 #define GEN_PASS_DEF_UBTOSPIRVCONVERSIONPASS 18 #include "mlir/Conversion/Passes.h.inc" 19 } // namespace mlir 20 21 using namespace mlir; 22 23 namespace { 24 25 struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> { 26 using OpConversionPattern::OpConversionPattern; 27 28 LogicalResult 29 matchAndRewrite(ub::PoisonOp op, OpAdaptor, 30 ConversionPatternRewriter &rewriter) const override { 31 Type origType = op.getType(); 32 if (!origType.isIntOrIndexOrFloat()) 33 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 34 diag << "unsupported type " << origType; 35 }); 36 37 Type resType = getTypeConverter()->convertType(origType); 38 if (!resType) 39 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 40 diag << "failed to convert result type " << origType; 41 }); 42 43 rewriter.replaceOpWithNewOp<spirv::UndefOp>(op, resType); 44 return success(); 45 } 46 }; 47 48 } // namespace 49 50 //===----------------------------------------------------------------------===// 51 // Pass Definition 52 //===----------------------------------------------------------------------===// 53 54 namespace { 55 struct UBToSPIRVConversionPass final 56 : impl::UBToSPIRVConversionPassBase<UBToSPIRVConversionPass> { 57 using Base::Base; 58 59 void runOnOperation() override { 60 Operation *op = getOperation(); 61 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); 62 std::unique_ptr<SPIRVConversionTarget> target = 63 SPIRVConversionTarget::get(targetAttr); 64 65 SPIRVConversionOptions options; 66 SPIRVTypeConverter typeConverter(targetAttr, options); 67 68 RewritePatternSet patterns(&getContext()); 69 ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns); 70 71 if (failed(applyPartialConversion(op, *target, std::move(patterns)))) 72 signalPassFailure(); 73 } 74 }; 75 } // namespace 76 77 //===----------------------------------------------------------------------===// 78 // Pattern Population 79 //===----------------------------------------------------------------------===// 80 81 void mlir::ub::populateUBToSPIRVConversionPatterns( 82 const SPIRVTypeConverter &converter, RewritePatternSet &patterns) { 83 patterns.add<PoisonOpLowering>(converter, patterns.getContext()); 84 } 85