xref: /llvm-project/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
13ba66435SRiver Riddle //===- TensorToSPIRV.cpp - Tensor to SPIR-V Patterns ----------------------===//
23ba66435SRiver Riddle //
33ba66435SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43ba66435SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
53ba66435SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63ba66435SRiver Riddle //
73ba66435SRiver Riddle //===----------------------------------------------------------------------===//
83ba66435SRiver Riddle //
93ba66435SRiver Riddle // This file implements patterns to convert Tensor dialect to SPIR-V dialect.
103ba66435SRiver Riddle //
113ba66435SRiver Riddle //===----------------------------------------------------------------------===//
123ba66435SRiver Riddle 
133ba66435SRiver Riddle #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h"
143ba66435SRiver Riddle #include "../SPIRVCommon/Pattern.h"
153ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
163ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
173ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
183ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
193ba66435SRiver Riddle #include "mlir/Dialect/Tensor/IR/Tensor.h"
203ba66435SRiver Riddle #include "mlir/IR/AffineMap.h"
213ba66435SRiver Riddle #include "llvm/Support/Debug.h"
223ba66435SRiver Riddle 
233ba66435SRiver Riddle #define DEBUG_TYPE "tensor-to-spirv-pattern"
243ba66435SRiver Riddle 
253ba66435SRiver Riddle using namespace mlir;
263ba66435SRiver Riddle 
273ba66435SRiver Riddle //===----------------------------------------------------------------------===//
283ba66435SRiver Riddle // Operation conversion
293ba66435SRiver Riddle //===----------------------------------------------------------------------===//
303ba66435SRiver Riddle 
313ba66435SRiver Riddle namespace {
323ba66435SRiver Riddle 
333ba66435SRiver Riddle /// Converts tensor.extract into loading using access chains from SPIR-V local
343ba66435SRiver Riddle /// variables.
353ba66435SRiver Riddle class TensorExtractPattern final
363ba66435SRiver Riddle     : public OpConversionPattern<tensor::ExtractOp> {
373ba66435SRiver Riddle public:
38*206fad0eSMatthias Springer   TensorExtractPattern(const TypeConverter &typeConverter, MLIRContext *context,
393ba66435SRiver Riddle                        int64_t threshold, PatternBenefit benefit = 1)
403ba66435SRiver Riddle       : OpConversionPattern(typeConverter, context, benefit),
413ba66435SRiver Riddle         byteCountThreshold(threshold) {}
423ba66435SRiver Riddle 
433ba66435SRiver Riddle   LogicalResult
443ba66435SRiver Riddle   matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor,
453ba66435SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
465550c821STres Popp     auto tensorType = cast<RankedTensorType>(extractOp.getTensor().getType());
473ba66435SRiver Riddle 
48f4b9839dSLongsheng Mou     if (!isa<spirv::ScalarType>(tensorType.getElementType()))
49f4b9839dSLongsheng Mou       return rewriter.notifyMatchFailure(extractOp, "unsupported type");
503ba66435SRiver Riddle     if (!tensorType.hasStaticShape())
513ba66435SRiver Riddle       return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
523ba66435SRiver Riddle 
533ba66435SRiver Riddle     if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() >
543ba66435SRiver Riddle         byteCountThreshold * 8)
553ba66435SRiver Riddle       return rewriter.notifyMatchFailure(extractOp,
563ba66435SRiver Riddle                                          "exceeding byte count threshold");
573ba66435SRiver Riddle 
583ba66435SRiver Riddle     Location loc = extractOp.getLoc();
593ba66435SRiver Riddle 
603ba66435SRiver Riddle     int64_t rank = tensorType.getRank();
613ba66435SRiver Riddle     SmallVector<int64_t, 4> strides(rank, 1);
623ba66435SRiver Riddle     for (int i = rank - 2; i >= 0; --i) {
633ba66435SRiver Riddle       strides[i] = strides[i + 1] * tensorType.getDimSize(i + 1);
643ba66435SRiver Riddle     }
653ba66435SRiver Riddle 
668df54a6aSJacques Pienaar     Type varType = spirv::PointerType::get(adaptor.getTensor().getType(),
673ba66435SRiver Riddle                                            spirv::StorageClass::Function);
683ba66435SRiver Riddle 
693ba66435SRiver Riddle     spirv::VariableOp varOp;
708df54a6aSJacques Pienaar     if (adaptor.getTensor().getDefiningOp<spirv::ConstantOp>()) {
71f148f3b2SStanley Winata       // We could use the initializer directly; but certain driver compilers
725ab6ef75SJakub Kuderski       // have bugs dealing with that. So for now, use spirv.Store for
73f148f3b2SStanley Winata       // initialization.
74f148f3b2SStanley Winata       varOp = rewriter.create<spirv::VariableOp>(loc, varType,
75f148f3b2SStanley Winata                                                  spirv::StorageClass::Function,
76f148f3b2SStanley Winata                                                  /*initializer=*/nullptr);
77f148f3b2SStanley Winata       rewriter.create<spirv::StoreOp>(loc, varOp, adaptor.getTensor());
783ba66435SRiver Riddle     } else {
793ba66435SRiver Riddle       // Need to store the value to the local variable. It's questionable
803ba66435SRiver Riddle       // whether we want to support such case though.
813ba66435SRiver Riddle       return failure();
823ba66435SRiver Riddle     }
833ba66435SRiver Riddle 
843ba66435SRiver Riddle     auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
853ba66435SRiver Riddle     auto indexType = typeConverter.getIndexType();
863ba66435SRiver Riddle 
878df54a6aSJacques Pienaar     Value index = spirv::linearizeIndex(adaptor.getIndices(), strides,
883ba66435SRiver Riddle                                         /*offset=*/0, indexType, loc, rewriter);
893ba66435SRiver Riddle     auto acOp = rewriter.create<spirv::AccessChainOp>(loc, varOp, index);
903ba66435SRiver Riddle 
913ba66435SRiver Riddle     rewriter.replaceOpWithNewOp<spirv::LoadOp>(extractOp, acOp);
923ba66435SRiver Riddle 
933ba66435SRiver Riddle     return success();
943ba66435SRiver Riddle   }
953ba66435SRiver Riddle 
963ba66435SRiver Riddle private:
973ba66435SRiver Riddle   int64_t byteCountThreshold;
983ba66435SRiver Riddle };
993ba66435SRiver Riddle 
1003ba66435SRiver Riddle } // namespace
1013ba66435SRiver Riddle 
1023ba66435SRiver Riddle //===----------------------------------------------------------------------===//
1033ba66435SRiver Riddle // Pattern population
1043ba66435SRiver Riddle //===----------------------------------------------------------------------===//
1053ba66435SRiver Riddle 
106*206fad0eSMatthias Springer void mlir::populateTensorToSPIRVPatterns(
107*206fad0eSMatthias Springer     const SPIRVTypeConverter &typeConverter, int64_t byteCountThreshold,
1083ba66435SRiver Riddle     RewritePatternSet &patterns) {
1093ba66435SRiver Riddle   patterns.add<TensorExtractPattern>(typeConverter, patterns.getContext(),
1103ba66435SRiver Riddle                                      byteCountThreshold);
1113ba66435SRiver Riddle }
112