xref: /llvm-project/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (revision 39358f846d1e336def88ff9c25581fab392d59fe)
1from ..lang import *
2
3T1 = TV.T1
4T2 = TV.T2
5
6Batch = S.Batch
7
8
9@linalg_structured_op
10def copy(
11    I=TensorDef(T1),
12    O=TensorDef(U, output=True),
13    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
14):
15    """Copies the tensor elementwise.
16
17    Numeric casting is performed on the input operand, promoting it to the same
18    data type as the accumulator/output.
19    """
20    defines(Canonicalizer)
21    O[None] = cast(U, I[None])
22
23
24@linalg_structured_op
25def elemwise_unary(
26    I=TensorDef(T1),
27    O=TensorDef(U, output=True),
28    fun=UnaryFnAttrDef(default=UnaryFn.exp),
29    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
30):
31    """Applies the unary function fun elementwise.
32
33    Numeric casting is performed on the input operand, promoting it to the same
34    data type as the accumulator/output.
35    """
36    O[None] = fun(cast(U, I[None]))
37
38
39@linalg_structured_op
40def exp(
41    I=TensorDef(T1),
42    O=TensorDef(T1, output=True),
43):
44    """Applies exp(x) elementwise.
45
46    No numeric casting is performed on the input operand.
47    """
48    O[None] = UnaryFn.exp(I[None])
49
50
51@linalg_structured_op
52def log(
53    I=TensorDef(T1),
54    O=TensorDef(T1, output=True),
55):
56    """Applies log(x) elementwise.
57
58    No numeric casting is performed on the input operand.
59    """
60    O[None] = UnaryFn.log(I[None])
61
62
63@linalg_structured_op
64def abs(
65    I=TensorDef(T1),
66    O=TensorDef(T1, output=True),
67):
68    """Applies abs(x) elementwise.
69
70    No numeric casting is performed on the input operand.
71    """
72    O[None] = UnaryFn.abs(I[None])
73
74
75@linalg_structured_op
76def ceil(
77    I=TensorDef(T1),
78    O=TensorDef(T1, output=True),
79):
80    """Applies ceil(x) elementwise.
81
82    No numeric casting is performed on the input operand.
83    """
84    O[None] = UnaryFn.ceil(I[None])
85
86
87@linalg_structured_op
88def floor(
89    I=TensorDef(T1),
90    O=TensorDef(T1, output=True),
91):
92    """Applies floor(x) elementwise.
93
94    No numeric casting is performed on the input operand.
95    """
96    O[None] = UnaryFn.floor(I[None])
97
98
99@linalg_structured_op(op_class_name="NegFOp")
100def negf(
101    I=TensorDef(T1),
102    O=TensorDef(T1, output=True),
103):
104    """Applies negf(x) elementwise.
105
106    No numeric casting is performed on the input operand.
107    """
108    O[None] = UnaryFn.negf(I[None])
109
110
111@linalg_structured_op(op_class_name="ReciprocalOp")
112def reciprocal(
113    I=TensorDef(T1),
114    O=TensorDef(T1, output=True),
115):
116    """Applies reciprocal(x) elementwise.
117
118    No numeric casting is performed on the input operand.
119    """
120    O[None] = UnaryFn.reciprocal(I[None])
121
122
123@linalg_structured_op
124def round(
125    I=TensorDef(T1),
126    O=TensorDef(T1, output=True),
127):
128    """Applies round(x) elementwise.
129
130    No numeric casting is performed on the input operand.
131    """
132    O[None] = UnaryFn.round(I[None])
133
134
135@linalg_structured_op
136def sqrt(
137    I=TensorDef(T1),
138    O=TensorDef(T1, output=True),
139):
140    """Applies sqrt(x) elementwise.
141
142    No numeric casting is performed on the input operand.
143    """
144    O[None] = UnaryFn.sqrt(I[None])
145
146
147@linalg_structured_op
148def rsqrt(
149    I=TensorDef(T1),
150    O=TensorDef(T1, output=True),
151):
152    """Applies rsqrt(x) elementwise.
153
154    No numeric casting is performed on the input operand.
155    """
156    O[None] = UnaryFn.rsqrt(I[None])
157
158
159@linalg_structured_op
160def square(
161    I=TensorDef(T1),
162    O=TensorDef(T1, output=True),
163):
164    """Applies square(x) elementwise.
165
166    No numeric casting is performed on the input operand.
167    """
168    O[None] = UnaryFn.square(I[None])
169
170
171@linalg_structured_op
172def tanh(
173    I=TensorDef(T1),
174    O=TensorDef(T1, output=True),
175):
176    """Applies tanh(x) elementwise.
177
178    No numeric casting is performed on the input operand.
179    """
180    O[None] = UnaryFn.tanh(I[None])
181
182
183@linalg_structured_op
184def erf(
185    I=TensorDef(T1),
186    O=TensorDef(T1, output=True),
187):
188    """Applies erf(x) elementwise.
189
190    No numeric casting is performed on the input operand.
191    """
192    O[None] = UnaryFn.erf(I[None])
193
194
195@linalg_structured_op
196def elemwise_binary(
197    lhs=TensorDef(T1),
198    rhs=TensorDef(T2),
199    O=TensorDef(U, output=True),
200    fun=BinaryFnAttrDef(default=BinaryFn.add),
201    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
202):
203    """Applies the binary function fun elementwise.
204
205    Numeric casting is performed on the input operand, promoting it to the same
206    data type as the accumulator/output.
207    """
208    O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
209
210
211@linalg_structured_op
212def add(
213    lhs=TensorDef(T1),
214    rhs=TensorDef(T1),
215    O=TensorDef(T1, output=True),
216):
217    """Adds two tensors elementwise.
218
219    The shapes and element types must be identical. The appropriate casts,
220    broadcasts and reductions should be done previously to calling this op.
221
222    This means reduction/broadcast/element cast semantics is explicit. Further
223    passes can take that into account when lowering this code. For example,
224    a `linalg.broadcast` + `linalg.add` sequence can be lowered to a
225    `linalg.generic` with different affine maps for the two operands.
226    """
227    O[None] = BinaryFn.add(lhs[None], rhs[None])
228
229
230@linalg_structured_op
231def sub(
232    lhs=TensorDef(T1),
233    rhs=TensorDef(T1),
234    O=TensorDef(T1, output=True),
235):
236    """Subtracts two tensors elementwise.
237
238    The shapes and element types must be identical. The appropriate casts,
239    broadcasts and reductions should be done previously to calling this op.
240
241    This means reduction/broadcast/element cast semantics is explicit. Further
242    passes can take that into account when lowering this code. For example,
243    a `linalg.broadcast` + `linalg.sub` sequence can be lowered to a
244    `linalg.generic` with different affine maps for the two operands.
245    """
246    O[None] = BinaryFn.sub(lhs[None], rhs[None])
247
248
249@linalg_structured_op
250def mul(
251    lhs=TensorDef(T1),
252    rhs=TensorDef(T1),
253    O=TensorDef(T1, output=True),
254):
255    """Multiplies two tensors elementwise.
256
257    The shapes and element types must be identical. The appropriate casts,
258    broadcasts and reductions should be done previously to calling this op.
259
260    This means reduction/broadcast/element cast semantics is explicit. Further
261    passes can take that into account when lowering this code. For example,
262    a `linalg.broadcast` + `linalg.mul` sequence can be lowered to a
263    `linalg.generic` with different affine maps for the two operands.
264    """
265    O[None] = BinaryFn.mul(lhs[None], rhs[None])
266
267
268@linalg_structured_op
269def div(
270    lhs=TensorDef(T1),
271    rhs=TensorDef(T1),
272    O=TensorDef(T1, output=True),
273):
274    """Divides the first tensor by the second tensor, elementwise.
275
276    The shapes and element types must be identical. The appropriate casts,
277    broadcasts and reductions should be done previously to calling this op.
278
279    This means reduction/broadcast/element cast semantics is explicit. Further
280    passes can take that into account when lowering this code. For example,
281    a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
282    `linalg.generic` with different affine maps for the two operands.
283    """
284    O[None] = BinaryFn.div(lhs[None], rhs[None])
285
286
287@linalg_structured_op
288def div_unsigned(
289    lhs=TensorDef(T1),
290    rhs=TensorDef(T1),
291    O=TensorDef(T1, output=True),
292):
293    """Divides the first tensor by the second tensor, elementwise. For integer
294    types, performs an unsigned division.
295
296    The shapes and element types must be identical. The appropriate casts,
297    broadcasts and reductions should be done previously to calling this op.
298
299    This means reduction/broadcast/element cast semantics is explicit. Further
300    passes can take that into account when lowering this code. For example,
301    a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
302    `linalg.generic` with different affine maps for the two operands.
303    """
304    O[None] = BinaryFn.div_unsigned(lhs[None], rhs[None])
305
306
307@linalg_structured_op
308def max(
309    lhs=TensorDef(T1),
310    rhs=TensorDef(T1),
311    O=TensorDef(T1, output=True),
312):
313    """Takes the max (signed) between two inputs, elementwise.
314
315    The shapes and element types must be identical. The appropriate casts,
316    broadcasts and reductions should be done previously to calling this op.
317
318    This means reduction/broadcast/element cast semantics is explicit. Further
319    passes can take that into account when lowering this code. For example,
320    a `linalg.broadcast` + `linalg.max` sequence can be lowered to a
321    `linalg.generic` with different affine maps for the two operands.
322    """
323    O[None] = BinaryFn.max_signed(lhs[None], rhs[None])
324
325
326@linalg_structured_op
327def min(
328    lhs=TensorDef(T1),
329    rhs=TensorDef(T1),
330    O=TensorDef(T1, output=True),
331):
332    """Takes the min (signed) between two inputs, elementwise.
333
334    The shapes and element types must be identical. The appropriate casts,
335    broadcasts and reductions should be done previously to calling this op.
336
337    This means reduction/broadcast/element cast semantics is explicit. Further
338    passes can take that into account when lowering this code. For example,
339    a `linalg.broadcast` + `linalg.min` sequence can be lowered to a
340    `linalg.generic` with different affine maps for the two operands.
341    """
342    O[None] = BinaryFn.min_signed(lhs[None], rhs[None])
343
344
345@linalg_structured_op(op_class_name="PowFOp")
346def powf(
347    lhs=TensorDef(T1),
348    rhs=TensorDef(T1),
349    O=TensorDef(T1, output=True),
350):
351    """Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`.
352
353    Only applies to floating point values.
354
355    The shapes and element types must be identical. The appropriate casts,
356    broadcasts and reductions should be done previously to calling this op.
357
358    This means reduction/broadcast/element cast semantics is explicit. Further
359    passes can take that into account when lowering this code. For example,
360    a `linalg.broadcast` + `linalg.powf` sequence can be lowered to a
361    `linalg.generic` with different affine maps for the two operands.
362    """
363    O[None] = BinaryFn.powf(lhs[None], rhs[None])
364
365
366@linalg_structured_op
367def select(
368    cond=TensorDef(U),
369    lhs=TensorDef(T1),
370    rhs=TensorDef(T1),
371    O=TensorDef(T1, output=True),
372):
373    """Chooses one value based on a binary condition supplied as its first operand.
374
375    The shapes and element types must be identical. The appropriate casts,
376    broadcasts and reductions should be done previously to calling this op.
377
378    This means reduction/broadcast/element cast semantics is explicit. Further
379    passes can take that into account when lowering this code. For example,
380    a `linalg.broadcast` + `linalg.select` sequence can be lowered to a
381    `linalg.generic` with different affine maps for the two operands.
382    """
383    O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])
384
385
386@linalg_structured_op
387def quantized_matmul(
388    A=TensorDef(T1, S.M, S.K),
389    B=TensorDef(T2, S.K, S.N),
390    AZp=ScalarDef(I32),
391    BZp=ScalarDef(I32),
392    C=TensorDef(U, S.M, S.N, output=True),
393):
394    """Performs a matrix multiplication of two 2D inputs.
395
396    Numeric casting is performed on the operands to the inner multiply, promoting
397    them to the same data type as the accumulator/output. The quantized variant
398    includes zero-point adjustments for the left and right operands of the
399    matmul.
400    """
401    domain(D.m, D.n, D.k)
402    C[D.m, D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * (
403        TypeFn.cast_signed(U, B[D.k, D.n]) - TypeFn.cast_signed(U, BZp)
404    )
405
406
407@linalg_structured_op
408def matmul_transpose_a(
409    A=TensorDef(T1, S.K, S.N),
410    B=TensorDef(T2, S.K, S.M),
411    C=TensorDef(U, S.M, S.N, output=True),
412    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
413):
414    """Performs a matrix multiplication of two 2D inputs with lhs operand
415    transposed.
416
417    Numeric casting is performed on the operands to the inner multiply, promoting
418    them to the same data type as the accumulator/output.
419    """
420    domain(D.m, D.n, D.k)
421    implements(ContractionOpInterface)
422    C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
423
424
425@linalg_structured_op
426def matmul_transpose_b(
427    A=TensorDef(T1, S.M, S.K),
428    B=TensorDef(T2, S.N, S.K),
429    C=TensorDef(U, S.M, S.N, output=True),
430    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
431):
432    """Performs a matrix multiplication of two 2D inputs with rhs operand
433    transposed.
434
435    Numeric casting is performed on the operands to the inner multiply, promoting
436    them to the same data type as the accumulator/output.
437    """
438    domain(D.m, D.n, D.k)
439    implements(ContractionOpInterface)
440    C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k])
441
442
443@linalg_structured_op
444def mmt4d(
445    lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
446    rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
447    accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True),
448):
449    """Performs a matrix-matrix-transpose multiplication of two 4D inputs.
450
451    Differences from linalg.matmul:
452    * The right hand side is transposed, whence the 't' in 'mmt'.
453    * The input and output tensors have a 4D shape instead of a 2D shape. They
454      are interpreted as 2D matrices with one level of 2D tile subdivision,
455      whence the 2+2=4 dimensions. The inner tile dimensions are identified with
456      '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
457      as: MxK tiles, each of shape M0xK0.
458    """
459    domain(D.m, D.n, D.k, D.m0, D.n0, D.k0)
460    implements(ContractionOpInterface)
461    accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed(
462        TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]
463    ) * TypeFn.cast_signed(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
464
465
466@linalg_structured_op
467def batch_mmt4d(
468    lhs=TensorDef(TV.LhsType, Batch, S.M, S.K, S.M0, S.K0),
469    rhs=TensorDef(TV.RhsType, Batch, S.N, S.K, S.N0, S.K0),
470    accum=TensorDef(TV.AccumType, Batch, S.M, S.N, S.M0, S.N0, output=True),
471):
472    """Performs a batched matrix-matrix-transpose multiplication of two
473    batched-4D (5D) inputs.
474
475    Besides the outermost batch dimension has the same semantic as
476    linalg.batch_matmul, the differences from linalg.batch_matmul in the
477    non-batch dimensions are the same as linalg.mmt4d vs. linalg.matmul. See the
478    description of lingalg.mmt4d.
479    """
480    domain(D.b, D.m, D.n, D.k, D.m0, D.n0, D.k0)
481    implements(ContractionOpInterface)
482    accum[D.b, D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed(
483        TV.AccumType, lhs[D.b, D.m, D.k, D.m0, D.k0]
484    ) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0])
485
486
487@linalg_structured_op
488def batch_matmul(
489    A=TensorDef(T1, Batch, S.M, S.K),
490    B=TensorDef(T2, Batch, S.K, S.N),
491    C=TensorDef(U, Batch, S.M, S.N, output=True),
492):
493    """Performs a batched matrix multiplication of two 3D inputs.
494
495    Numeric casting is performed on the operands to the inner multiply, promoting
496    them to the same data type as the accumulator/output.
497    """
498    domain(D.b, D.m, D.n, D.k)
499    implements(ContractionOpInterface)
500    C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
501        U, B[D.b, D.k, D.n]
502    )
503
504
505@linalg_structured_op
506def batch_matmul_transpose_a(
507    A=TensorDef(T1, Batch, S.K, S.M),
508    B=TensorDef(T2, Batch, S.K, S.N),
509    C=TensorDef(U, Batch, S.M, S.N, output=True),
510):
511    """Performs a batched matrix multiplication of two 3D inputs where lhs operand
512    has its non-batch dimensions transposed.
513
514    Numeric casting is performed on the operands to the inner multiply, promoting
515    them to the same data type as the accumulator/output.
516    """
517    domain(D.b, D.m, D.n, D.k)
518    implements(ContractionOpInterface)
519    C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed(
520        U, B[D.b, D.k, D.n]
521    )
522
523
524@linalg_structured_op
525def batch_matmul_transpose_b(
526    A=TensorDef(T1, Batch, S.M, S.K),
527    B=TensorDef(T2, Batch, S.N, S.K),
528    C=TensorDef(U, Batch, S.M, S.N, output=True),
529):
530    """Performs a batched matrix multiplication of two 3D inputs where rhs operand
531    has its non-batch dimensions transposed.
532
533    Numeric casting is performed on the operands to the inner multiply, promoting
534    them to the same data type as the accumulator/output.
535    """
536    domain(D.b, D.m, D.n, D.k)
537    implements(ContractionOpInterface)
538    C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
539        U, B[D.b, D.n, D.k]
540    )
541
542
543@linalg_structured_op
544def quantized_batch_matmul(
545    A=TensorDef(T1, Batch, S.M, S.K),
546    B=TensorDef(T2, Batch, S.K, S.N),
547    AZp=ScalarDef(I32),
548    BZp=ScalarDef(I32),
549    C=TensorDef(U, Batch, S.M, S.N, output=True),
550):
551    """Performs a batched matrix multiplication of two 3D inputs.
552
553    Numeric casting is performed on the operands to the inner multiply, promoting
554    them to the same data type as the accumulator/output. The quantized variant
555    includes zero-point adjustments for the left and right operands of the
556    matmul.
557    """
558    domain(D.b, D.m, D.n, D.k)
559    C[D.b, D.m, D.n] += (
560        TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - TypeFn.cast_signed(U, AZp)
561    ) * (TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
562
563
564@linalg_structured_op
565def batch_reduce_matmul(
566    A=TensorDef(T1, Batch, S.M, S.K),
567    B=TensorDef(T2, Batch, S.K, S.N),
568    C=TensorDef(U, S.M, S.N, output=True),
569):
570    """Performs a batch-reduce matrix multiplication of two 3D inputs.
571    The partial multiplication results are reduced into a 2D output.
572
573    Numeric casting is performed on the operands to the inner multiply, promoting
574    them to the same data type as the accumulator/output.
575    """
576    domain(D.b, D.m, D.n, D.k)
577    implements(ContractionOpInterface)
578    C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
579        U, B[D.b, D.k, D.n]
580    )
581
582
583@linalg_structured_op
584def matvec(
585    A=TensorDef(T1, S.M, S.N), y=TensorDef(T2, S.N), x=TensorDef(U, S.M, output=True)
586):
587    """Performs a matrix-vector multiplication.
588
589    Numeric casting is performed on the operands to the inner multiply, promoting
590    them to the same data type as the accumulator/output.
591    """
592    domain(D.m, D.n)
593    implements(ContractionOpInterface)
594    x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n])
595
596
597@linalg_structured_op
598def vecmat(
599    y=TensorDef(T1, S.M), A=TensorDef(T2, S.M, S.N), x=TensorDef(U, S.N, output=True)
600):
601    """Performs a vector-matrix multiplication.
602
603    Numeric casting is performed on the operands to the inner multiply, promoting
604    them to the same data type as the accumulator/output.
605    """
606    domain(D.n, D.m)
607    implements(ContractionOpInterface)
608    x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n])
609
610
611@linalg_structured_op
612def batch_matvec(
613    A=TensorDef(T1, Batch, S.M, S.K),
614    B=TensorDef(T2, Batch, S.K),
615    C=TensorDef(U, Batch, S.M, output=True),
616):
617    """Performs a batched matrix-vector multiplication.
618
619    Numeric casting is performed on the operands to the inner multiply, promoting
620    them to the same data type as the accumulator/output.
621    """
622    domain(D.b, D.m, D.k)
623    implements(ContractionOpInterface)
624    C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
625        U, B[D.b, D.k]
626    )
627
628
629@linalg_structured_op
630def batch_vecmat(
631    A=TensorDef(T1, Batch, S.K),
632    B=TensorDef(T2, Batch, S.K, S.N),
633    C=TensorDef(U, Batch, S.N, output=True),
634):
635    """Performs a batched matrix-vector multiplication.
636
637    Numeric casting is performed on the operands to the inner multiply, promoting
638    them to the same data type as the accumulator/output.
639    """
640    domain(D.b, D.n, D.k)
641    implements(ContractionOpInterface)
642    C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed(
643        U, B[D.b, D.k, D.n]
644    )
645
646
647@linalg_structured_op
648def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
649    """Performs a dot product of two vectors to a scalar result.
650
651    Numeric casting is performed on the operands to the inner multiply, promoting
652    them to the same data type as the accumulator/output.
653    """
654    implements(ContractionOpInterface)
655    C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
656
657
658@linalg_structured_op
659def conv_1d(
660    I=TensorDef(T1, S.OW + S.KW),
661    K=TensorDef(T2, S.KW),
662    O=TensorDef(U, S.OW, output=True),
663):
664    """Performs 1-D convolution with no channels.
665
666    Numeric casting is performed on the operands to the inner multiply, promoting
667    them to the same data type as the accumulator/output.
668    """
669    implements(ConvolutionOpInterface)
670    domain(D.ow, D.kw)
671    O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kw])
672
673
674@linalg_structured_op
675def conv_2d(
676    I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW),
677    K=TensorDef(T2, S.KH, S.KW),
678    O=TensorDef(U, S.OH, S.OW, output=True),
679):
680    """Performs 2-D convolution with no channels.
681
682    Numeric casting is performed on the operands to the inner multiply, promoting
683    them to the same data type as the accumulator/output.
684    """
685    implements(ConvolutionOpInterface)
686    domain(D.oh, D.ow, D.kh, D.kw)
687    O[D.oh, D.ow] += TypeFn.cast_signed(
688        U, I[D.oh + D.kh, D.ow + D.kw]
689    ) * TypeFn.cast_signed(U, K[D.kh, D.kw])
690
691
692@linalg_structured_op
693def conv_3d(
694    I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW),
695    K=TensorDef(T2, S.KD, S.KH, S.KW),
696    O=TensorDef(U, S.OD, S.OH, S.OW, output=True),
697):
698    """Performs 3-D convolution with no channels.
699
700    Numeric casting is performed on the operands to the inner multiply, promoting
701    them to the same data type as the accumulator/output.
702    """
703    implements(ConvolutionOpInterface)
704    domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw)
705    O[D.od, D.oh, D.ow] += TypeFn.cast_signed(
706        U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]
707    ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw])
708
709
710@linalg_structured_op
711def conv_1d_nwc_wcf(
712    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
713    K=TensorDef(T2, S.KW, S.C, S.F),
714    O=TensorDef(U, S.N, S.OW, S.F, output=True),
715    strides=IndexAttrDef(S.SW, default=[1]),
716    dilations=IndexAttrDef(S.DW, default=[1]),
717):
718    """Performs 1-D convolution.
719
720    Numeric casting is performed on the operands to the inner multiply, promoting
721    them to the same data type as the accumulator/output.
722    """
723    implements(ConvolutionOpInterface)
724    domain(D.n, D.ow, D.f, D.kw, D.c)
725    O[D.n, D.ow, D.f] += TypeFn.cast_signed(
726        U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]
727    ) * TypeFn.cast_signed(U, K[D.kw, D.c, D.f])
728
729
730@linalg_structured_op
731def conv_1d_ncw_fcw(
732    I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
733    K=TensorDef(T2, S.F, S.C, S.KW),
734    O=TensorDef(U, S.N, S.F, S.OW, output=True),
735    strides=IndexAttrDef(S.SW, default=[1]),
736    dilations=IndexAttrDef(S.DW, default=[1]),
737):
738    """Performs 1-D convolution.
739
740    Layout:
741      * Input: NCW.
742      * Kernel: FCW.
743
744    Numeric casting is performed on the operands to the inner multiply, promoting
745    them to the same data type as the accumulator/output.
746    """
747    implements(ConvolutionOpInterface)
748    domain(D.n, D.f, D.ow, D.c, D.kw)
749    O[D.n, D.f, D.ow] += TypeFn.cast_signed(
750        U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]
751    ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kw])
752
753
754@linalg_structured_op
755def conv_2d_nhwc_hwcf(
756    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
757    K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
758    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
759    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
760    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
761):
762    """Performs 2-D convolution.
763
764    Layout:
765      * Input: NHWC.
766      * Kernel: HWCF.
767
768    Numeric casting is performed on the operands to the inner multiply, promoting
769    them to the same data type as the accumulator/output.
770    """
771    implements(ConvolutionOpInterface)
772    domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
773    O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
774        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
775    ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f])
776
777
778@linalg_structured_op
779def conv_2d_nhwc_fhwc(
780    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
781    K=TensorDef(T2, S.F, S.KH, S.KW, S.C),
782    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
783    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
784    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
785):
786    """Performs 2-D convolution.
787
788    Layout:
789      * Input: NHWC.
790      * Kernel: FHWC.
791
792    Numeric casting is performed on the operands to the inner multiply, promoting
793    them to the same data type as the accumulator/output.
794    """
795    implements(ConvolutionOpInterface)
796    domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
797    O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
798        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
799    ) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c])
800
801
802@linalg_structured_op
803def conv_2d_nhwc_hwcf_q(
804    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
805    K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
806    IZp=ScalarDef(I32),
807    KZp=ScalarDef(I32),
808    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
809    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
810    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
811):
812    """Performs 2-D convolution with zero point offsets.
813
814    Layout:
815      * Input: NHWC.
816      * Kernel: HWCF.
817
818    Numeric casting is performed on the operands to the inner multiply, promoting
819    them to the same data type as the accumulator/output. This includes the zero
820    point offsets common to quantized operations.
821    """
822    implements(ConvolutionOpInterface)
823    domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
824    O[D.n, D.oh, D.ow, D.f] += (
825        TypeFn.cast_signed(
826            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
827        )
828        - TypeFn.cast_signed(U, IZp)
829    ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp))
830
831
832@linalg_structured_op
833def conv_2d_nhwc_fhwc_q(
834    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
835    K=TensorDef(T2, S.F, S.KH, S.KW, S.C),
836    IZp=ScalarDef(I32),
837    KZp=ScalarDef(I32),
838    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
839    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
840    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
841):
842    """Performs 2-D convolution with zero point offsets.
843
844    Layout:
845      * Input: NHWC.
846      * Kernel: FHWC.
847
848    Numeric casting is performed on the operands to the inner multiply, promoting
849    them to the same data type as the accumulator/output. This includes the zero
850    point offsets common to quantized operations.
851    """
852    implements(ConvolutionOpInterface)
853    domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
854    O[D.n, D.oh, D.ow, D.f] += (
855        TypeFn.cast_signed(
856            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
857        )
858        - TypeFn.cast_signed(U, IZp)
859    ) * (TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) - TypeFn.cast_signed(U, KZp))
860
861
862@linalg_structured_op
863def conv_2d_nchw_fchw_q(
864    I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
865    K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
866    IZp=ScalarDef(I32),
867    KZp=ScalarDef(I32),
868    O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
869    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
870    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
871):
872    """Performs 2-D convolution with zero point offsets.
873
874    Layout:
875      * Input: NCHW.
876      * Kernel: FCHW.
877
878    Numeric casting is performed on the operands to the inner multiply, promoting
879    them to the same data type as the accumulator/output. This includes the zero
880    point offsets common to quantized operations.
881    """
882    implements(ConvolutionOpInterface)
883    domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
884    O[D.n, D.f, D.oh, D.ow] += (
885        TypeFn.cast_signed(
886            U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
887        )
888        - TypeFn.cast_signed(U, IZp)
889    ) * (TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) - TypeFn.cast_signed(U, KZp))
890
891@linalg_structured_op
892def conv_2d_nchw_fchw(
893    I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
894    K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
895    O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
896    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
897    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
898):
899    """Performs 2-D convolution.
900
901    Layout:
902      * Input: NCHW.
903      * Kernel: FCHW.
904
905    Numeric casting is performed on the operands to the inner multiply, promoting
906    them to the same data type as the accumulator/output.
907    """
908    implements(ConvolutionOpInterface)
909    domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
910    O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed(
911        U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
912    ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw])
913
914
915@linalg_structured_op
916def conv_2d_ngchw_fgchw(
917    I=TensorDef(
918        T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
919    ),
920    K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
921    O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
922    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
923    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
924):
925    """Performs 2-D grouped convolution.
926
927    Layout:
928      * Input: NGCHW.
929      * Kernel: FGCHW.
930
931    Numeric casting is performed on the operands to the inner multiply, promoting
932    them to the same data type as the accumulator/output.
933    """
934    implements(ConvolutionOpInterface)
935    domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
936    O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
937        U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
938    ) * TypeFn.cast_signed(U, K[D.fg, D.g, D.c, D.kh, D.kw])
939
940
941@linalg_structured_op
942def conv_2d_ngchw_gfchw(
943    I=TensorDef(
944        T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
945    ),
946    K=TensorDef(T2, S.G, S.FG, S.C, S.KH, S.KW),
947    O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
948    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
949    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
950):
951    """Performs 2-D grouped convolution.
952
953    Layout:
954      * Input: NGCHW.
955      * Kernel: GFCHW.
956
957    Numeric casting is performed on the operands to the inner multiply, promoting
958    them to the same data type as the accumulator/output.
959    """
960    implements(ConvolutionOpInterface)
961    domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
962    O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
963        U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
964    ) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
965
966
967@linalg_structured_op
968def conv_2d_nhwgc_gfhwc(
969    I=TensorDef(
970        T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.G, S.C
971    ),
972    K=TensorDef(T2, S.G, S.FG, S.KH, S.KW, S.C),
973    O=TensorDef(U, S.N, S.OH, S.OW, S.G, S.FG, output=True),
974    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
975    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
976):
977    """Performs 2-D grouped convolution.
978
979    Layout:
980      * Input: NHWGC.
981      * Kernel: GFHWC.
982
983    Numeric casting is performed on the operands to the inner multiply, promoting
984    them to the same data type as the accumulator/output.
985    """
986    implements(ConvolutionOpInterface)
987    domain(D.n, D.oh, D.ow, D.g, D.fg, D.kh, D.kw, D.c)
988    O[D.n, D.oh, D.ow, D.g, D.fg] += TypeFn.cast_signed(
989        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.g, D.c]
990    ) * TypeFn.cast_signed(U, K[D.g, D.fg, D.kh, D.kw, D.c])
991
992
993@linalg_structured_op
994def conv_2d_nhwgc_gfhwc_q(
995    I=TensorDef(
996        T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.G, S.C
997    ),
998    K=TensorDef(T2, S.G, S.FG, S.KH, S.KW, S.C),
999    IZp=ScalarDef(I32),
1000    KZp=ScalarDef(I32),
1001    O=TensorDef(U, S.N, S.OH, S.OW, S.G, S.FG, output=True),
1002    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1003    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1004):
1005    """Performs 2-D grouped convolution with zero point offsets.
1006
1007    Layout:
1008      * Input: NHWGC.
1009      * Kernel: GFHWC.
1010
1011    Numeric casting is performed on the operands to the inner multiply, promoting
1012    them to the same data type as the accumulator/output. This includes the zero
1013    point offsets common to quantized operations.
1014    """
1015    implements(ConvolutionOpInterface)
1016    domain(D.n, D.oh, D.ow, D.g, D.fg, D.kh, D.kw, D.c)
1017    O[D.n, D.oh, D.ow, D.g, D.fg] += (
1018        TypeFn.cast_signed(
1019            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.g, D.c]
1020        )
1021        - TypeFn.cast_signed(U, IZp)
1022    ) * (
1023        TypeFn.cast_signed(U, K[D.g, D.fg, D.kh, D.kw, D.c])
1024        - TypeFn.cast_signed(U, KZp)
1025    )
1026
1027
1028@linalg_structured_op
1029def conv_2d_ngchw_gfchw_q(
1030    I=TensorDef(
1031        T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
1032    ),
1033    K=TensorDef(T2, S.G, S.FG, S.C, S.KH, S.KW),
1034    IZp=ScalarDef(I32),
1035    KZp=ScalarDef(I32),
1036    O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
1037    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1038    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1039):
1040    """Performs 2-D grouped convolution with zero-point offsets.
1041
1042    Layout:
1043      * Input: NGCHW.
1044      * Kernel: GFCHW.
1045
1046    Numeric casting is performed on the operands to the inner multiply, promoting
1047    them to the same data type as the accumulator/output. This includes the zero
1048    point offsets common to quantized operations.
1049    """
1050    implements(ConvolutionOpInterface)
1051    domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
1052    O[D.n, D.g, D.fg, D.oh, D.ow] += (
1053        TypeFn.cast_signed(
1054            U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
1055        )
1056        - TypeFn.cast_signed(U, IZp)
1057    ) * (
1058        TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
1059        - TypeFn.cast_signed(U, KZp)
1060    )
1061
1062
1063@linalg_structured_op
1064def conv_3d_ndhwc_dhwcf(
1065    I=TensorDef(
1066        T1,
1067        S.N,
1068        S.OD * S.SD + S.KD * S.DD,
1069        S.OH * S.SH + S.KH * S.DH,
1070        S.OW * S.SW + S.KW * S.DW,
1071        S.C,
1072    ),
1073    K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
1074    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
1075    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
1076    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
1077):
1078    """Performs 3-D convolution.
1079
1080    Numeric casting is performed on the operands to the inner multiply, promoting
1081    them to the same data type as the accumulator/output.
1082    """
1083    implements(ConvolutionOpInterface)
1084    domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
1085    O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed(
1086        U,
1087        I[
1088            D.n,
1089            D.od * S.SD + D.kd * S.DD,
1090            D.oh * S.SH + D.kh * S.DH,
1091            D.ow * S.SW + D.kw * S.DW,
1092            D.c,
1093        ],
1094    ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f])
1095
1096
1097@linalg_structured_op
1098def conv_3d_ndhwc_dhwcf_q(
1099    I=TensorDef(
1100        T1,
1101        S.N,
1102        S.OD * S.SD + S.KD * S.DD,
1103        S.OH * S.SH + S.KH * S.DH,
1104        S.OW * S.SW + S.KW * S.DW,
1105        S.C,
1106    ),
1107    K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
1108    IZp=ScalarDef(I32),
1109    KZp=ScalarDef(I32),
1110    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
1111    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
1112    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
1113):
1114    """Performs 3-D convolution with zero point offsets.
1115
1116    Numeric casting is performed on the operands to the inner multiply, promoting
1117    them to the same data type as the accumulator/output. This includes the zero
1118    point offsets common to quantized operations.
1119    """
1120    implements(ConvolutionOpInterface)
1121    domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
1122    O[D.n, D.od, D.oh, D.ow, D.f] += (
1123        TypeFn.cast_signed(
1124            U,
1125            I[
1126                D.n,
1127                D.od * S.SD + D.kd * S.DD,
1128                D.oh * S.SH + D.kh * S.DH,
1129                D.ow * S.SW + D.kw * S.DW,
1130                D.c,
1131            ],
1132        )
1133        - TypeFn.cast_signed(U, IZp)
1134    ) * (
1135        TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f])
1136        - TypeFn.cast_signed(U, KZp)
1137    )
1138
1139
1140@linalg_structured_op
1141def conv_3d_ncdhw_fcdhw(
1142    I=TensorDef(
1143        T1,
1144        S.N,
1145        S.C,
1146        S.OD * S.SD + S.KD * S.DD,
1147        S.OH * S.SH + S.KH * S.DH,
1148        S.OW * S.SW + S.KW * S.DW,
1149    ),
1150    K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW),
1151    O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True),
1152    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
1153    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
1154):
1155    """Performs 3-D convolution.
1156
1157    Numeric casting is performed on the operands to the inner multiply, promoting
1158    them to the same data type as the accumulator/output.
1159    """
1160    implements(ConvolutionOpInterface)
1161    domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
1162    O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed(
1163        U,
1164        I[
1165            D.n,
1166            D.c,
1167            D.od * S.SD + D.kd * S.DD,
1168            D.oh * S.SH + D.kh * S.DH,
1169            D.ow * S.SW + D.kw * S.DW,
1170        ],
1171    ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw])
1172
1173
1174@linalg_structured_op
1175def depthwise_conv_1d_nwc_wc(
1176    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC),
1177    K=TensorDef(T2, S.KW, S.IC),
1178    O=TensorDef(U, S.N, S.OW, S.IC, output=True),
1179    strides=IndexAttrDef(S.SW, default=[1]),
1180    dilations=IndexAttrDef(S.DW, default=[1]),
1181):
1182    """Performs depth-wise 1-D convolution.
1183
1184    Numeric casting is performed on the operands to the inner multiply, promoting
1185    them to the same data type as the accumulator/output. Multiplier is set to 1
1186    which is a special case for most depthwise convolutions.
1187    """
1188    implements(ConvolutionOpInterface)
1189    domain(D.n, D.ow, D.ic, D.kw)
1190    O[D.n, D.ow, D.ic] += TypeFn.cast_signed(
1191        U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]
1192    ) * TypeFn.cast_signed(U, K[D.kw, D.ic])
1193
1194
1195@linalg_structured_op
1196def depthwise_conv_1d_ncw_cw(
1197    I=TensorDef(T1, S.N, S.IC, S.OW * S.SW + S.KW * S.DW),
1198    K=TensorDef(T2, S.IC, S.KW),
1199    O=TensorDef(U, S.N, S.IC, S.OW, output=True),
1200    strides=IndexAttrDef(S.SW, default=[1]),
1201    dilations=IndexAttrDef(S.DW, default=[1]),
1202):
1203    """Performs depth-wise 1-D convolution.
1204
1205    Numeric casting is performed on the operands to the inner multiply, promoting
1206    them to the same data type as the accumulator/output. Multiplier is set to 1
1207    which is a special case for most depthwise convolutions.
1208    """
1209    implements(ConvolutionOpInterface)
1210    domain(D.n, D.ow, D.ic, D.kw)
1211    O[D.n, D.ic, D.ow] += TypeFn.cast_signed(
1212        U, I[D.n, D.ic, D.ow * S.SW + D.kw * S.DW]
1213    ) * TypeFn.cast_signed(U, K[D.ic, D.kw])
1214
1215
1216@linalg_structured_op
1217def depthwise_conv_1d_nwc_wcm(
1218    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC),
1219    K=TensorDef(T2, S.KW, S.IC, S.CM),
1220    O=TensorDef(U, S.N, S.OW, S.IC, S.CM, output=True),
1221    strides=IndexAttrDef(S.SW, default=[1]),
1222    dilations=IndexAttrDef(S.DW, default=[1]),
1223):
1224    """Performs depth-wise 1-D convolution.
1225
1226    Numeric casting is performed on the operands to the inner multiply, promoting
1227    them to the same data type as the accumulator/output.
1228    """
1229    implements(ConvolutionOpInterface)
1230    domain(D.n, D.ow, D.ic, D.cm, D.kw)
1231    O[D.n, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
1232        U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]
1233    ) * TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm])
1234
1235
1236@linalg_structured_op
1237def depthwise_conv_2d_nhwc_hwc(
1238    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
1239    K=TensorDef(T2, S.KH, S.KW, S.IC),
1240    O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
1241    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1242    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1243):
1244    """Performs depth-wise 2-D convolution.
1245
1246    Numeric casting is performed on the operands to the inner multiply, promoting
1247    them to the same data type as the accumulator/output. Multiplier is set to 1
1248    which is a special case for most depthwise convolutions.
1249    """
1250    implements(ConvolutionOpInterface)
1251    domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
1252    O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
1253        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]
1254    ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic])
1255
1256
1257@linalg_structured_op
1258def depthwise_conv_2d_nchw_chw(
1259    I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
1260    K=TensorDef(T2, S.IC, S.KH, S.KW),
1261    O=TensorDef(U, S.N, S.IC, S.OH, S.OW, output=True),
1262    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1263    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1264):
1265    """Performs depth-wise 2-D convolution.
1266
1267    Numeric casting is performed on the operands to the inner multiply, promoting
1268    them to the same data type as the accumulator/output. Multiplier is set to 1
1269    which is a special case for most depthwise convolutions.
1270    """
1271    implements(ConvolutionOpInterface)
1272    domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
1273    O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed(
1274        U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
1275    ) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw])
1276
1277
1278@linalg_structured_op
1279def depthwise_conv_2d_nhwc_hwc_q(
1280    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
1281    K=TensorDef(T2, S.KH, S.KW, S.IC),
1282    IZp=ScalarDef(I32),
1283    KZp=ScalarDef(I32),
1284    O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
1285    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1286    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1287):
1288    """Performs depth-wise 2-D convolution.
1289
1290    Numeric casting is performed on the operands to the inner multiply, promoting
1291    them to the same data type as the accumulator/output.
1292    """
1293    implements(ConvolutionOpInterface)
1294    domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
1295    O[D.n, D.oh, D.ow, D.ic] += (
1296        TypeFn.cast_signed(
1297            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]
1298        )
1299        - TypeFn.cast_signed(U, IZp)
1300    ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast_signed(U, KZp))
1301
1302
1303@linalg_structured_op
1304def depthwise_conv_2d_nhwc_hwcm(
1305    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
1306    K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
1307    O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
1308    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1309    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1310):
1311    """Performs depth-wise 2-D convolution.
1312
1313    Numeric casting is performed on the operands to the inner multiply, promoting
1314    them to the same data type as the accumulator/output.
1315    """
1316    implements(ConvolutionOpInterface)
1317    domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
1318    O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
1319        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]
1320    ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm])
1321
1322
1323@linalg_structured_op
1324def depthwise_conv_2d_nhwc_hwcm_q(
1325    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
1326    K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
1327    IZp=ScalarDef(I32),
1328    KZp=ScalarDef(I32),
1329    O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
1330    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1331    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1332):
1333    """Performs depth-wise 2-D convolution.
1334
1335    Numeric casting is performed on the operands to the inner multiply, promoting
1336    them to the same data type as the accumulator/output.
1337    """
1338    implements(ConvolutionOpInterface)
1339    domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
1340    O[D.n, D.oh, D.ow, D.ic, D.cm] += (
1341        TypeFn.cast_signed(
1342            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]
1343        )
1344        - TypeFn.cast_signed(U, IZp)
1345    ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast_signed(U, KZp))
1346
1347
1348@linalg_structured_op
1349def depthwise_conv_3d_ndhwc_dhwc(
1350    I=TensorDef(
1351        T1,
1352        S.N,
1353        S.OD * S.SD + S.KD * S.DD,
1354        S.OH * S.SH + S.KH * S.DH,
1355        S.OW * S.SW + S.KW * S.DW,
1356        S.IC,
1357    ),
1358    K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC),
1359    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, output=True),
1360    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
1361    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
1362):
1363    """Performs depth-wise 3-D convolution.
1364
1365    Numeric casting is performed on the operands to the inner multiply, promoting
1366    them to the same data type as the accumulator/output. Multiplier is set to 1
1367    which is a special case for most depthwise convolutions.
1368    """
1369    implements(ConvolutionOpInterface)
1370    domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
1371    O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
1372        U,
1373        I[
1374            D.n,
1375            D.od * S.SD + D.kd * S.DD,
1376            D.oh * S.SH + D.kh * S.DH,
1377            D.ow * S.SW + D.kw * S.DW,
1378            D.ic,
1379        ],
1380    ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic])
1381
1382
1383@linalg_structured_op
1384def depthwise_conv_3d_ncdhw_cdhw(
1385    I=TensorDef(
1386        T1,
1387        S.N,
1388        S.IC,
1389        S.OD * S.SD + S.KD * S.DD,
1390        S.OH * S.SH + S.KH * S.DH,
1391        S.OW * S.SW + S.KW * S.DW,
1392    ),
1393    K=TensorDef(T2, S.IC, S.KD, S.KH, S.KW),
1394    O=TensorDef(U, S.N, S.IC, S.OD, S.OH, S.OW, output=True),
1395    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
1396    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
1397):
1398    """Performs depth-wise 3-D convolution.
1399
1400    Numeric casting is performed on the operands to the inner multiply, promoting
1401    them to the same data type as the accumulator/output. Multiplier is set to 1
1402    which is a special case for most depthwise convolutions.
1403    """
1404    implements(ConvolutionOpInterface)
1405    domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
1406    O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed(
1407        U,
1408        I[
1409            D.n,
1410            D.ic,
1411            D.od * S.SD + D.kd * S.DD,
1412            D.oh * S.SH + D.kh * S.DH,
1413            D.ow * S.SW + D.kw * S.DW,
1414        ],
1415    ) * TypeFn.cast_signed(U, K[D.ic, D.kd, D.kh, D.kw])
1416
1417
1418@linalg_structured_op
1419def depthwise_conv_3d_ndhwc_dhwcm(
1420    I=TensorDef(
1421        T1,
1422        S.N,
1423        S.OD * S.SD + S.KD * S.DD,
1424        S.OH * S.SH + S.KH * S.DH,
1425        S.OW * S.SW + S.KW * S.DW,
1426        S.IC,
1427    ),
1428    K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM),
1429    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.CM, output=True),
1430    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
1431    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
1432):
1433    """Performs depth-wise 3-D convolution.
1434
1435    Numeric casting is performed on the operands to the inner multiply, promoting
1436    them to the same data type as the accumulator/output.
1437    """
1438    implements(ConvolutionOpInterface)
1439    domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic)
1440    O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
1441        U,
1442        I[
1443            D.n,
1444            D.od * S.SD + D.kd * S.DD,
1445            D.oh * S.SH + D.kh * S.DH,
1446            D.ow * S.SW + D.kw * S.DW,
1447            D.ic,
1448        ],
1449    ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic, D.cm])
1450
1451
1452@linalg_structured_op
1453def pooling_nhwc_sum(
1454    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
1455    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
1456    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
1457    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1458    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1459):
1460    """Performs sum pooling.
1461
1462    Layout:
1463      * Input: NHWC.
1464      * Kernel: HW.
1465
1466    Numeric casting is performed on the input operand, promoting it to the same
1467    data type as the accumulator/output.
1468    """
1469    implements(ConvolutionOpInterface)
1470    domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
1471    O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
1472        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
1473    )
1474
1475
1476@linalg_structured_op
1477def pooling_nchw_sum(
1478    I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
1479    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
1480    O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
1481    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1482    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1483):
1484    """Performs sum pooling.
1485
1486    Layout:
1487      * Input: NCHW.
1488      * Kernel: HW.
1489
1490    Numeric casting is performed on the input operand, promoting it to the same
1491    data type as the accumulator/output.
1492    """
1493    implements(ConvolutionOpInterface)
1494    domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
1495    O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed(
1496        U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
1497    )
1498
1499
1500@linalg_structured_op
1501def pooling_nhwc_max(
1502    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
1503    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
1504    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
1505    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1506    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1507):
1508    """Performs max pooling.
1509
1510    Numeric casting is performed on the input operand, promoting it to the same
1511    data type as the accumulator/output.
1512    """
1513    implements(ConvolutionOpInterface)
1514    domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
1515    O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](
1516        TypeFn.cast_signed(
1517            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
1518        )
1519    )
1520
1521
1522@linalg_structured_op
1523def pooling_nhwc_max_unsigned(
1524    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
1525    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
1526    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
1527    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1528    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1529):
1530    """Performs unsigned max pooling.
1531
1532    Numeric casting is performed on the input operand, promoting it to the same
1533    data type as the accumulator/output.
1534    """
1535    implements(ConvolutionOpInterface)
1536    domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
1537    O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
1538        TypeFn.cast_unsigned(
1539            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
1540        )
1541    )
1542
1543
1544@linalg_structured_op
1545def pooling_nchw_max(
1546    I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
1547    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
1548    O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
1549    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1550    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1551):
1552    """Performs max pooling.
1553
1554    Numeric casting is performed on the input operand, promoting it to the same
1555    data type as the accumulator/output.
1556    """
1557    implements(ConvolutionOpInterface)
1558    domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
1559    O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](
1560        TypeFn.cast_signed(
1561            U,
1562            I[
1563                D.n,
1564                D.c,
1565                D.oh * S.SH + D.kh * S.DH,
1566                D.ow * S.SW + D.kw * S.DW,
1567            ],
1568        )
1569    )
1570
1571
1572@linalg_structured_op
1573def pooling_nhwc_min(
1574    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
1575    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
1576    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
1577    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1578    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1579):
1580    """Performs min pooling.
1581
1582    Numeric casting is performed on the input operand, promoting it to the same
1583    data type as the accumulator/output.
1584    """
1585    implements(ConvolutionOpInterface)
1586    domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
1587    O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](
1588        TypeFn.cast_signed(
1589            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
1590        )
1591    )
1592
1593
1594@linalg_structured_op
1595def pooling_nhwc_min_unsigned(
1596    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
1597    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
1598    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
1599    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1600    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1601):
1602    """Performs unsigned min pooling.
1603
1604    Numeric casting is performed on the input operand, promoting it to the same
1605    data type as the accumulator/output.
1606    """
1607    implements(ConvolutionOpInterface)
1608    domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
1609    O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
1610        TypeFn.cast_unsigned(
1611            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
1612        )
1613    )
1614
1615
1616@linalg_structured_op
1617def pooling_nwc_sum(
1618    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
1619    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
1620    O=TensorDef(U, S.N, S.OW, S.C, output=True),
1621    strides=IndexAttrDef(S.SW, default=[1]),
1622    dilations=IndexAttrDef(S.DW, default=[1]),
1623):
1624    """Performs sum pooling.
1625
1626    Layout:
1627      * Input: NWC.
1628      * Kernel: W.
1629
1630    Numeric casting is performed on the input operand, promoting it to the same
1631    data type as the accumulator/output.
1632    """
1633    implements(ConvolutionOpInterface)
1634    domain(D.n, D.ow, D.c, D.kw)
1635    O[D.n, D.ow, D.c] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
1636
1637
1638@linalg_structured_op
1639def pooling_ncw_sum(
1640    I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
1641    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
1642    O=TensorDef(U, S.N, S.C, S.OW, output=True),
1643    strides=IndexAttrDef(S.SW, default=[1]),
1644    dilations=IndexAttrDef(S.DW, default=[1]),
1645):
1646    """Performs sum pooling.
1647
1648    Layout:
1649      * Input: NCW.
1650      * Kernel: W.
1651
1652    Numeric casting is performed on the input operand, promoting it to the same
1653    data type as the accumulator/output.
1654    """
1655    implements(ConvolutionOpInterface)
1656    domain(D.n, D.c, D.ow, D.kw)
1657    O[D.n, D.c, D.ow] += TypeFn.cast_signed(U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW])
1658
1659
1660@linalg_structured_op
1661def pooling_nwc_max(
1662    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
1663    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
1664    O=TensorDef(U, S.N, S.OW, S.C, output=True),
1665    strides=IndexAttrDef(S.SW, default=[1]),
1666    dilations=IndexAttrDef(S.DW, default=[1]),
1667):
1668    """Performs max pooling.
1669
1670    Numeric casting is performed on the input operand, promoting it to the same
1671    data type as the accumulator/output.
1672    """
1673    implements(ConvolutionOpInterface)
1674    domain(D.n, D.ow, D.c, D.kw)
1675    O[D.n, D.ow, D.c] = ReduceFn.max_signed[[D.kw]](
1676        TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
1677    )
1678
1679
1680@linalg_structured_op
1681def pooling_nwc_max_unsigned(
1682    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
1683    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
1684    O=TensorDef(U, S.N, S.OW, S.C, output=True),
1685    strides=IndexAttrDef(S.SW, default=[1]),
1686    dilations=IndexAttrDef(S.DW, default=[1]),
1687):
1688    """Performs unsigned max pooling.
1689
1690    Numeric casting is performed on the input operand, promoting it to the same
1691    data type as the accumulator/output.
1692    """
1693    implements(ConvolutionOpInterface)
1694    domain(D.n, D.ow, D.c, D.kw)
1695    O[D.n, D.ow, D.c] = ReduceFn.max_unsigned[[D.kw]](
1696        TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
1697    )
1698
1699
1700@linalg_structured_op
1701def pooling_ncw_max(
1702    I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
1703    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
1704    O=TensorDef(U, S.N, S.C, S.OW, output=True),
1705    strides=IndexAttrDef(S.SW, default=[1]),
1706    dilations=IndexAttrDef(S.DW, default=[1]),
1707):
1708    """Performs max pooling.
1709
1710    Numeric casting is performed on the input operand, promoting it to the same
1711    data type as the accumulator/output.
1712    """
1713    implements(ConvolutionOpInterface)
1714    domain(D.n, D.c, D.ow, D.kw)
1715    O[D.n, D.c, D.ow] = ReduceFn.max_signed[[D.kw]](
1716        TypeFn.cast_signed(
1717            U,
1718            I[
1719                D.n,
1720                D.c,
1721                D.ow * S.SW + D.kw * S.DW,
1722            ],
1723        )
1724    )
1725
1726
1727@linalg_structured_op
1728def pooling_nwc_min(
1729    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
1730    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
1731    O=TensorDef(U, S.N, S.OW, S.C, output=True),
1732    strides=IndexAttrDef(S.SW, default=[1]),
1733    dilations=IndexAttrDef(S.DW, default=[1]),
1734):
1735    """Performs min pooling.
1736
1737    Numeric casting is performed on the input operand, promoting it to the same
1738    data type as the accumulator/output.
1739    """
1740    implements(ConvolutionOpInterface)
1741    domain(D.n, D.ow, D.c, D.kw)
1742    O[D.n, D.ow, D.c] = ReduceFn.min_signed[[D.kw]](
1743        TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
1744    )
1745
1746
1747@linalg_structured_op
1748def pooling_nwc_min_unsigned(
1749    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
1750    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
1751    O=TensorDef(U, S.N, S.OW, S.C, output=True),
1752    strides=IndexAttrDef(S.SW, default=[1]),
1753    dilations=IndexAttrDef(S.DW, default=[1]),
1754):
1755    """Performs unsigned min pooling.
1756
1757    Numeric casting is performed on the input operand, promoting it to the same
1758    data type as the accumulator/output.
1759    """
1760    implements(ConvolutionOpInterface)
1761    domain(D.n, D.ow, D.c, D.kw)
1762    O[D.n, D.ow, D.c] = ReduceFn.min_unsigned[[D.kw]](
1763        TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
1764    )
1765
1766
1767@linalg_structured_op
1768def pooling_ndhwc_sum(
1769    I=TensorDef(
1770        T1,
1771        S.N,
1772        S.OD * S.SD + S.KD * S.DD,
1773        S.OH * S.SH + S.KH * S.DH,
1774        S.OW * S.SW + S.KW * S.DW,
1775        S.C,
1776    ),
1777    K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
1778    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
1779    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
1780    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
1781):
1782    """Performs 3D sum pooling.
1783
1784    Numeric casting is performed on the input operand, promoting it to the same
1785    data type as the accumulator/output.
1786    """
1787    implements(ConvolutionOpInterface)
1788    domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
1789    O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed(
1790        U,
1791        I[
1792            D.n,
1793            D.od * S.SD + D.kd * S.DD,
1794            D.oh * S.SH + D.kh * S.DH,
1795            D.ow * S.SW + D.kw * S.DW,
1796            D.c,
1797        ],
1798    )
1799
1800
1801@linalg_structured_op
1802def pooling_ndhwc_max(
1803    I=TensorDef(
1804        T1,
1805        S.N,
1806        S.OD * S.SD + S.KD * S.DD,
1807        S.OH * S.SH + S.KH * S.DH,
1808        S.OW * S.SW + S.KW * S.DW,
1809        S.C,
1810    ),
1811    K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
1812    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
1813    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
1814    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
1815):
1816    """Performs 3D max pooling.
1817
1818    Numeric casting is performed on the input operand, promoting it to the same
1819    data type as the accumulator/output.
1820    """
1821    implements(ConvolutionOpInterface)
1822    domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
1823    O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](
1824        TypeFn.cast_signed(
1825            U,
1826            I[
1827                D.n,
1828                D.od * S.SD + D.kd * S.DD,
1829                D.oh * S.SH + D.kh * S.DH,
1830                D.ow * S.SW + D.kw * S.DW,
1831                D.c,
1832            ],
1833        )
1834    )
1835
1836
1837@linalg_structured_op
1838def pooling_ndhwc_min(
1839    I=TensorDef(
1840        T1,
1841        S.N,
1842        S.OD * S.SD + S.KD * S.DD,
1843        S.OH * S.SH + S.KH * S.DH,
1844        S.OW * S.SW + S.KW * S.DW,
1845        S.C,
1846    ),
1847    K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
1848    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
1849    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
1850    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
1851):
1852    """Performs 3D min pooling.
1853
1854    Numeric casting is performed on the input operand, promoting it to the same
1855    data type as the accumulator/output.
1856    """
1857    implements(ConvolutionOpInterface)
1858    domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
1859    O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](
1860        TypeFn.cast_signed(
1861            U,
1862            I[
1863                D.n,
1864                D.od * S.SD + D.kd * S.DD,
1865                D.oh * S.SH + D.kh * S.DH,
1866                D.ow * S.SW + D.kw * S.DW,
1867                D.c,
1868            ],
1869        )
1870    )
1871
1872
1873@linalg_structured_op
1874def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)):
1875    """Fills the output tensor with the given value.
1876
1877    Works for arbitrary ranked output tensors since the operation performs scalar
1878    accesses only and is thus rank polymorphic. Numeric casting is performed on
1879    the value operand, promoting it to the same data type as the output.
1880    """
1881    implements(FillOpInterface)
1882    defines(Canonicalizer)
1883    O[None] = TypeFn.cast_signed(U, value)
1884
1885
1886@linalg_structured_op
1887def fill_rng_2d(
1888    min=ScalarDef(F64),
1889    max=ScalarDef(F64),
1890    seed=ScalarDef(I32),
1891    O=TensorDef(T, S.M, S.N, output=True),
1892):
1893    """Fills the output tensor with pseudo random numbers.
1894
1895    The operation generations pseudo random numbers using a linear congruential
1896    generator. It provides no guarantees regarding the distribution of the
1897    generated random numbers. Instead of generating the random numbers
1898    sequentially, it instantiates one random number generator per data element
1899    and runs them in parallel. The seed operand and the indices of the data
1900    element seed the random number generation. The min and max operands limit
1901    the range of the generated random numbers.
1902    """
1903    domain(D.m, D.n)
1904    multiplier = TypeFn.cast_signed(I32, const(1103515245))
1905    increment = TypeFn.cast_signed(I32, const(12345))
1906    rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment
1907    rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment
1908    inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10))
1909    offset = TypeFn.cast_signed(F64, const(2147483647))
1910    scaling = (max - min) * inv_range
1911    O[D.m, D.n] = TypeFn.cast_signed(
1912        T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min
1913    )
1914