1 //===- TensorToSPIRV.cpp - Tensor to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Tensor dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h" 14 #include "../SPIRVCommon/Pattern.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/IR/AffineMap.h" 21 #include "llvm/Support/Debug.h" 22 23 #define DEBUG_TYPE "tensor-to-spirv-pattern" 24 25 using namespace mlir; 26 27 //===----------------------------------------------------------------------===// 28 // Operation conversion 29 //===----------------------------------------------------------------------===// 30 31 namespace { 32 33 /// Converts tensor.extract into loading using access chains from SPIR-V local 34 /// variables. 35 class TensorExtractPattern final 36 : public OpConversionPattern<tensor::ExtractOp> { 37 public: 38 TensorExtractPattern(const TypeConverter &typeConverter, MLIRContext *context, 39 int64_t threshold, PatternBenefit benefit = 1) 40 : OpConversionPattern(typeConverter, context, benefit), 41 byteCountThreshold(threshold) {} 42 43 LogicalResult 44 matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor, 45 ConversionPatternRewriter &rewriter) const override { 46 auto tensorType = cast<RankedTensorType>(extractOp.getTensor().getType()); 47 48 if (!isa<spirv::ScalarType>(tensorType.getElementType())) 49 return rewriter.notifyMatchFailure(extractOp, "unsupported type"); 50 if (!tensorType.hasStaticShape()) 51 return rewriter.notifyMatchFailure(extractOp, "non-static tensor"); 52 53 if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() > 54 byteCountThreshold * 8) 55 return rewriter.notifyMatchFailure(extractOp, 56 "exceeding byte count threshold"); 57 58 Location loc = extractOp.getLoc(); 59 60 int64_t rank = tensorType.getRank(); 61 SmallVector<int64_t, 4> strides(rank, 1); 62 for (int i = rank - 2; i >= 0; --i) { 63 strides[i] = strides[i + 1] * tensorType.getDimSize(i + 1); 64 } 65 66 Type varType = spirv::PointerType::get(adaptor.getTensor().getType(), 67 spirv::StorageClass::Function); 68 69 spirv::VariableOp varOp; 70 if (adaptor.getTensor().getDefiningOp<spirv::ConstantOp>()) { 71 // We could use the initializer directly; but certain driver compilers 72 // have bugs dealing with that. So for now, use spirv.Store for 73 // initialization. 74 varOp = rewriter.create<spirv::VariableOp>(loc, varType, 75 spirv::StorageClass::Function, 76 /*initializer=*/nullptr); 77 rewriter.create<spirv::StoreOp>(loc, varOp, adaptor.getTensor()); 78 } else { 79 // Need to store the value to the local variable. It's questionable 80 // whether we want to support such case though. 81 return failure(); 82 } 83 84 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 85 auto indexType = typeConverter.getIndexType(); 86 87 Value index = spirv::linearizeIndex(adaptor.getIndices(), strides, 88 /*offset=*/0, indexType, loc, rewriter); 89 auto acOp = rewriter.create<spirv::AccessChainOp>(loc, varOp, index); 90 91 rewriter.replaceOpWithNewOp<spirv::LoadOp>(extractOp, acOp); 92 93 return success(); 94 } 95 96 private: 97 int64_t byteCountThreshold; 98 }; 99 100 } // namespace 101 102 //===----------------------------------------------------------------------===// 103 // Pattern population 104 //===----------------------------------------------------------------------===// 105 106 void mlir::populateTensorToSPIRVPatterns( 107 const SPIRVTypeConverter &typeConverter, int64_t byteCountThreshold, 108 RewritePatternSet &patterns) { 109 patterns.add<TensorExtractPattern>(typeConverter, patterns.getContext(), 110 byteCountThreshold); 111 } 112