xref: /llvm-project/mlir/test/Dialect/GPU/all-reduce-maxf.mlir (revision 560564f51c626cf89920f13b6cea96684bac5848)
1// RUN: mlir-opt -test-gpu-rewrite %s | FileCheck %s
2
3// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
4// CHECK: gpu.module @kernels {
5gpu.module @kernels {
6
7  // CHECK-LABEL: gpu.func @kernel(
8  // CHECK-SAME: [[VAL_0:%.*]]: f32) workgroup([[VAL_1:%.*]] : memref<32xf32, #gpu.address_space<workgroup>>) kernel {
9  gpu.func @kernel(%arg0 : f32) kernel {
10    // CHECK-DAG:   [[VAL_2:%.*]] = arith.constant 31 : i32
11    // CHECK-DAG:   [[VAL_3:%.*]] = arith.constant 0 : i32
12    // CHECK-DAG:   [[VAL_4:%.*]] = arith.constant 0 : index
13    // CHECK-DAG:   [[VAL_5:%.*]] = arith.constant 32 : i32
14    // CHECK-DAG:   [[VAL_6:%.*]] = arith.constant 1 : i32
15    // CHECK-DAG:   [[VAL_7:%.*]] = arith.constant 2 : i32
16    // CHECK-DAG:   [[VAL_8:%.*]] = arith.constant 4 : i32
17    // CHECK-DAG:   [[VAL_9:%.*]] = arith.constant 8 : i32
18    // CHECK-DAG:   [[VAL_10:%.*]] = arith.constant 16 : i32
19    // CHECK:   [[VAL_11:%.*]] = gpu.block_dim x
20    // CHECK:   [[VAL_12:%.*]] = arith.index_cast [[VAL_11]] : index to i32
21    // CHECK:   [[VAL_13:%.*]] = gpu.block_dim y
22    // CHECK:   [[VAL_14:%.*]] = arith.index_cast [[VAL_13]] : index to i32
23    // CHECK:   [[VAL_15:%.*]] = gpu.block_dim z
24    // CHECK:   [[VAL_16:%.*]] = arith.index_cast [[VAL_15]] : index to i32
25    // CHECK:   [[VAL_17:%.*]] = gpu.thread_id x
26    // CHECK:   [[VAL_18:%.*]] = arith.index_cast [[VAL_17]] : index to i32
27    // CHECK:   [[VAL_19:%.*]] = gpu.thread_id y
28    // CHECK:   [[VAL_20:%.*]] = arith.index_cast [[VAL_19]] : index to i32
29    // CHECK:   [[VAL_21:%.*]] = gpu.thread_id z
30    // CHECK:   [[VAL_22:%.*]] = arith.index_cast [[VAL_21]] : index to i32
31    // CHECK:   [[VAL_23:%.*]] = arith.muli [[VAL_22]], [[VAL_14]] : i32
32    // CHECK:   [[VAL_24:%.*]] = arith.addi [[VAL_23]], [[VAL_20]] : i32
33    // CHECK:   [[VAL_25:%.*]] = arith.muli [[VAL_24]], [[VAL_12]] : i32
34    // CHECK:   [[VAL_26:%.*]] = arith.muli [[VAL_12]], [[VAL_14]] : i32
35    // CHECK:   [[VAL_27:%.*]] = arith.addi [[VAL_25]], [[VAL_18]] : i32
36    // CHECK:   [[VAL_28:%.*]] = arith.muli [[VAL_26]], [[VAL_16]] : i32
37    // CHECK:   [[VAL_29:%.*]] = arith.andi [[VAL_27]], [[VAL_2]] : i32
38    // CHECK:   [[VAL_30:%.*]] = arith.cmpi eq, [[VAL_29]], [[VAL_3]] : i32
39    // CHECK:   [[VAL_31:%.*]] = arith.subi [[VAL_27]], [[VAL_29]] : i32
40    // CHECK:   [[VAL_32:%.*]] = arith.subi [[VAL_28]], [[VAL_31]] : i32
41    // CHECK:   [[VAL_33:%.*]] = arith.cmpi slt, [[VAL_32]], [[VAL_5]] : i32
42    // CHECK:   cf.cond_br [[VAL_33]], ^bb1, ^bb17
43    // CHECK: ^bb1:
44    // CHECK:   [[VAL_34:%.*]], [[VAL_35:%.*]] = gpu.shuffle xor [[VAL_0]], [[VAL_6]], [[VAL_32]] : f32
45    // CHECK:   cf.cond_br [[VAL_35]], ^bb2, ^bb3
46    // CHECK: ^bb2:
47    // CHECK:   [[VAL_36:%.*]] = arith.maxnumf [[VAL_0]], [[VAL_34]] : f32
48    // CHECK:   cf.br ^bb4([[VAL_36]] : f32)
49    // CHECK: ^bb3:
50    // CHECK:   cf.br ^bb4([[VAL_0]] : f32)
51    // CHECK: ^bb4([[VAL_38:%.*]]: f32):
52    // CHECK:   [[VAL_39:%.*]], [[VAL_40:%.*]] = gpu.shuffle xor [[VAL_38]], [[VAL_7]], [[VAL_32]] : f32
53    // CHECK:   cf.cond_br [[VAL_40]], ^bb5, ^bb6
54    // CHECK: ^bb5:
55    // CHECK:   [[VAL_41:%.*]] = arith.maxnumf [[VAL_38]], [[VAL_39]] : f32
56    // CHECK:   cf.br ^bb7([[VAL_41]] : f32)
57    // CHECK: ^bb6:
58    // CHECK:   cf.br ^bb7([[VAL_38]] : f32)
59    // CHECK: ^bb7([[VAL_43:%.*]]: f32):
60    // CHECK:   [[VAL_44:%.*]], [[VAL_45:%.*]] = gpu.shuffle xor [[VAL_43]], [[VAL_8]], [[VAL_32]] : f32
61    // CHECK:   cf.cond_br [[VAL_45]], ^bb8, ^bb9
62    // CHECK: ^bb8:
63    // CHECK:   [[VAL_46:%.*]] = arith.maxnumf [[VAL_43]], [[VAL_44]] : f32
64    // CHECK:   cf.br ^bb10([[VAL_46]] : f32)
65    // CHECK: ^bb9:
66    // CHECK:   cf.br ^bb10([[VAL_43]] : f32)
67    // CHECK: ^bb10([[VAL_48:%.*]]: f32):
68    // CHECK:   [[VAL_49:%.*]], [[VAL_50:%.*]] = gpu.shuffle xor [[VAL_48]], [[VAL_9]], [[VAL_32]] : f32
69    // CHECK:   cf.cond_br [[VAL_50]], ^bb11, ^bb12
70    // CHECK: ^bb11:
71    // CHECK:   [[VAL_51:%.*]] = arith.maxnumf [[VAL_48]], [[VAL_49]] : f32
72    // CHECK:   cf.br ^bb13([[VAL_51]] : f32)
73    // CHECK: ^bb12:
74    // CHECK:   cf.br ^bb13([[VAL_48]] : f32)
75    // CHECK: ^bb13([[VAL_53:%.*]]: f32):
76    // CHECK:   [[VAL_54:%.*]], [[VAL_55:%.*]] = gpu.shuffle xor [[VAL_53]], [[VAL_10]], [[VAL_32]] : f32
77    // CHECK:   cf.cond_br [[VAL_55]], ^bb14, ^bb15
78    // CHECK: ^bb14:
79    // CHECK:   [[VAL_56:%.*]] = arith.maxnumf [[VAL_53]], [[VAL_54]] : f32
80    // CHECK:   cf.br ^bb16([[VAL_56]] : f32)
81    // CHECK: ^bb15:
82    // CHECK:   cf.br ^bb16([[VAL_53]] : f32)
83    // CHECK: ^bb16([[VAL_58:%.*]]: f32):
84    // CHECK:   cf.br ^bb18([[VAL_58]] : f32)
85    // CHECK: ^bb17:
86    // CHECK:   [[VAL_59:%.*]], [[VAL_60:%.*]] = gpu.shuffle xor [[VAL_0]], [[VAL_6]], [[VAL_5]] : f32
87    // CHECK:   [[VAL_62:%.*]] = arith.maxnumf [[VAL_0]], [[VAL_59]] : f32
88    // CHECK:   [[VAL_63:%.*]], [[VAL_64:%.*]] = gpu.shuffle xor [[VAL_62]], [[VAL_7]], [[VAL_5]] : f32
89    // CHECK:   [[VAL_66:%.*]] = arith.maxnumf [[VAL_62]], [[VAL_63]] : f32
90    // CHECK:   [[VAL_67:%.*]], [[VAL_68:%.*]] = gpu.shuffle xor [[VAL_66]], [[VAL_8]], [[VAL_5]] : f32
91    // CHECK:   [[VAL_70:%.*]] = arith.maxnumf [[VAL_66]], [[VAL_67]] : f32
92    // CHECK:   [[VAL_71:%.*]], [[VAL_72:%.*]] = gpu.shuffle xor [[VAL_70]], [[VAL_9]], [[VAL_5]] : f32
93    // CHECK:   [[VAL_74:%.*]] = arith.maxnumf [[VAL_70]], [[VAL_71]] : f32
94    // CHECK:   [[VAL_75:%.*]], [[VAL_76:%.*]] = gpu.shuffle xor [[VAL_74]], [[VAL_10]], [[VAL_5]] : f32
95    // CHECK:   [[VAL_78:%.*]] = arith.maxnumf [[VAL_74]], [[VAL_75]] : f32
96    // CHECK:   cf.br ^bb18([[VAL_78]] : f32)
97    // CHECK: ^bb18([[VAL_79:%.*]]: f32):
98    // CHECK:   cf.cond_br [[VAL_30]], ^bb19, ^bb20
99    // CHECK: ^bb19:
100    // CHECK:   [[VAL_80:%.*]] = arith.divsi [[VAL_27]], [[VAL_5]] : i32
101    // CHECK:   [[VAL_81:%.*]] = arith.index_cast [[VAL_80]] : i32 to index
102    // CHECK:   store [[VAL_79]], [[VAL_1]]{{\[}}[[VAL_81]]] : memref<32xf32, #gpu.address_space<workgroup>>
103    // CHECK:   cf.br ^bb21
104    // CHECK: ^bb20:
105    // CHECK:   cf.br ^bb21
106    // CHECK: ^bb21:
107    // CHECK:   gpu.barrier
108    // CHECK:   [[VAL_82:%.*]] = arith.addi [[VAL_28]], [[VAL_2]] : i32
109    // CHECK:   [[VAL_83:%.*]] = arith.divsi [[VAL_82]], [[VAL_5]] : i32
110    // CHECK:   [[VAL_84:%.*]] = arith.cmpi slt, [[VAL_27]], [[VAL_83]] : i32
111    // CHECK:   cf.cond_br [[VAL_84]], ^bb22, ^bb41
112    // CHECK: ^bb22:
113    // CHECK:   [[VAL_85:%.*]] = arith.index_cast [[VAL_27]] : i32 to index
114    // CHECK:   [[VAL_86:%.*]] = memref.load [[VAL_1]]{{\[}}[[VAL_85]]] : memref<32xf32, #gpu.address_space<workgroup>>
115    // CHECK:   [[VAL_87:%.*]] = arith.cmpi slt, [[VAL_83]], [[VAL_5]] : i32
116    // CHECK:   cf.cond_br [[VAL_87]], ^bb23, ^bb39
117    // CHECK: ^bb23:
118    // CHECK:   [[VAL_88:%.*]], [[VAL_89:%.*]] = gpu.shuffle xor [[VAL_86]], [[VAL_6]], [[VAL_83]] : f32
119    // CHECK:   cf.cond_br [[VAL_89]], ^bb24, ^bb25
120    // CHECK: ^bb24:
121    // CHECK:   [[VAL_91:%.*]] = arith.maxnumf [[VAL_86]], [[VAL_88]] : f32
122    // CHECK:   cf.br ^bb26([[VAL_91]] : f32)
123    // CHECK: ^bb25:
124    // CHECK:   cf.br ^bb26([[VAL_86]] : f32)
125    // CHECK: ^bb26([[VAL_92:%.*]]: f32):
126    // CHECK:   [[VAL_93:%.*]], [[VAL_94:%.*]] = gpu.shuffle xor [[VAL_92]], [[VAL_7]], [[VAL_83]] : f32
127    // CHECK:   cf.cond_br [[VAL_94]], ^bb27, ^bb28
128    // CHECK: ^bb27:
129    // CHECK:   [[VAL_96:%.*]] = arith.maxnumf [[VAL_92]], [[VAL_93]] : f32
130    // CHECK:   cf.br ^bb29([[VAL_96]] : f32)
131    // CHECK: ^bb28:
132    // CHECK:   cf.br ^bb29([[VAL_92]] : f32)
133    // CHECK: ^bb29([[VAL_97:%.*]]: f32):
134    // CHECK:   [[VAL_98:%.*]], [[VAL_99:%.*]] = gpu.shuffle xor [[VAL_97]], [[VAL_8]], [[VAL_83]] : f32
135    // CHECK:   cf.cond_br [[VAL_99]], ^bb30, ^bb31
136    // CHECK: ^bb30:
137    // CHECK:   [[VAL_101:%.*]] = arith.maxnumf [[VAL_97]], [[VAL_98]] : f32
138    // CHECK:   cf.br ^bb32([[VAL_101]] : f32)
139    // CHECK: ^bb31:
140    // CHECK:   cf.br ^bb32([[VAL_97]] : f32)
141    // CHECK: ^bb32([[VAL_102:%.*]]: f32):
142    // CHECK:   [[VAL_103:%.*]], [[VAL_104:%.*]] = gpu.shuffle xor [[VAL_102]], [[VAL_9]], [[VAL_83]] : f32
143    // CHECK:   cf.cond_br [[VAL_104]], ^bb33, ^bb34
144    // CHECK: ^bb33:
145    // CHECK:   [[VAL_106:%.*]] = arith.maxnumf [[VAL_102]], [[VAL_103]] : f32
146    // CHECK:   cf.br ^bb35([[VAL_106]] : f32)
147    // CHECK: ^bb34:
148    // CHECK:   cf.br ^bb35([[VAL_102]] : f32)
149    // CHECK: ^bb35([[VAL_107:%.*]]: f32):
150    // CHECK:   [[VAL_108:%.*]], [[VAL_109:%.*]] = gpu.shuffle xor [[VAL_107]], [[VAL_10]], [[VAL_83]] : f32
151    // CHECK:   cf.cond_br [[VAL_109]], ^bb36, ^bb37
152    // CHECK: ^bb36:
153    // CHECK:   [[VAL_111:%.*]] = arith.maxnumf [[VAL_107]], [[VAL_108]] : f32
154    // CHECK:   cf.br ^bb38([[VAL_111]] : f32)
155    // CHECK: ^bb37:
156    // CHECK:   cf.br ^bb38([[VAL_107]] : f32)
157    // CHECK: ^bb38([[VAL_112:%.*]]: f32):
158    // CHECK:   cf.br ^bb40([[VAL_112]] : f32)
159    // CHECK: ^bb39:
160    // CHECK:   [[VAL_113:%.*]], [[VAL_114:%.*]] = gpu.shuffle xor [[VAL_86]], [[VAL_6]], [[VAL_5]] : f32
161    // CHECK:   [[VAL_116:%.*]] = arith.maxnumf [[VAL_86]], [[VAL_113]] : f32
162    // CHECK:   [[VAL_117:%.*]], [[VAL_118:%.*]] = gpu.shuffle xor [[VAL_116]], [[VAL_7]], [[VAL_5]] : f32
163    // CHECK:   [[VAL_120:%.*]] = arith.maxnumf [[VAL_116]], [[VAL_117]] : f32
164    // CHECK:   [[VAL_121:%.*]], [[VAL_122:%.*]] = gpu.shuffle xor [[VAL_120]], [[VAL_8]], [[VAL_5]] : f32
165    // CHECK:   [[VAL_124:%.*]] = arith.maxnumf [[VAL_120]], [[VAL_121]] : f32
166    // CHECK:   [[VAL_125:%.*]], [[VAL_126:%.*]] = gpu.shuffle xor [[VAL_124]], [[VAL_9]], [[VAL_5]] : f32
167    // CHECK:   [[VAL_128:%.*]] = arith.maxnumf [[VAL_124]], [[VAL_125]] : f32
168    // CHECK:   [[VAL_129:%.*]], [[VAL_130:%.*]] = gpu.shuffle xor [[VAL_128]], [[VAL_10]], [[VAL_5]] : f32
169    // CHECK:   [[VAL_132:%.*]] = arith.maxnumf [[VAL_128]], [[VAL_129]] : f32
170    // CHECK:   cf.br ^bb40([[VAL_132]] : f32)
171    // CHECK: ^bb40([[VAL_133:%.*]]: f32):
172    // CHECK:   store [[VAL_133]], [[VAL_1]]{{\[}}[[VAL_4]]] : memref<32xf32, #gpu.address_space<workgroup>>
173    // CHECK:   cf.br ^bb42
174    // CHECK: ^bb41:
175    // CHECK:   cf.br ^bb42
176    // CHECK: ^bb42:
177    // CHECK:   gpu.barrier
178    %sum = gpu.all_reduce maxnumf %arg0 uniform {} : (f32) -> (f32)
179    gpu.return
180  }
181
182}
183