108778d8cSAlex Zinenko //===- MemoryPromotion.cpp - Utilities for moving data across GPU memories ===//
208778d8cSAlex Zinenko //
308778d8cSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
408778d8cSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
508778d8cSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
608778d8cSAlex Zinenko //
708778d8cSAlex Zinenko //===----------------------------------------------------------------------===//
808778d8cSAlex Zinenko //
908778d8cSAlex Zinenko // This file implements utilities that allow one to create IR moving the data
1008778d8cSAlex Zinenko // across different levels of the GPU memory hierarchy.
1108778d8cSAlex Zinenko //
1208778d8cSAlex Zinenko //===----------------------------------------------------------------------===//
1308778d8cSAlex Zinenko
14d7ef488bSMogball #include "mlir/Dialect/GPU/Transforms/MemoryPromotion.h"
15d7ef488bSMogball
16a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopUtils.h"
17abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
18d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h"
198eb18a0fSNicolas Vasilache #include "mlir/Dialect/MemRef/IR/MemRef.h"
208b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
21e3cf7c88SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
2208778d8cSAlex Zinenko #include "mlir/Pass/Pass.h"
2308778d8cSAlex Zinenko
2408778d8cSAlex Zinenko using namespace mlir;
2508778d8cSAlex Zinenko using namespace mlir::gpu;
2608778d8cSAlex Zinenko
2708778d8cSAlex Zinenko /// Emits the (imperfect) loop nest performing the copy between "from" and "to"
2808778d8cSAlex Zinenko /// values using the bounds derived from the "from" value. Emits at least
2908778d8cSAlex Zinenko /// GPUDialect::getNumWorkgroupDimensions() loops, completing the nest with
3008778d8cSAlex Zinenko /// single-iteration loops. Maps the innermost loops to thread dimensions, in
3108778d8cSAlex Zinenko /// reverse order to enable access coalescing in the innermost loop.
insertCopyLoops(ImplicitLocOpBuilder & b,Value from,Value to)32e3cf7c88SNicolas Vasilache static void insertCopyLoops(ImplicitLocOpBuilder &b, Value from, Value to) {
335550c821STres Popp auto memRefType = cast<MemRefType>(from.getType());
34e3cf7c88SNicolas Vasilache auto rank = memRefType.getRank();
35e3cf7c88SNicolas Vasilache
36367229e1SNicolas Vasilache SmallVector<Value, 4> lbs, ubs, steps;
37a54f4eaeSMogball Value zero = b.create<arith::ConstantIndexOp>(0);
38a54f4eaeSMogball Value one = b.create<arith::ConstantIndexOp>(1);
3908778d8cSAlex Zinenko
4008778d8cSAlex Zinenko // Make sure we have enough loops to use all thread dimensions, these trivial
4108778d8cSAlex Zinenko // loops should be outermost and therefore inserted first.
4208778d8cSAlex Zinenko if (rank < GPUDialect::getNumWorkgroupDimensions()) {
4308778d8cSAlex Zinenko unsigned extraLoops = GPUDialect::getNumWorkgroupDimensions() - rank;
4408778d8cSAlex Zinenko lbs.resize(extraLoops, zero);
4508778d8cSAlex Zinenko ubs.resize(extraLoops, one);
4608778d8cSAlex Zinenko steps.resize(extraLoops, one);
4708778d8cSAlex Zinenko }
4808778d8cSAlex Zinenko
4973f371c3SKazuaki Ishizaki // Add existing bounds.
50e3cf7c88SNicolas Vasilache lbs.append(rank, zero);
51e3cf7c88SNicolas Vasilache ubs.reserve(lbs.size());
5208778d8cSAlex Zinenko steps.reserve(lbs.size());
53e3cf7c88SNicolas Vasilache for (auto idx = 0; idx < rank; ++idx) {
54*b23c8225SMatthias Springer ubs.push_back(b.createOrFold<memref::DimOp>(from, idx));
55e3cf7c88SNicolas Vasilache steps.push_back(one);
56e3cf7c88SNicolas Vasilache }
5708778d8cSAlex Zinenko
5808778d8cSAlex Zinenko // Obtain thread identifiers and block sizes, necessary to map to them.
5984a880e1SNicolas Vasilache auto indexType = b.getIndexType();
6008778d8cSAlex Zinenko SmallVector<Value, 3> threadIds, blockDims;
61aae51255SMogball for (auto dim : {gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z}) {
62aae51255SMogball threadIds.push_back(b.create<gpu::ThreadIdOp>(indexType, dim));
63aae51255SMogball blockDims.push_back(b.create<gpu::BlockDimOp>(indexType, dim));
6408778d8cSAlex Zinenko }
6508778d8cSAlex Zinenko
6608778d8cSAlex Zinenko // Produce the loop nest with copies.
67367229e1SNicolas Vasilache SmallVector<Value, 8> ivs(lbs.size());
6884a880e1SNicolas Vasilache mlir::scf::buildLoopNest(
69e3cf7c88SNicolas Vasilache b, b.getLoc(), lbs, ubs, steps,
7084a880e1SNicolas Vasilache [&](OpBuilder &b, Location loc, ValueRange loopIvs) {
71d1560f39SAlex Zinenko ivs.assign(loopIvs.begin(), loopIvs.end());
72984b800aSserge-sans-paille auto activeIvs = llvm::ArrayRef(ivs).take_back(rank);
7384a880e1SNicolas Vasilache Value loaded = b.create<memref::LoadOp>(loc, from, activeIvs);
7484a880e1SNicolas Vasilache b.create<memref::StoreOp>(loc, loaded, to, activeIvs);
7508778d8cSAlex Zinenko });
7608778d8cSAlex Zinenko
7708778d8cSAlex Zinenko // Map the innermost loops to threads in reverse order.
78e4853be2SMehdi Amini for (const auto &en :
79984b800aSserge-sans-paille llvm::enumerate(llvm::reverse(llvm::ArrayRef(ivs).take_back(
8008778d8cSAlex Zinenko GPUDialect::getNumWorkgroupDimensions())))) {
81367229e1SNicolas Vasilache Value v = en.value();
82c25b20c0SAlex Zinenko auto loop = cast<scf::ForOp>(v.getParentRegion()->getParentOp());
834c48f016SMatthias Springer affine::mapLoopToProcessorIds(loop, {threadIds[en.index()]},
8408778d8cSAlex Zinenko {blockDims[en.index()]});
8508778d8cSAlex Zinenko }
8608778d8cSAlex Zinenko }
8708778d8cSAlex Zinenko
8808778d8cSAlex Zinenko /// Emits the loop nests performing the copy to the designated location in the
8908778d8cSAlex Zinenko /// beginning of the region, and from the designated location immediately before
9008778d8cSAlex Zinenko /// the terminator of the first block of the region. The region is expected to
9108778d8cSAlex Zinenko /// have one block. This boils down to the following structure
9208778d8cSAlex Zinenko ///
9308778d8cSAlex Zinenko /// ^bb(...):
9408778d8cSAlex Zinenko /// <loop-bound-computation>
9508778d8cSAlex Zinenko /// for %arg0 = ... to ... step ... {
9608778d8cSAlex Zinenko /// ...
9708778d8cSAlex Zinenko /// for %argN = <thread-id-x> to ... step <block-dim-x> {
9808778d8cSAlex Zinenko /// %0 = load %from[%arg0, ..., %argN]
9908778d8cSAlex Zinenko /// store %0, %to[%arg0, ..., %argN]
10008778d8cSAlex Zinenko /// }
10108778d8cSAlex Zinenko /// ...
10208778d8cSAlex Zinenko /// }
10308778d8cSAlex Zinenko /// gpu.barrier
10408778d8cSAlex Zinenko /// <... original body ...>
10508778d8cSAlex Zinenko /// gpu.barrier
10608778d8cSAlex Zinenko /// for %arg0 = ... to ... step ... {
10708778d8cSAlex Zinenko /// ...
10808778d8cSAlex Zinenko /// for %argN = <thread-id-x> to ... step <block-dim-x> {
10908778d8cSAlex Zinenko /// %1 = load %to[%arg0, ..., %argN]
11008778d8cSAlex Zinenko /// store %1, %from[%arg0, ..., %argN]
11108778d8cSAlex Zinenko /// }
11208778d8cSAlex Zinenko /// ...
11308778d8cSAlex Zinenko /// }
11408778d8cSAlex Zinenko ///
11508778d8cSAlex Zinenko /// Inserts the barriers unconditionally since different threads may be copying
11608778d8cSAlex Zinenko /// values and reading them. An analysis would be required to eliminate barriers
11708778d8cSAlex Zinenko /// in case where value is only used by the thread that copies it. Both copies
11808778d8cSAlex Zinenko /// are inserted unconditionally, an analysis would be required to only copy
11908778d8cSAlex Zinenko /// live-in and live-out values when necessary. This copies the entire memref
12008778d8cSAlex Zinenko /// pointed to by "from". In case a smaller block would be sufficient, the
12108778d8cSAlex Zinenko /// caller can create a subview of the memref and promote it instead.
insertCopies(Region & region,Location loc,Value from,Value to)12208778d8cSAlex Zinenko static void insertCopies(Region ®ion, Location loc, Value from, Value to) {
1235550c821STres Popp auto fromType = cast<MemRefType>(from.getType());
1245550c821STres Popp auto toType = cast<MemRefType>(to.getType());
12508778d8cSAlex Zinenko (void)fromType;
12608778d8cSAlex Zinenko (void)toType;
12708778d8cSAlex Zinenko assert(fromType.getShape() == toType.getShape());
12808778d8cSAlex Zinenko assert(fromType.getRank() != 0);
129204c3b55SRiver Riddle assert(llvm::hasSingleElement(region) &&
13008778d8cSAlex Zinenko "unstructured control flow not supported");
13108778d8cSAlex Zinenko
132e3cf7c88SNicolas Vasilache auto b = ImplicitLocOpBuilder::atBlockBegin(loc, ®ion.front());
133e3cf7c88SNicolas Vasilache insertCopyLoops(b, from, to);
134e3cf7c88SNicolas Vasilache b.create<gpu::BarrierOp>();
13508778d8cSAlex Zinenko
13684a880e1SNicolas Vasilache b.setInsertionPoint(®ion.front().back());
137e3cf7c88SNicolas Vasilache b.create<gpu::BarrierOp>();
138e3cf7c88SNicolas Vasilache insertCopyLoops(b, to, from);
13908778d8cSAlex Zinenko }
14008778d8cSAlex Zinenko
14108778d8cSAlex Zinenko /// Promotes a function argument to workgroup memory in the given function. The
14208778d8cSAlex Zinenko /// copies will be inserted in the beginning and in the end of the function.
promoteToWorkgroupMemory(GPUFuncOp op,unsigned arg)14308778d8cSAlex Zinenko void mlir::promoteToWorkgroupMemory(GPUFuncOp op, unsigned arg) {
14408778d8cSAlex Zinenko Value value = op.getArgument(arg);
1455550c821STres Popp auto type = dyn_cast<MemRefType>(value.getType());
14608778d8cSAlex Zinenko assert(type && type.hasStaticShape() && "can only promote memrefs");
14708778d8cSAlex Zinenko
148ad398164SWen-Heng (Jack) Chung // Get the type of the buffer in the workgroup memory.
1496ca1a09fSChristopher Bate auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get(
1506ca1a09fSChristopher Bate op->getContext(), gpu::AddressSpace::Workgroup);
1516ca1a09fSChristopher Bate auto bufferType = MemRefType::get(type.getShape(), type.getElementType(),
1526ca1a09fSChristopher Bate MemRefLayoutAttrInterface{},
1536ca1a09fSChristopher Bate Attribute(workgroupMemoryAddressSpace));
154e084679fSRiver Riddle Value attribution = op.addWorkgroupAttribution(bufferType, value.getLoc());
15508778d8cSAlex Zinenko
15608778d8cSAlex Zinenko // Replace the uses first since only the original uses are currently present.
15708778d8cSAlex Zinenko // Then insert the copies.
15808778d8cSAlex Zinenko value.replaceAllUsesWith(attribution);
15908778d8cSAlex Zinenko insertCopies(op.getBody(), op.getLoc(), value, attribution);
16008778d8cSAlex Zinenko }
161