1 //===- Utils.cpp - Transform utilities ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/NVGPU/Transforms/Utils.h" 10 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 13 #include "mlir/Dialect/Vector/IR/VectorOps.h" 14 15 using namespace mlir; 16 using namespace mlir::nvgpu; 17 getIndices(Operation * op)18Operation::operand_range nvgpu::getIndices(Operation *op) { 19 if (auto ldmatrixOp = dyn_cast<LdMatrixOp>(op)) 20 return ldmatrixOp.getIndices(); 21 if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op)) 22 return copyOp.getDstIndices(); 23 if (auto loadOp = dyn_cast<memref::LoadOp>(op)) 24 return loadOp.getIndices(); 25 if (auto storeOp = dyn_cast<memref::StoreOp>(op)) 26 return storeOp.getIndices(); 27 if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op)) 28 return vectorReadOp.getIndices(); 29 if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op)) 30 return vectorStoreOp.getIndices(); 31 if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op)) 32 return transferReadOp.getIndices(); 33 if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op)) 34 return transferWriteOp.getIndices(); 35 llvm_unreachable("unsupported op type"); 36 } 37 setIndices(Operation * op,ArrayRef<Value> indices)38void nvgpu::setIndices(Operation *op, ArrayRef<Value> indices) { 39 if (auto ldmatrixOp = dyn_cast<LdMatrixOp>(op)) 40 return ldmatrixOp.getIndicesMutable().assign(indices); 41 if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op)) 42 return copyOp.getDstIndicesMutable().assign(indices); 43 if (auto loadOp = dyn_cast<memref::LoadOp>(op)) 44 return loadOp.getIndicesMutable().assign(indices); 45 if (auto storeOp = dyn_cast<memref::StoreOp>(op)) 46 return storeOp.getIndicesMutable().assign(indices); 47 if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op)) 48 return vectorReadOp.getIndicesMutable().assign(indices); 49 if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op)) 50 return vectorStoreOp.getIndicesMutable().assign(indices); 51 if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op)) 52 return transferReadOp.getIndicesMutable().assign(indices); 53 if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op)) 54 return transferWriteOp.getIndicesMutable().assign(indices); 55 llvm_unreachable("unsupported op type"); 56 } 57 getValueStored(Operation * op)58Value nvgpu::getValueStored(Operation *op) { 59 if (auto storeOp = dyn_cast<memref::StoreOp>(op)) 60 return storeOp.getValueToStore(); 61 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) 62 return transferWrite.getValue(); 63 if (auto storeOp = dyn_cast<vector::StoreOp>(op)) 64 return storeOp.getValueToStore(); 65 llvm_unreachable("unsupported op type"); 66 } 67 getMemrefOperand(Operation * op)68Value nvgpu::getMemrefOperand(Operation *op) { 69 if (auto loadOp = dyn_cast<memref::LoadOp>(op)) 70 return loadOp.getMemref(); 71 if (auto storeOp = dyn_cast<memref::StoreOp>(op)) 72 return storeOp.getMemref(); 73 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) 74 return transferWrite.getSource(); 75 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) 76 return transferRead.getSource(); 77 if (auto storeOp = dyn_cast<vector::StoreOp>(op)) 78 return storeOp.getBase(); 79 if (auto loadOp = dyn_cast<vector::LoadOp>(op)) 80 return loadOp.getBase(); 81 llvm_unreachable("unsupported op type"); 82 } 83