xref: /llvm-project/mlir/lib/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.cpp (revision 0aa831e0edb1c1deabb96ce2435667cc82bac79b)
1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
12 #include "mlir/Interfaces/InferIntRangeInterface.h"
13 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
14 
15 using namespace mlir;
16 using namespace mlir::gpu;
17 
18 namespace {
19 /// Implement ValueBoundsOpInterface (which only works on index-typed values,
20 /// gathers a set of constraint expressions, and is used for affine analyses)
21 /// in terms of InferIntRangeInterface (which works
22 /// on arbitrary integer types, creates [min, max] ranges, and is used in for
23 /// arithmetic simplification).
24 template <typename Op>
25 struct GpuIdOpInterface
26     : public ValueBoundsOpInterface::ExternalModel<GpuIdOpInterface<Op>, Op> {
27   void populateBoundsForIndexValue(Operation *op, Value value,
28                                    ValueBoundsConstraintSet &cstr) const {
29     auto inferrable = cast<InferIntRangeInterface>(op);
30     assert(value == op->getResult(0) &&
31            "inferring for value that isn't the GPU op's result");
32     auto translateConstraint = [&](Value v, const ConstantIntRanges &range) {
33       assert(v == value &&
34              "GPU ID op inferring values for something that's not its result");
35       cstr.bound(v) >= range.smin().getSExtValue();
36       cstr.bound(v) <= range.smax().getSExtValue();
37     };
38     assert(inferrable->getNumOperands() == 0 && "ID ops have no operands");
39     inferrable.inferResultRanges({}, translateConstraint);
40   }
41 };
42 
43 struct GpuLaunchOpInterface
44     : public ValueBoundsOpInterface::ExternalModel<GpuLaunchOpInterface,
45                                                    LaunchOp> {
46   void populateBoundsForIndexValue(Operation *op, Value value,
47                                    ValueBoundsConstraintSet &cstr) const {
48     auto launchOp = cast<LaunchOp>(op);
49 
50     Value sizeArg = nullptr;
51     bool isSize = false;
52     KernelDim3 gridSizeArgs = launchOp.getGridSizeOperandValues();
53     KernelDim3 blockSizeArgs = launchOp.getBlockSizeOperandValues();
54 
55     auto match = [&](KernelDim3 bodyArgs, KernelDim3 externalArgs,
56                      bool areSizeArgs) {
57       if (value == bodyArgs.x) {
58         sizeArg = externalArgs.x;
59         isSize = areSizeArgs;
60       }
61       if (value == bodyArgs.y) {
62         sizeArg = externalArgs.y;
63         isSize = areSizeArgs;
64       }
65       if (value == bodyArgs.z) {
66         sizeArg = externalArgs.z;
67         isSize = areSizeArgs;
68       }
69     };
70     match(launchOp.getThreadIds(), blockSizeArgs, false);
71     match(launchOp.getBlockSize(), blockSizeArgs, true);
72     match(launchOp.getBlockIds(), gridSizeArgs, false);
73     match(launchOp.getGridSize(), gridSizeArgs, true);
74     if (launchOp.hasClusterSize()) {
75       KernelDim3 clusterSizeArgs = *launchOp.getClusterSizeOperandValues();
76       match(*launchOp.getClusterIds(), clusterSizeArgs, false);
77       match(*launchOp.getClusterSize(), clusterSizeArgs, true);
78     }
79 
80     if (!sizeArg)
81       return;
82     if (isSize) {
83       cstr.bound(value) == cstr.getExpr(sizeArg);
84       cstr.bound(value) >= 1;
85     } else {
86       cstr.bound(value) < cstr.getExpr(sizeArg);
87       cstr.bound(value) >= 0;
88     }
89   }
90 };
91 } // namespace
92 
93 void mlir::gpu::registerValueBoundsOpInterfaceExternalModels(
94     DialectRegistry &registry) {
95   registry.addExtension(+[](MLIRContext *ctx, GPUDialect *dialect) {
96 #define REGISTER(X) X::attachInterface<GpuIdOpInterface<X>>(*ctx);
97     REGISTER(ClusterDimOp)
98     REGISTER(ClusterDimBlocksOp)
99     REGISTER(ClusterIdOp)
100     REGISTER(ClusterBlockIdOp)
101     REGISTER(BlockDimOp)
102     REGISTER(BlockIdOp)
103     REGISTER(GridDimOp)
104     REGISTER(ThreadIdOp)
105     REGISTER(LaneIdOp)
106     REGISTER(SubgroupIdOp)
107     REGISTER(GlobalIdOp)
108     REGISTER(NumSubgroupsOp)
109     REGISTER(SubgroupSizeOp)
110 #undef REGISTER
111 
112     LaunchOp::attachInterface<GpuLaunchOpInterface>(*ctx);
113   });
114 }
115