13ba66435SRiver Riddle //===- TensorToSPIRV.cpp - Tensor to SPIR-V Patterns ----------------------===// 23ba66435SRiver Riddle // 33ba66435SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 43ba66435SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 53ba66435SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 63ba66435SRiver Riddle // 73ba66435SRiver Riddle //===----------------------------------------------------------------------===// 83ba66435SRiver Riddle // 93ba66435SRiver Riddle // This file implements patterns to convert Tensor dialect to SPIR-V dialect. 103ba66435SRiver Riddle // 113ba66435SRiver Riddle //===----------------------------------------------------------------------===// 123ba66435SRiver Riddle 133ba66435SRiver Riddle #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h" 143ba66435SRiver Riddle #include "../SPIRVCommon/Pattern.h" 153ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 163ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 173ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 183ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" 193ba66435SRiver Riddle #include "mlir/Dialect/Tensor/IR/Tensor.h" 203ba66435SRiver Riddle #include "mlir/IR/AffineMap.h" 213ba66435SRiver Riddle #include "llvm/Support/Debug.h" 223ba66435SRiver Riddle 233ba66435SRiver Riddle #define DEBUG_TYPE "tensor-to-spirv-pattern" 243ba66435SRiver Riddle 253ba66435SRiver Riddle using namespace mlir; 263ba66435SRiver Riddle 273ba66435SRiver Riddle //===----------------------------------------------------------------------===// 283ba66435SRiver Riddle // Operation conversion 293ba66435SRiver Riddle //===----------------------------------------------------------------------===// 303ba66435SRiver Riddle 313ba66435SRiver Riddle namespace { 323ba66435SRiver Riddle 333ba66435SRiver Riddle /// Converts tensor.extract into loading using access chains from SPIR-V local 343ba66435SRiver Riddle /// variables. 353ba66435SRiver Riddle class TensorExtractPattern final 363ba66435SRiver Riddle : public OpConversionPattern<tensor::ExtractOp> { 373ba66435SRiver Riddle public: 38*206fad0eSMatthias Springer TensorExtractPattern(const TypeConverter &typeConverter, MLIRContext *context, 393ba66435SRiver Riddle int64_t threshold, PatternBenefit benefit = 1) 403ba66435SRiver Riddle : OpConversionPattern(typeConverter, context, benefit), 413ba66435SRiver Riddle byteCountThreshold(threshold) {} 423ba66435SRiver Riddle 433ba66435SRiver Riddle LogicalResult 443ba66435SRiver Riddle matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor, 453ba66435SRiver Riddle ConversionPatternRewriter &rewriter) const override { 465550c821STres Popp auto tensorType = cast<RankedTensorType>(extractOp.getTensor().getType()); 473ba66435SRiver Riddle 48f4b9839dSLongsheng Mou if (!isa<spirv::ScalarType>(tensorType.getElementType())) 49f4b9839dSLongsheng Mou return rewriter.notifyMatchFailure(extractOp, "unsupported type"); 503ba66435SRiver Riddle if (!tensorType.hasStaticShape()) 513ba66435SRiver Riddle return rewriter.notifyMatchFailure(extractOp, "non-static tensor"); 523ba66435SRiver Riddle 533ba66435SRiver Riddle if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() > 543ba66435SRiver Riddle byteCountThreshold * 8) 553ba66435SRiver Riddle return rewriter.notifyMatchFailure(extractOp, 563ba66435SRiver Riddle "exceeding byte count threshold"); 573ba66435SRiver Riddle 583ba66435SRiver Riddle Location loc = extractOp.getLoc(); 593ba66435SRiver Riddle 603ba66435SRiver Riddle int64_t rank = tensorType.getRank(); 613ba66435SRiver Riddle SmallVector<int64_t, 4> strides(rank, 1); 623ba66435SRiver Riddle for (int i = rank - 2; i >= 0; --i) { 633ba66435SRiver Riddle strides[i] = strides[i + 1] * tensorType.getDimSize(i + 1); 643ba66435SRiver Riddle } 653ba66435SRiver Riddle 668df54a6aSJacques Pienaar Type varType = spirv::PointerType::get(adaptor.getTensor().getType(), 673ba66435SRiver Riddle spirv::StorageClass::Function); 683ba66435SRiver Riddle 693ba66435SRiver Riddle spirv::VariableOp varOp; 708df54a6aSJacques Pienaar if (adaptor.getTensor().getDefiningOp<spirv::ConstantOp>()) { 71f148f3b2SStanley Winata // We could use the initializer directly; but certain driver compilers 725ab6ef75SJakub Kuderski // have bugs dealing with that. So for now, use spirv.Store for 73f148f3b2SStanley Winata // initialization. 74f148f3b2SStanley Winata varOp = rewriter.create<spirv::VariableOp>(loc, varType, 75f148f3b2SStanley Winata spirv::StorageClass::Function, 76f148f3b2SStanley Winata /*initializer=*/nullptr); 77f148f3b2SStanley Winata rewriter.create<spirv::StoreOp>(loc, varOp, adaptor.getTensor()); 783ba66435SRiver Riddle } else { 793ba66435SRiver Riddle // Need to store the value to the local variable. It's questionable 803ba66435SRiver Riddle // whether we want to support such case though. 813ba66435SRiver Riddle return failure(); 823ba66435SRiver Riddle } 833ba66435SRiver Riddle 843ba66435SRiver Riddle auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 853ba66435SRiver Riddle auto indexType = typeConverter.getIndexType(); 863ba66435SRiver Riddle 878df54a6aSJacques Pienaar Value index = spirv::linearizeIndex(adaptor.getIndices(), strides, 883ba66435SRiver Riddle /*offset=*/0, indexType, loc, rewriter); 893ba66435SRiver Riddle auto acOp = rewriter.create<spirv::AccessChainOp>(loc, varOp, index); 903ba66435SRiver Riddle 913ba66435SRiver Riddle rewriter.replaceOpWithNewOp<spirv::LoadOp>(extractOp, acOp); 923ba66435SRiver Riddle 933ba66435SRiver Riddle return success(); 943ba66435SRiver Riddle } 953ba66435SRiver Riddle 963ba66435SRiver Riddle private: 973ba66435SRiver Riddle int64_t byteCountThreshold; 983ba66435SRiver Riddle }; 993ba66435SRiver Riddle 1003ba66435SRiver Riddle } // namespace 1013ba66435SRiver Riddle 1023ba66435SRiver Riddle //===----------------------------------------------------------------------===// 1033ba66435SRiver Riddle // Pattern population 1043ba66435SRiver Riddle //===----------------------------------------------------------------------===// 1053ba66435SRiver Riddle 106*206fad0eSMatthias Springer void mlir::populateTensorToSPIRVPatterns( 107*206fad0eSMatthias Springer const SPIRVTypeConverter &typeConverter, int64_t byteCountThreshold, 1083ba66435SRiver Riddle RewritePatternSet &patterns) { 1093ba66435SRiver Riddle patterns.add<TensorExtractPattern>(typeConverter, patterns.getContext(), 1103ba66435SRiver Riddle byteCountThreshold); 1113ba66435SRiver Riddle } 112