xref: /llvm-project/mlir/lib/Dialect/NVGPU/Transforms/Utils.cpp (revision db393288ff9b07d2c13b2be6a61f91ba180cb601)
1 //===- Utils.cpp - Transform utilities ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/NVGPU/Transforms/Utils.h"
10 
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
13 #include "mlir/Dialect/Vector/IR/VectorOps.h"
14 
15 using namespace mlir;
16 using namespace mlir::nvgpu;
17 
getIndices(Operation * op)18 Operation::operand_range nvgpu::getIndices(Operation *op) {
19   if (auto ldmatrixOp = dyn_cast<LdMatrixOp>(op))
20     return ldmatrixOp.getIndices();
21   if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op))
22     return copyOp.getDstIndices();
23   if (auto loadOp = dyn_cast<memref::LoadOp>(op))
24     return loadOp.getIndices();
25   if (auto storeOp = dyn_cast<memref::StoreOp>(op))
26     return storeOp.getIndices();
27   if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
28     return vectorReadOp.getIndices();
29   if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
30     return vectorStoreOp.getIndices();
31   if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
32     return transferReadOp.getIndices();
33   if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
34     return transferWriteOp.getIndices();
35   llvm_unreachable("unsupported op type");
36 }
37 
setIndices(Operation * op,ArrayRef<Value> indices)38 void nvgpu::setIndices(Operation *op, ArrayRef<Value> indices) {
39   if (auto ldmatrixOp = dyn_cast<LdMatrixOp>(op))
40     return ldmatrixOp.getIndicesMutable().assign(indices);
41   if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op))
42     return copyOp.getDstIndicesMutable().assign(indices);
43   if (auto loadOp = dyn_cast<memref::LoadOp>(op))
44     return loadOp.getIndicesMutable().assign(indices);
45   if (auto storeOp = dyn_cast<memref::StoreOp>(op))
46     return storeOp.getIndicesMutable().assign(indices);
47   if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
48     return vectorReadOp.getIndicesMutable().assign(indices);
49   if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
50     return vectorStoreOp.getIndicesMutable().assign(indices);
51   if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
52     return transferReadOp.getIndicesMutable().assign(indices);
53   if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
54     return transferWriteOp.getIndicesMutable().assign(indices);
55   llvm_unreachable("unsupported op type");
56 }
57 
getValueStored(Operation * op)58 Value nvgpu::getValueStored(Operation *op) {
59   if (auto storeOp = dyn_cast<memref::StoreOp>(op))
60     return storeOp.getValueToStore();
61   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
62     return transferWrite.getValue();
63   if (auto storeOp = dyn_cast<vector::StoreOp>(op))
64     return storeOp.getValueToStore();
65   llvm_unreachable("unsupported op type");
66 }
67 
getMemrefOperand(Operation * op)68 Value nvgpu::getMemrefOperand(Operation *op) {
69   if (auto loadOp = dyn_cast<memref::LoadOp>(op))
70     return loadOp.getMemref();
71   if (auto storeOp = dyn_cast<memref::StoreOp>(op))
72     return storeOp.getMemref();
73   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
74     return transferWrite.getSource();
75   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
76     return transferRead.getSource();
77   if (auto storeOp = dyn_cast<vector::StoreOp>(op))
78     return storeOp.getBase();
79   if (auto loadOp = dyn_cast<vector::LoadOp>(op))
80     return loadOp.getBase();
81   llvm_unreachable("unsupported op type");
82 }
83