1from ..lang import * 2 3T1 = TV.T1 4T2 = TV.T2 5 6Batch = S.Batch 7 8 9@linalg_structured_op 10def copy( 11 I=TensorDef(T1), 12 O=TensorDef(U, output=True), 13 cast=TypeFnAttrDef(default=TypeFn.cast_signed), 14): 15 """Copies the tensor elementwise. 16 17 Numeric casting is performed on the input operand, promoting it to the same 18 data type as the accumulator/output. 19 """ 20 defines(Canonicalizer) 21 O[None] = cast(U, I[None]) 22 23 24@linalg_structured_op 25def elemwise_unary( 26 I=TensorDef(T1), 27 O=TensorDef(U, output=True), 28 fun=UnaryFnAttrDef(default=UnaryFn.exp), 29 cast=TypeFnAttrDef(default=TypeFn.cast_signed), 30): 31 """Applies the unary function fun elementwise. 32 33 Numeric casting is performed on the input operand, promoting it to the same 34 data type as the accumulator/output. 35 """ 36 O[None] = fun(cast(U, I[None])) 37 38 39@linalg_structured_op 40def exp( 41 I=TensorDef(T1), 42 O=TensorDef(T1, output=True), 43): 44 """Applies exp(x) elementwise. 45 46 No numeric casting is performed on the input operand. 47 """ 48 O[None] = UnaryFn.exp(I[None]) 49 50 51@linalg_structured_op 52def log( 53 I=TensorDef(T1), 54 O=TensorDef(T1, output=True), 55): 56 """Applies log(x) elementwise. 57 58 No numeric casting is performed on the input operand. 59 """ 60 O[None] = UnaryFn.log(I[None]) 61 62 63@linalg_structured_op 64def abs( 65 I=TensorDef(T1), 66 O=TensorDef(T1, output=True), 67): 68 """Applies abs(x) elementwise. 69 70 No numeric casting is performed on the input operand. 71 """ 72 O[None] = UnaryFn.abs(I[None]) 73 74 75@linalg_structured_op 76def ceil( 77 I=TensorDef(T1), 78 O=TensorDef(T1, output=True), 79): 80 """Applies ceil(x) elementwise. 81 82 No numeric casting is performed on the input operand. 83 """ 84 O[None] = UnaryFn.ceil(I[None]) 85 86 87@linalg_structured_op 88def floor( 89 I=TensorDef(T1), 90 O=TensorDef(T1, output=True), 91): 92 """Applies floor(x) elementwise. 93 94 No numeric casting is performed on the input operand. 95 """ 96 O[None] = UnaryFn.floor(I[None]) 97 98 99@linalg_structured_op(op_class_name="NegFOp") 100def negf( 101 I=TensorDef(T1), 102 O=TensorDef(T1, output=True), 103): 104 """Applies negf(x) elementwise. 105 106 No numeric casting is performed on the input operand. 107 """ 108 O[None] = UnaryFn.negf(I[None]) 109 110 111@linalg_structured_op(op_class_name="ReciprocalOp") 112def reciprocal( 113 I=TensorDef(T1), 114 O=TensorDef(T1, output=True), 115): 116 """Applies reciprocal(x) elementwise. 117 118 No numeric casting is performed on the input operand. 119 """ 120 O[None] = UnaryFn.reciprocal(I[None]) 121 122 123@linalg_structured_op 124def round( 125 I=TensorDef(T1), 126 O=TensorDef(T1, output=True), 127): 128 """Applies round(x) elementwise. 129 130 No numeric casting is performed on the input operand. 131 """ 132 O[None] = UnaryFn.round(I[None]) 133 134 135@linalg_structured_op 136def sqrt( 137 I=TensorDef(T1), 138 O=TensorDef(T1, output=True), 139): 140 """Applies sqrt(x) elementwise. 141 142 No numeric casting is performed on the input operand. 143 """ 144 O[None] = UnaryFn.sqrt(I[None]) 145 146 147@linalg_structured_op 148def rsqrt( 149 I=TensorDef(T1), 150 O=TensorDef(T1, output=True), 151): 152 """Applies rsqrt(x) elementwise. 153 154 No numeric casting is performed on the input operand. 155 """ 156 O[None] = UnaryFn.rsqrt(I[None]) 157 158 159@linalg_structured_op 160def square( 161 I=TensorDef(T1), 162 O=TensorDef(T1, output=True), 163): 164 """Applies square(x) elementwise. 165 166 No numeric casting is performed on the input operand. 167 """ 168 O[None] = UnaryFn.square(I[None]) 169 170 171@linalg_structured_op 172def tanh( 173 I=TensorDef(T1), 174 O=TensorDef(T1, output=True), 175): 176 """Applies tanh(x) elementwise. 177 178 No numeric casting is performed on the input operand. 179 """ 180 O[None] = UnaryFn.tanh(I[None]) 181 182 183@linalg_structured_op 184def erf( 185 I=TensorDef(T1), 186 O=TensorDef(T1, output=True), 187): 188 """Applies erf(x) elementwise. 189 190 No numeric casting is performed on the input operand. 191 """ 192 O[None] = UnaryFn.erf(I[None]) 193 194 195@linalg_structured_op 196def elemwise_binary( 197 lhs=TensorDef(T1), 198 rhs=TensorDef(T2), 199 O=TensorDef(U, output=True), 200 fun=BinaryFnAttrDef(default=BinaryFn.add), 201 cast=TypeFnAttrDef(default=TypeFn.cast_signed), 202): 203 """Applies the binary function fun elementwise. 204 205 Numeric casting is performed on the input operand, promoting it to the same 206 data type as the accumulator/output. 207 """ 208 O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None])) 209 210 211@linalg_structured_op 212def add( 213 lhs=TensorDef(T1), 214 rhs=TensorDef(T1), 215 O=TensorDef(T1, output=True), 216): 217 """Adds two tensors elementwise. 218 219 The shapes and element types must be identical. The appropriate casts, 220 broadcasts and reductions should be done previously to calling this op. 221 222 This means reduction/broadcast/element cast semantics is explicit. Further 223 passes can take that into account when lowering this code. For example, 224 a `linalg.broadcast` + `linalg.add` sequence can be lowered to a 225 `linalg.generic` with different affine maps for the two operands. 226 """ 227 O[None] = BinaryFn.add(lhs[None], rhs[None]) 228 229 230@linalg_structured_op 231def sub( 232 lhs=TensorDef(T1), 233 rhs=TensorDef(T1), 234 O=TensorDef(T1, output=True), 235): 236 """Subtracts two tensors elementwise. 237 238 The shapes and element types must be identical. The appropriate casts, 239 broadcasts and reductions should be done previously to calling this op. 240 241 This means reduction/broadcast/element cast semantics is explicit. Further 242 passes can take that into account when lowering this code. For example, 243 a `linalg.broadcast` + `linalg.sub` sequence can be lowered to a 244 `linalg.generic` with different affine maps for the two operands. 245 """ 246 O[None] = BinaryFn.sub(lhs[None], rhs[None]) 247 248 249@linalg_structured_op 250def mul( 251 lhs=TensorDef(T1), 252 rhs=TensorDef(T1), 253 O=TensorDef(T1, output=True), 254): 255 """Multiplies two tensors elementwise. 256 257 The shapes and element types must be identical. The appropriate casts, 258 broadcasts and reductions should be done previously to calling this op. 259 260 This means reduction/broadcast/element cast semantics is explicit. Further 261 passes can take that into account when lowering this code. For example, 262 a `linalg.broadcast` + `linalg.mul` sequence can be lowered to a 263 `linalg.generic` with different affine maps for the two operands. 264 """ 265 O[None] = BinaryFn.mul(lhs[None], rhs[None]) 266 267 268@linalg_structured_op 269def div( 270 lhs=TensorDef(T1), 271 rhs=TensorDef(T1), 272 O=TensorDef(T1, output=True), 273): 274 """Divides the first tensor by the second tensor, elementwise. 275 276 The shapes and element types must be identical. The appropriate casts, 277 broadcasts and reductions should be done previously to calling this op. 278 279 This means reduction/broadcast/element cast semantics is explicit. Further 280 passes can take that into account when lowering this code. For example, 281 a `linalg.broadcast` + `linalg.div` sequence can be lowered to a 282 `linalg.generic` with different affine maps for the two operands. 283 """ 284 O[None] = BinaryFn.div(lhs[None], rhs[None]) 285 286 287@linalg_structured_op 288def div_unsigned( 289 lhs=TensorDef(T1), 290 rhs=TensorDef(T1), 291 O=TensorDef(T1, output=True), 292): 293 """Divides the first tensor by the second tensor, elementwise. For integer 294 types, performs an unsigned division. 295 296 The shapes and element types must be identical. The appropriate casts, 297 broadcasts and reductions should be done previously to calling this op. 298 299 This means reduction/broadcast/element cast semantics is explicit. Further 300 passes can take that into account when lowering this code. For example, 301 a `linalg.broadcast` + `linalg.div` sequence can be lowered to a 302 `linalg.generic` with different affine maps for the two operands. 303 """ 304 O[None] = BinaryFn.div_unsigned(lhs[None], rhs[None]) 305 306 307@linalg_structured_op 308def max( 309 lhs=TensorDef(T1), 310 rhs=TensorDef(T1), 311 O=TensorDef(T1, output=True), 312): 313 """Takes the max (signed) between two inputs, elementwise. 314 315 The shapes and element types must be identical. The appropriate casts, 316 broadcasts and reductions should be done previously to calling this op. 317 318 This means reduction/broadcast/element cast semantics is explicit. Further 319 passes can take that into account when lowering this code. For example, 320 a `linalg.broadcast` + `linalg.max` sequence can be lowered to a 321 `linalg.generic` with different affine maps for the two operands. 322 """ 323 O[None] = BinaryFn.max_signed(lhs[None], rhs[None]) 324 325 326@linalg_structured_op 327def min( 328 lhs=TensorDef(T1), 329 rhs=TensorDef(T1), 330 O=TensorDef(T1, output=True), 331): 332 """Takes the min (signed) between two inputs, elementwise. 333 334 The shapes and element types must be identical. The appropriate casts, 335 broadcasts and reductions should be done previously to calling this op. 336 337 This means reduction/broadcast/element cast semantics is explicit. Further 338 passes can take that into account when lowering this code. For example, 339 a `linalg.broadcast` + `linalg.min` sequence can be lowered to a 340 `linalg.generic` with different affine maps for the two operands. 341 """ 342 O[None] = BinaryFn.min_signed(lhs[None], rhs[None]) 343 344 345@linalg_structured_op(op_class_name="PowFOp") 346def powf( 347 lhs=TensorDef(T1), 348 rhs=TensorDef(T1), 349 O=TensorDef(T1, output=True), 350): 351 """Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`. 352 353 Only applies to floating point values. 354 355 The shapes and element types must be identical. The appropriate casts, 356 broadcasts and reductions should be done previously to calling this op. 357 358 This means reduction/broadcast/element cast semantics is explicit. Further 359 passes can take that into account when lowering this code. For example, 360 a `linalg.broadcast` + `linalg.powf` sequence can be lowered to a 361 `linalg.generic` with different affine maps for the two operands. 362 """ 363 O[None] = BinaryFn.powf(lhs[None], rhs[None]) 364 365 366@linalg_structured_op 367def select( 368 cond=TensorDef(U), 369 lhs=TensorDef(T1), 370 rhs=TensorDef(T1), 371 O=TensorDef(T1, output=True), 372): 373 """Chooses one value based on a binary condition supplied as its first operand. 374 375 The shapes and element types must be identical. The appropriate casts, 376 broadcasts and reductions should be done previously to calling this op. 377 378 This means reduction/broadcast/element cast semantics is explicit. Further 379 passes can take that into account when lowering this code. For example, 380 a `linalg.broadcast` + `linalg.select` sequence can be lowered to a 381 `linalg.generic` with different affine maps for the two operands. 382 """ 383 O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None]) 384 385 386@linalg_structured_op 387def quantized_matmul( 388 A=TensorDef(T1, S.M, S.K), 389 B=TensorDef(T2, S.K, S.N), 390 AZp=ScalarDef(I32), 391 BZp=ScalarDef(I32), 392 C=TensorDef(U, S.M, S.N, output=True), 393): 394 """Performs a matrix multiplication of two 2D inputs. 395 396 Numeric casting is performed on the operands to the inner multiply, promoting 397 them to the same data type as the accumulator/output. The quantized variant 398 includes zero-point adjustments for the left and right operands of the 399 matmul. 400 """ 401 domain(D.m, D.n, D.k) 402 C[D.m, D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * ( 403 TypeFn.cast_signed(U, B[D.k, D.n]) - TypeFn.cast_signed(U, BZp) 404 ) 405 406 407@linalg_structured_op 408def matmul_transpose_a( 409 A=TensorDef(T1, S.K, S.N), 410 B=TensorDef(T2, S.K, S.M), 411 C=TensorDef(U, S.M, S.N, output=True), 412 cast=TypeFnAttrDef(default=TypeFn.cast_signed), 413): 414 """Performs a matrix multiplication of two 2D inputs with lhs operand 415 transposed. 416 417 Numeric casting is performed on the operands to the inner multiply, promoting 418 them to the same data type as the accumulator/output. 419 """ 420 domain(D.m, D.n, D.k) 421 implements(ContractionOpInterface) 422 C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n]) 423 424 425@linalg_structured_op 426def matmul_transpose_b( 427 A=TensorDef(T1, S.M, S.K), 428 B=TensorDef(T2, S.N, S.K), 429 C=TensorDef(U, S.M, S.N, output=True), 430 cast=TypeFnAttrDef(default=TypeFn.cast_signed), 431): 432 """Performs a matrix multiplication of two 2D inputs with rhs operand 433 transposed. 434 435 Numeric casting is performed on the operands to the inner multiply, promoting 436 them to the same data type as the accumulator/output. 437 """ 438 domain(D.m, D.n, D.k) 439 implements(ContractionOpInterface) 440 C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k]) 441 442 443@linalg_structured_op 444def mmt4d( 445 lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), 446 rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), 447 accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True), 448): 449 """Performs a matrix-matrix-transpose multiplication of two 4D inputs. 450 451 Differences from linalg.matmul: 452 * The right hand side is transposed, whence the 't' in 'mmt'. 453 * The input and output tensors have a 4D shape instead of a 2D shape. They 454 are interpreted as 2D matrices with one level of 2D tile subdivision, 455 whence the 2+2=4 dimensions. The inner tile dimensions are identified with 456 '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads 457 as: MxK tiles, each of shape M0xK0. 458 """ 459 domain(D.m, D.n, D.k, D.m0, D.n0, D.k0) 460 implements(ContractionOpInterface) 461 accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed( 462 TV.AccumType, lhs[D.m, D.k, D.m0, D.k0] 463 ) * TypeFn.cast_signed(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) 464 465 466@linalg_structured_op 467def batch_mmt4d( 468 lhs=TensorDef(TV.LhsType, Batch, S.M, S.K, S.M0, S.K0), 469 rhs=TensorDef(TV.RhsType, Batch, S.N, S.K, S.N0, S.K0), 470 accum=TensorDef(TV.AccumType, Batch, S.M, S.N, S.M0, S.N0, output=True), 471): 472 """Performs a batched matrix-matrix-transpose multiplication of two 473 batched-4D (5D) inputs. 474 475 Besides the outermost batch dimension has the same semantic as 476 linalg.batch_matmul, the differences from linalg.batch_matmul in the 477 non-batch dimensions are the same as linalg.mmt4d vs. linalg.matmul. See the 478 description of lingalg.mmt4d. 479 """ 480 domain(D.b, D.m, D.n, D.k, D.m0, D.n0, D.k0) 481 implements(ContractionOpInterface) 482 accum[D.b, D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed( 483 TV.AccumType, lhs[D.b, D.m, D.k, D.m0, D.k0] 484 ) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0]) 485 486 487@linalg_structured_op 488def batch_matmul( 489 A=TensorDef(T1, Batch, S.M, S.K), 490 B=TensorDef(T2, Batch, S.K, S.N), 491 C=TensorDef(U, Batch, S.M, S.N, output=True), 492): 493 """Performs a batched matrix multiplication of two 3D inputs. 494 495 Numeric casting is performed on the operands to the inner multiply, promoting 496 them to the same data type as the accumulator/output. 497 """ 498 domain(D.b, D.m, D.n, D.k) 499 implements(ContractionOpInterface) 500 C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( 501 U, B[D.b, D.k, D.n] 502 ) 503 504 505@linalg_structured_op 506def batch_matmul_transpose_a( 507 A=TensorDef(T1, Batch, S.K, S.M), 508 B=TensorDef(T2, Batch, S.K, S.N), 509 C=TensorDef(U, Batch, S.M, S.N, output=True), 510): 511 """Performs a batched matrix multiplication of two 3D inputs where lhs operand 512 has its non-batch dimensions transposed. 513 514 Numeric casting is performed on the operands to the inner multiply, promoting 515 them to the same data type as the accumulator/output. 516 """ 517 domain(D.b, D.m, D.n, D.k) 518 implements(ContractionOpInterface) 519 C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed( 520 U, B[D.b, D.k, D.n] 521 ) 522 523 524@linalg_structured_op 525def batch_matmul_transpose_b( 526 A=TensorDef(T1, Batch, S.M, S.K), 527 B=TensorDef(T2, Batch, S.N, S.K), 528 C=TensorDef(U, Batch, S.M, S.N, output=True), 529): 530 """Performs a batched matrix multiplication of two 3D inputs where rhs operand 531 has its non-batch dimensions transposed. 532 533 Numeric casting is performed on the operands to the inner multiply, promoting 534 them to the same data type as the accumulator/output. 535 """ 536 domain(D.b, D.m, D.n, D.k) 537 implements(ContractionOpInterface) 538 C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( 539 U, B[D.b, D.n, D.k] 540 ) 541 542 543@linalg_structured_op 544def quantized_batch_matmul( 545 A=TensorDef(T1, Batch, S.M, S.K), 546 B=TensorDef(T2, Batch, S.K, S.N), 547 AZp=ScalarDef(I32), 548 BZp=ScalarDef(I32), 549 C=TensorDef(U, Batch, S.M, S.N, output=True), 550): 551 """Performs a batched matrix multiplication of two 3D inputs. 552 553 Numeric casting is performed on the operands to the inner multiply, promoting 554 them to the same data type as the accumulator/output. The quantized variant 555 includes zero-point adjustments for the left and right operands of the 556 matmul. 557 """ 558 domain(D.b, D.m, D.n, D.k) 559 C[D.b, D.m, D.n] += ( 560 TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - TypeFn.cast_signed(U, AZp) 561 ) * (TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) 562 563 564@linalg_structured_op 565def batch_reduce_matmul( 566 A=TensorDef(T1, Batch, S.M, S.K), 567 B=TensorDef(T2, Batch, S.K, S.N), 568 C=TensorDef(U, S.M, S.N, output=True), 569): 570 """Performs a batch-reduce matrix multiplication of two 3D inputs. 571 The partial multiplication results are reduced into a 2D output. 572 573 Numeric casting is performed on the operands to the inner multiply, promoting 574 them to the same data type as the accumulator/output. 575 """ 576 domain(D.b, D.m, D.n, D.k) 577 implements(ContractionOpInterface) 578 C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( 579 U, B[D.b, D.k, D.n] 580 ) 581 582 583@linalg_structured_op 584def matvec( 585 A=TensorDef(T1, S.M, S.N), y=TensorDef(T2, S.N), x=TensorDef(U, S.M, output=True) 586): 587 """Performs a matrix-vector multiplication. 588 589 Numeric casting is performed on the operands to the inner multiply, promoting 590 them to the same data type as the accumulator/output. 591 """ 592 domain(D.m, D.n) 593 implements(ContractionOpInterface) 594 x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n]) 595 596 597@linalg_structured_op 598def vecmat( 599 y=TensorDef(T1, S.M), A=TensorDef(T2, S.M, S.N), x=TensorDef(U, S.N, output=True) 600): 601 """Performs a vector-matrix multiplication. 602 603 Numeric casting is performed on the operands to the inner multiply, promoting 604 them to the same data type as the accumulator/output. 605 """ 606 domain(D.n, D.m) 607 implements(ContractionOpInterface) 608 x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n]) 609 610 611@linalg_structured_op 612def batch_matvec( 613 A=TensorDef(T1, Batch, S.M, S.K), 614 B=TensorDef(T2, Batch, S.K), 615 C=TensorDef(U, Batch, S.M, output=True), 616): 617 """Performs a batched matrix-vector multiplication. 618 619 Numeric casting is performed on the operands to the inner multiply, promoting 620 them to the same data type as the accumulator/output. 621 """ 622 domain(D.b, D.m, D.k) 623 implements(ContractionOpInterface) 624 C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( 625 U, B[D.b, D.k] 626 ) 627 628 629@linalg_structured_op 630def batch_vecmat( 631 A=TensorDef(T1, Batch, S.K), 632 B=TensorDef(T2, Batch, S.K, S.N), 633 C=TensorDef(U, Batch, S.N, output=True), 634): 635 """Performs a batched matrix-vector multiplication. 636 637 Numeric casting is performed on the operands to the inner multiply, promoting 638 them to the same data type as the accumulator/output. 639 """ 640 domain(D.b, D.n, D.k) 641 implements(ContractionOpInterface) 642 C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed( 643 U, B[D.b, D.k, D.n] 644 ) 645 646 647@linalg_structured_op 648def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)): 649 """Performs a dot product of two vectors to a scalar result. 650 651 Numeric casting is performed on the operands to the inner multiply, promoting 652 them to the same data type as the accumulator/output. 653 """ 654 implements(ContractionOpInterface) 655 C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m]) 656 657 658@linalg_structured_op 659def conv_1d( 660 I=TensorDef(T1, S.OW + S.KW), 661 K=TensorDef(T2, S.KW), 662 O=TensorDef(U, S.OW, output=True), 663): 664 """Performs 1-D convolution with no channels. 665 666 Numeric casting is performed on the operands to the inner multiply, promoting 667 them to the same data type as the accumulator/output. 668 """ 669 implements(ConvolutionOpInterface) 670 domain(D.ow, D.kw) 671 O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kw]) 672 673 674@linalg_structured_op 675def conv_2d( 676 I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW), 677 K=TensorDef(T2, S.KH, S.KW), 678 O=TensorDef(U, S.OH, S.OW, output=True), 679): 680 """Performs 2-D convolution with no channels. 681 682 Numeric casting is performed on the operands to the inner multiply, promoting 683 them to the same data type as the accumulator/output. 684 """ 685 implements(ConvolutionOpInterface) 686 domain(D.oh, D.ow, D.kh, D.kw) 687 O[D.oh, D.ow] += TypeFn.cast_signed( 688 U, I[D.oh + D.kh, D.ow + D.kw] 689 ) * TypeFn.cast_signed(U, K[D.kh, D.kw]) 690 691 692@linalg_structured_op 693def conv_3d( 694 I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW), 695 K=TensorDef(T2, S.KD, S.KH, S.KW), 696 O=TensorDef(U, S.OD, S.OH, S.OW, output=True), 697): 698 """Performs 3-D convolution with no channels. 699 700 Numeric casting is performed on the operands to the inner multiply, promoting 701 them to the same data type as the accumulator/output. 702 """ 703 implements(ConvolutionOpInterface) 704 domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) 705 O[D.od, D.oh, D.ow] += TypeFn.cast_signed( 706 U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw] 707 ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw]) 708 709 710@linalg_structured_op 711def conv_1d_nwc_wcf( 712 I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), 713 K=TensorDef(T2, S.KW, S.C, S.F), 714 O=TensorDef(U, S.N, S.OW, S.F, output=True), 715 strides=IndexAttrDef(S.SW, default=[1]), 716 dilations=IndexAttrDef(S.DW, default=[1]), 717): 718 """Performs 1-D convolution. 719 720 Numeric casting is performed on the operands to the inner multiply, promoting 721 them to the same data type as the accumulator/output. 722 """ 723 implements(ConvolutionOpInterface) 724 domain(D.n, D.ow, D.f, D.kw, D.c) 725 O[D.n, D.ow, D.f] += TypeFn.cast_signed( 726 U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c] 727 ) * TypeFn.cast_signed(U, K[D.kw, D.c, D.f]) 728 729 730@linalg_structured_op 731def conv_1d_ncw_fcw( 732 I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), 733 K=TensorDef(T2, S.F, S.C, S.KW), 734 O=TensorDef(U, S.N, S.F, S.OW, output=True), 735 strides=IndexAttrDef(S.SW, default=[1]), 736 dilations=IndexAttrDef(S.DW, default=[1]), 737): 738 """Performs 1-D convolution. 739 740 Layout: 741 * Input: NCW. 742 * Kernel: FCW. 743 744 Numeric casting is performed on the operands to the inner multiply, promoting 745 them to the same data type as the accumulator/output. 746 """ 747 implements(ConvolutionOpInterface) 748 domain(D.n, D.f, D.ow, D.c, D.kw) 749 O[D.n, D.f, D.ow] += TypeFn.cast_signed( 750 U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW] 751 ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kw]) 752 753 754@linalg_structured_op 755def conv_2d_nhwc_hwcf( 756 I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), 757 K=TensorDef(T2, S.KH, S.KW, S.C, S.F), 758 O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), 759 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 760 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 761): 762 """Performs 2-D convolution. 763 764 Layout: 765 * Input: NHWC. 766 * Kernel: HWCF. 767 768 Numeric casting is performed on the operands to the inner multiply, promoting 769 them to the same data type as the accumulator/output. 770 """ 771 implements(ConvolutionOpInterface) 772 domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) 773 O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( 774 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] 775 ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) 776 777 778@linalg_structured_op 779def conv_2d_nhwc_fhwc( 780 I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), 781 K=TensorDef(T2, S.F, S.KH, S.KW, S.C), 782 O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), 783 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 784 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 785): 786 """Performs 2-D convolution. 787 788 Layout: 789 * Input: NHWC. 790 * Kernel: FHWC. 791 792 Numeric casting is performed on the operands to the inner multiply, promoting 793 them to the same data type as the accumulator/output. 794 """ 795 implements(ConvolutionOpInterface) 796 domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) 797 O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( 798 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] 799 ) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) 800 801 802@linalg_structured_op 803def conv_2d_nhwc_hwcf_q( 804 I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), 805 K=TensorDef(T2, S.KH, S.KW, S.C, S.F), 806 IZp=ScalarDef(I32), 807 KZp=ScalarDef(I32), 808 O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), 809 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 810 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 811): 812 """Performs 2-D convolution with zero point offsets. 813 814 Layout: 815 * Input: NHWC. 816 * Kernel: HWCF. 817 818 Numeric casting is performed on the operands to the inner multiply, promoting 819 them to the same data type as the accumulator/output. This includes the zero 820 point offsets common to quantized operations. 821 """ 822 implements(ConvolutionOpInterface) 823 domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) 824 O[D.n, D.oh, D.ow, D.f] += ( 825 TypeFn.cast_signed( 826 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] 827 ) 828 - TypeFn.cast_signed(U, IZp) 829 ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp)) 830 831 832@linalg_structured_op 833def conv_2d_nhwc_fhwc_q( 834 I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), 835 K=TensorDef(T2, S.F, S.KH, S.KW, S.C), 836 IZp=ScalarDef(I32), 837 KZp=ScalarDef(I32), 838 O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), 839 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 840 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 841): 842 """Performs 2-D convolution with zero point offsets. 843 844 Layout: 845 * Input: NHWC. 846 * Kernel: FHWC. 847 848 Numeric casting is performed on the operands to the inner multiply, promoting 849 them to the same data type as the accumulator/output. This includes the zero 850 point offsets common to quantized operations. 851 """ 852 implements(ConvolutionOpInterface) 853 domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) 854 O[D.n, D.oh, D.ow, D.f] += ( 855 TypeFn.cast_signed( 856 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] 857 ) 858 - TypeFn.cast_signed(U, IZp) 859 ) * (TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) - TypeFn.cast_signed(U, KZp)) 860 861 862@linalg_structured_op 863def conv_2d_nchw_fchw_q( 864 I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), 865 K=TensorDef(T2, S.F, S.C, S.KH, S.KW), 866 IZp=ScalarDef(I32), 867 KZp=ScalarDef(I32), 868 O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), 869 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 870 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 871): 872 """Performs 2-D convolution with zero point offsets. 873 874 Layout: 875 * Input: NCHW. 876 * Kernel: FCHW. 877 878 Numeric casting is performed on the operands to the inner multiply, promoting 879 them to the same data type as the accumulator/output. This includes the zero 880 point offsets common to quantized operations. 881 """ 882 implements(ConvolutionOpInterface) 883 domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) 884 O[D.n, D.f, D.oh, D.ow] += ( 885 TypeFn.cast_signed( 886 U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] 887 ) 888 - TypeFn.cast_signed(U, IZp) 889 ) * (TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) - TypeFn.cast_signed(U, KZp)) 890 891@linalg_structured_op 892def conv_2d_nchw_fchw( 893 I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), 894 K=TensorDef(T2, S.F, S.C, S.KH, S.KW), 895 O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), 896 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 897 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 898): 899 """Performs 2-D convolution. 900 901 Layout: 902 * Input: NCHW. 903 * Kernel: FCHW. 904 905 Numeric casting is performed on the operands to the inner multiply, promoting 906 them to the same data type as the accumulator/output. 907 """ 908 implements(ConvolutionOpInterface) 909 domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) 910 O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed( 911 U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] 912 ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) 913 914 915@linalg_structured_op 916def conv_2d_ngchw_fgchw( 917 I=TensorDef( 918 T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW 919 ), 920 K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), 921 O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True), 922 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 923 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 924): 925 """Performs 2-D grouped convolution. 926 927 Layout: 928 * Input: NGCHW. 929 * Kernel: FGCHW. 930 931 Numeric casting is performed on the operands to the inner multiply, promoting 932 them to the same data type as the accumulator/output. 933 """ 934 implements(ConvolutionOpInterface) 935 domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) 936 O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( 937 U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] 938 ) * TypeFn.cast_signed(U, K[D.fg, D.g, D.c, D.kh, D.kw]) 939 940 941@linalg_structured_op 942def conv_2d_ngchw_gfchw( 943 I=TensorDef( 944 T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW 945 ), 946 K=TensorDef(T2, S.G, S.FG, S.C, S.KH, S.KW), 947 O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True), 948 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 949 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 950): 951 """Performs 2-D grouped convolution. 952 953 Layout: 954 * Input: NGCHW. 955 * Kernel: GFCHW. 956 957 Numeric casting is performed on the operands to the inner multiply, promoting 958 them to the same data type as the accumulator/output. 959 """ 960 implements(ConvolutionOpInterface) 961 domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) 962 O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( 963 U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] 964 ) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) 965 966 967@linalg_structured_op 968def conv_2d_nhwgc_gfhwc( 969 I=TensorDef( 970 T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.G, S.C 971 ), 972 K=TensorDef(T2, S.G, S.FG, S.KH, S.KW, S.C), 973 O=TensorDef(U, S.N, S.OH, S.OW, S.G, S.FG, output=True), 974 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 975 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 976): 977 """Performs 2-D grouped convolution. 978 979 Layout: 980 * Input: NHWGC. 981 * Kernel: GFHWC. 982 983 Numeric casting is performed on the operands to the inner multiply, promoting 984 them to the same data type as the accumulator/output. 985 """ 986 implements(ConvolutionOpInterface) 987 domain(D.n, D.oh, D.ow, D.g, D.fg, D.kh, D.kw, D.c) 988 O[D.n, D.oh, D.ow, D.g, D.fg] += TypeFn.cast_signed( 989 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.g, D.c] 990 ) * TypeFn.cast_signed(U, K[D.g, D.fg, D.kh, D.kw, D.c]) 991 992 993@linalg_structured_op 994def conv_2d_nhwgc_gfhwc_q( 995 I=TensorDef( 996 T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.G, S.C 997 ), 998 K=TensorDef(T2, S.G, S.FG, S.KH, S.KW, S.C), 999 IZp=ScalarDef(I32), 1000 KZp=ScalarDef(I32), 1001 O=TensorDef(U, S.N, S.OH, S.OW, S.G, S.FG, output=True), 1002 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 1003 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 1004): 1005 """Performs 2-D grouped convolution with zero point offsets. 1006 1007 Layout: 1008 * Input: NHWGC. 1009 * Kernel: GFHWC. 1010 1011 Numeric casting is performed on the operands to the inner multiply, promoting 1012 them to the same data type as the accumulator/output. This includes the zero 1013 point offsets common to quantized operations. 1014 """ 1015 implements(ConvolutionOpInterface) 1016 domain(D.n, D.oh, D.ow, D.g, D.fg, D.kh, D.kw, D.c) 1017 O[D.n, D.oh, D.ow, D.g, D.fg] += ( 1018 TypeFn.cast_signed( 1019 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.g, D.c] 1020 ) 1021 - TypeFn.cast_signed(U, IZp) 1022 ) * ( 1023 TypeFn.cast_signed(U, K[D.g, D.fg, D.kh, D.kw, D.c]) 1024 - TypeFn.cast_signed(U, KZp) 1025 ) 1026 1027 1028@linalg_structured_op 1029def conv_2d_ngchw_gfchw_q( 1030 I=TensorDef( 1031 T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW 1032 ), 1033 K=TensorDef(T2, S.G, S.FG, S.C, S.KH, S.KW), 1034 IZp=ScalarDef(I32), 1035 KZp=ScalarDef(I32), 1036 O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True), 1037 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 1038 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 1039): 1040 """Performs 2-D grouped convolution with zero-point offsets. 1041 1042 Layout: 1043 * Input: NGCHW. 1044 * Kernel: GFCHW. 1045 1046 Numeric casting is performed on the operands to the inner multiply, promoting 1047 them to the same data type as the accumulator/output. This includes the zero 1048 point offsets common to quantized operations. 1049 """ 1050 implements(ConvolutionOpInterface) 1051 domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) 1052 O[D.n, D.g, D.fg, D.oh, D.ow] += ( 1053 TypeFn.cast_signed( 1054 U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] 1055 ) 1056 - TypeFn.cast_signed(U, IZp) 1057 ) * ( 1058 TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) 1059 - TypeFn.cast_signed(U, KZp) 1060 ) 1061 1062 1063@linalg_structured_op 1064def conv_3d_ndhwc_dhwcf( 1065 I=TensorDef( 1066 T1, 1067 S.N, 1068 S.OD * S.SD + S.KD * S.DD, 1069 S.OH * S.SH + S.KH * S.DH, 1070 S.OW * S.SW + S.KW * S.DW, 1071 S.C, 1072 ), 1073 K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), 1074 O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), 1075 strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), 1076 dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), 1077): 1078 """Performs 3-D convolution. 1079 1080 Numeric casting is performed on the operands to the inner multiply, promoting 1081 them to the same data type as the accumulator/output. 1082 """ 1083 implements(ConvolutionOpInterface) 1084 domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) 1085 O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed( 1086 U, 1087 I[ 1088 D.n, 1089 D.od * S.SD + D.kd * S.DD, 1090 D.oh * S.SH + D.kh * S.DH, 1091 D.ow * S.SW + D.kw * S.DW, 1092 D.c, 1093 ], 1094 ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) 1095 1096 1097@linalg_structured_op 1098def conv_3d_ndhwc_dhwcf_q( 1099 I=TensorDef( 1100 T1, 1101 S.N, 1102 S.OD * S.SD + S.KD * S.DD, 1103 S.OH * S.SH + S.KH * S.DH, 1104 S.OW * S.SW + S.KW * S.DW, 1105 S.C, 1106 ), 1107 K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), 1108 IZp=ScalarDef(I32), 1109 KZp=ScalarDef(I32), 1110 O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), 1111 strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), 1112 dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), 1113): 1114 """Performs 3-D convolution with zero point offsets. 1115 1116 Numeric casting is performed on the operands to the inner multiply, promoting 1117 them to the same data type as the accumulator/output. This includes the zero 1118 point offsets common to quantized operations. 1119 """ 1120 implements(ConvolutionOpInterface) 1121 domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) 1122 O[D.n, D.od, D.oh, D.ow, D.f] += ( 1123 TypeFn.cast_signed( 1124 U, 1125 I[ 1126 D.n, 1127 D.od * S.SD + D.kd * S.DD, 1128 D.oh * S.SH + D.kh * S.DH, 1129 D.ow * S.SW + D.kw * S.DW, 1130 D.c, 1131 ], 1132 ) 1133 - TypeFn.cast_signed(U, IZp) 1134 ) * ( 1135 TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) 1136 - TypeFn.cast_signed(U, KZp) 1137 ) 1138 1139 1140@linalg_structured_op 1141def conv_3d_ncdhw_fcdhw( 1142 I=TensorDef( 1143 T1, 1144 S.N, 1145 S.C, 1146 S.OD * S.SD + S.KD * S.DD, 1147 S.OH * S.SH + S.KH * S.DH, 1148 S.OW * S.SW + S.KW * S.DW, 1149 ), 1150 K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW), 1151 O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True), 1152 strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), 1153 dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), 1154): 1155 """Performs 3-D convolution. 1156 1157 Numeric casting is performed on the operands to the inner multiply, promoting 1158 them to the same data type as the accumulator/output. 1159 """ 1160 implements(ConvolutionOpInterface) 1161 domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) 1162 O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed( 1163 U, 1164 I[ 1165 D.n, 1166 D.c, 1167 D.od * S.SD + D.kd * S.DD, 1168 D.oh * S.SH + D.kh * S.DH, 1169 D.ow * S.SW + D.kw * S.DW, 1170 ], 1171 ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw]) 1172 1173 1174@linalg_structured_op 1175def depthwise_conv_1d_nwc_wc( 1176 I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), 1177 K=TensorDef(T2, S.KW, S.IC), 1178 O=TensorDef(U, S.N, S.OW, S.IC, output=True), 1179 strides=IndexAttrDef(S.SW, default=[1]), 1180 dilations=IndexAttrDef(S.DW, default=[1]), 1181): 1182 """Performs depth-wise 1-D convolution. 1183 1184 Numeric casting is performed on the operands to the inner multiply, promoting 1185 them to the same data type as the accumulator/output. Multiplier is set to 1 1186 which is a special case for most depthwise convolutions. 1187 """ 1188 implements(ConvolutionOpInterface) 1189 domain(D.n, D.ow, D.ic, D.kw) 1190 O[D.n, D.ow, D.ic] += TypeFn.cast_signed( 1191 U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic] 1192 ) * TypeFn.cast_signed(U, K[D.kw, D.ic]) 1193 1194 1195@linalg_structured_op 1196def depthwise_conv_1d_ncw_cw( 1197 I=TensorDef(T1, S.N, S.IC, S.OW * S.SW + S.KW * S.DW), 1198 K=TensorDef(T2, S.IC, S.KW), 1199 O=TensorDef(U, S.N, S.IC, S.OW, output=True), 1200 strides=IndexAttrDef(S.SW, default=[1]), 1201 dilations=IndexAttrDef(S.DW, default=[1]), 1202): 1203 """Performs depth-wise 1-D convolution. 1204 1205 Numeric casting is performed on the operands to the inner multiply, promoting 1206 them to the same data type as the accumulator/output. Multiplier is set to 1 1207 which is a special case for most depthwise convolutions. 1208 """ 1209 implements(ConvolutionOpInterface) 1210 domain(D.n, D.ow, D.ic, D.kw) 1211 O[D.n, D.ic, D.ow] += TypeFn.cast_signed( 1212 U, I[D.n, D.ic, D.ow * S.SW + D.kw * S.DW] 1213 ) * TypeFn.cast_signed(U, K[D.ic, D.kw]) 1214 1215 1216@linalg_structured_op 1217def depthwise_conv_1d_nwc_wcm( 1218 I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), 1219 K=TensorDef(T2, S.KW, S.IC, S.CM), 1220 O=TensorDef(U, S.N, S.OW, S.IC, S.CM, output=True), 1221 strides=IndexAttrDef(S.SW, default=[1]), 1222 dilations=IndexAttrDef(S.DW, default=[1]), 1223): 1224 """Performs depth-wise 1-D convolution. 1225 1226 Numeric casting is performed on the operands to the inner multiply, promoting 1227 them to the same data type as the accumulator/output. 1228 """ 1229 implements(ConvolutionOpInterface) 1230 domain(D.n, D.ow, D.ic, D.cm, D.kw) 1231 O[D.n, D.ow, D.ic, D.cm] += TypeFn.cast_signed( 1232 U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic] 1233 ) * TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm]) 1234 1235 1236@linalg_structured_op 1237def depthwise_conv_2d_nhwc_hwc( 1238 I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), 1239 K=TensorDef(T2, S.KH, S.KW, S.IC), 1240 O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), 1241 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 1242 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 1243): 1244 """Performs depth-wise 2-D convolution. 1245 1246 Numeric casting is performed on the operands to the inner multiply, promoting 1247 them to the same data type as the accumulator/output. Multiplier is set to 1 1248 which is a special case for most depthwise convolutions. 1249 """ 1250 implements(ConvolutionOpInterface) 1251 domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) 1252 O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed( 1253 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] 1254 ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) 1255 1256 1257@linalg_structured_op 1258def depthwise_conv_2d_nchw_chw( 1259 I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), 1260 K=TensorDef(T2, S.IC, S.KH, S.KW), 1261 O=TensorDef(U, S.N, S.IC, S.OH, S.OW, output=True), 1262 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 1263 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 1264): 1265 """Performs depth-wise 2-D convolution. 1266 1267 Numeric casting is performed on the operands to the inner multiply, promoting 1268 them to the same data type as the accumulator/output. Multiplier is set to 1 1269 which is a special case for most depthwise convolutions. 1270 """ 1271 implements(ConvolutionOpInterface) 1272 domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) 1273 O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed( 1274 U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] 1275 ) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw]) 1276 1277 1278@linalg_structured_op 1279def depthwise_conv_2d_nhwc_hwc_q( 1280 I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), 1281 K=TensorDef(T2, S.KH, S.KW, S.IC), 1282 IZp=ScalarDef(I32), 1283 KZp=ScalarDef(I32), 1284 O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), 1285 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 1286 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 1287): 1288 """Performs depth-wise 2-D convolution. 1289 1290 Numeric casting is performed on the operands to the inner multiply, promoting 1291 them to the same data type as the accumulator/output. 1292 """ 1293 implements(ConvolutionOpInterface) 1294 domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) 1295 O[D.n, D.oh, D.ow, D.ic] += ( 1296 TypeFn.cast_signed( 1297 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] 1298 ) 1299 - TypeFn.cast_signed(U, IZp) 1300 ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast_signed(U, KZp)) 1301 1302 1303@linalg_structured_op 1304def depthwise_conv_2d_nhwc_hwcm( 1305 I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), 1306 K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), 1307 O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), 1308 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 1309 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 1310): 1311 """Performs depth-wise 2-D convolution. 1312 1313 Numeric casting is performed on the operands to the inner multiply, promoting 1314 them to the same data type as the accumulator/output. 1315 """ 1316 implements(ConvolutionOpInterface) 1317 domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) 1318 O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( 1319 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] 1320 ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) 1321 1322 1323@linalg_structured_op 1324def depthwise_conv_2d_nhwc_hwcm_q( 1325 I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), 1326 K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), 1327 IZp=ScalarDef(I32), 1328 KZp=ScalarDef(I32), 1329 O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), 1330 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 1331 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 1332): 1333 """Performs depth-wise 2-D convolution. 1334 1335 Numeric casting is performed on the operands to the inner multiply, promoting 1336 them to the same data type as the accumulator/output. 1337 """ 1338 implements(ConvolutionOpInterface) 1339 domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) 1340 O[D.n, D.oh, D.ow, D.ic, D.cm] += ( 1341 TypeFn.cast_signed( 1342 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] 1343 ) 1344 - TypeFn.cast_signed(U, IZp) 1345 ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast_signed(U, KZp)) 1346 1347 1348@linalg_structured_op 1349def depthwise_conv_3d_ndhwc_dhwc( 1350 I=TensorDef( 1351 T1, 1352 S.N, 1353 S.OD * S.SD + S.KD * S.DD, 1354 S.OH * S.SH + S.KH * S.DH, 1355 S.OW * S.SW + S.KW * S.DW, 1356 S.IC, 1357 ), 1358 K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC), 1359 O=TensorDef(U, S.N, S.OD, S.OH, S.OW, output=True), 1360 strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), 1361 dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), 1362): 1363 """Performs depth-wise 3-D convolution. 1364 1365 Numeric casting is performed on the operands to the inner multiply, promoting 1366 them to the same data type as the accumulator/output. Multiplier is set to 1 1367 which is a special case for most depthwise convolutions. 1368 """ 1369 implements(ConvolutionOpInterface) 1370 domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) 1371 O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed( 1372 U, 1373 I[ 1374 D.n, 1375 D.od * S.SD + D.kd * S.DD, 1376 D.oh * S.SH + D.kh * S.DH, 1377 D.ow * S.SW + D.kw * S.DW, 1378 D.ic, 1379 ], 1380 ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic]) 1381 1382 1383@linalg_structured_op 1384def depthwise_conv_3d_ncdhw_cdhw( 1385 I=TensorDef( 1386 T1, 1387 S.N, 1388 S.IC, 1389 S.OD * S.SD + S.KD * S.DD, 1390 S.OH * S.SH + S.KH * S.DH, 1391 S.OW * S.SW + S.KW * S.DW, 1392 ), 1393 K=TensorDef(T2, S.IC, S.KD, S.KH, S.KW), 1394 O=TensorDef(U, S.N, S.IC, S.OD, S.OH, S.OW, output=True), 1395 strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), 1396 dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), 1397): 1398 """Performs depth-wise 3-D convolution. 1399 1400 Numeric casting is performed on the operands to the inner multiply, promoting 1401 them to the same data type as the accumulator/output. Multiplier is set to 1 1402 which is a special case for most depthwise convolutions. 1403 """ 1404 implements(ConvolutionOpInterface) 1405 domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) 1406 O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed( 1407 U, 1408 I[ 1409 D.n, 1410 D.ic, 1411 D.od * S.SD + D.kd * S.DD, 1412 D.oh * S.SH + D.kh * S.DH, 1413 D.ow * S.SW + D.kw * S.DW, 1414 ], 1415 ) * TypeFn.cast_signed(U, K[D.ic, D.kd, D.kh, D.kw]) 1416 1417 1418@linalg_structured_op 1419def depthwise_conv_3d_ndhwc_dhwcm( 1420 I=TensorDef( 1421 T1, 1422 S.N, 1423 S.OD * S.SD + S.KD * S.DD, 1424 S.OH * S.SH + S.KH * S.DH, 1425 S.OW * S.SW + S.KW * S.DW, 1426 S.IC, 1427 ), 1428 K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM), 1429 O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.CM, output=True), 1430 strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), 1431 dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), 1432): 1433 """Performs depth-wise 3-D convolution. 1434 1435 Numeric casting is performed on the operands to the inner multiply, promoting 1436 them to the same data type as the accumulator/output. 1437 """ 1438 implements(ConvolutionOpInterface) 1439 domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic) 1440 O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( 1441 U, 1442 I[ 1443 D.n, 1444 D.od * S.SD + D.kd * S.DD, 1445 D.oh * S.SH + D.kh * S.DH, 1446 D.ow * S.SW + D.kw * S.DW, 1447 D.ic, 1448 ], 1449 ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic, D.cm]) 1450 1451 1452@linalg_structured_op 1453def pooling_nhwc_sum( 1454 I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), 1455 K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), 1456 O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), 1457 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 1458 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 1459): 1460 """Performs sum pooling. 1461 1462 Layout: 1463 * Input: NHWC. 1464 * Kernel: HW. 1465 1466 Numeric casting is performed on the input operand, promoting it to the same 1467 data type as the accumulator/output. 1468 """ 1469 implements(ConvolutionOpInterface) 1470 domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) 1471 O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed( 1472 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] 1473 ) 1474 1475 1476@linalg_structured_op 1477def pooling_nchw_sum( 1478 I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), 1479 K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), 1480 O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), 1481 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 1482 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 1483): 1484 """Performs sum pooling. 1485 1486 Layout: 1487 * Input: NCHW. 1488 * Kernel: HW. 1489 1490 Numeric casting is performed on the input operand, promoting it to the same 1491 data type as the accumulator/output. 1492 """ 1493 implements(ConvolutionOpInterface) 1494 domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) 1495 O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed( 1496 U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] 1497 ) 1498 1499 1500@linalg_structured_op 1501def pooling_nhwc_max( 1502 I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), 1503 K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), 1504 O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), 1505 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 1506 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 1507): 1508 """Performs max pooling. 1509 1510 Numeric casting is performed on the input operand, promoting it to the same 1511 data type as the accumulator/output. 1512 """ 1513 implements(ConvolutionOpInterface) 1514 domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) 1515 O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw]( 1516 TypeFn.cast_signed( 1517 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] 1518 ) 1519 ) 1520 1521 1522@linalg_structured_op 1523def pooling_nhwc_max_unsigned( 1524 I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), 1525 K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), 1526 O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), 1527 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 1528 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 1529): 1530 """Performs unsigned max pooling. 1531 1532 Numeric casting is performed on the input operand, promoting it to the same 1533 data type as the accumulator/output. 1534 """ 1535 implements(ConvolutionOpInterface) 1536 domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) 1537 O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw]( 1538 TypeFn.cast_unsigned( 1539 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] 1540 ) 1541 ) 1542 1543 1544@linalg_structured_op 1545def pooling_nchw_max( 1546 I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), 1547 K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), 1548 O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), 1549 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 1550 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 1551): 1552 """Performs max pooling. 1553 1554 Numeric casting is performed on the input operand, promoting it to the same 1555 data type as the accumulator/output. 1556 """ 1557 implements(ConvolutionOpInterface) 1558 domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) 1559 O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw]( 1560 TypeFn.cast_signed( 1561 U, 1562 I[ 1563 D.n, 1564 D.c, 1565 D.oh * S.SH + D.kh * S.DH, 1566 D.ow * S.SW + D.kw * S.DW, 1567 ], 1568 ) 1569 ) 1570 1571 1572@linalg_structured_op 1573def pooling_nhwc_min( 1574 I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), 1575 K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), 1576 O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), 1577 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 1578 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 1579): 1580 """Performs min pooling. 1581 1582 Numeric casting is performed on the input operand, promoting it to the same 1583 data type as the accumulator/output. 1584 """ 1585 implements(ConvolutionOpInterface) 1586 domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) 1587 O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw]( 1588 TypeFn.cast_signed( 1589 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] 1590 ) 1591 ) 1592 1593 1594@linalg_structured_op 1595def pooling_nhwc_min_unsigned( 1596 I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), 1597 K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), 1598 O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), 1599 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 1600 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), 1601): 1602 """Performs unsigned min pooling. 1603 1604 Numeric casting is performed on the input operand, promoting it to the same 1605 data type as the accumulator/output. 1606 """ 1607 implements(ConvolutionOpInterface) 1608 domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) 1609 O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw]( 1610 TypeFn.cast_unsigned( 1611 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] 1612 ) 1613 ) 1614 1615 1616@linalg_structured_op 1617def pooling_nwc_sum( 1618 I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), 1619 K=TensorDef(T2, S.KW, index_dims=[D.kw]), 1620 O=TensorDef(U, S.N, S.OW, S.C, output=True), 1621 strides=IndexAttrDef(S.SW, default=[1]), 1622 dilations=IndexAttrDef(S.DW, default=[1]), 1623): 1624 """Performs sum pooling. 1625 1626 Layout: 1627 * Input: NWC. 1628 * Kernel: W. 1629 1630 Numeric casting is performed on the input operand, promoting it to the same 1631 data type as the accumulator/output. 1632 """ 1633 implements(ConvolutionOpInterface) 1634 domain(D.n, D.ow, D.c, D.kw) 1635 O[D.n, D.ow, D.c] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) 1636 1637 1638@linalg_structured_op 1639def pooling_ncw_sum( 1640 I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), 1641 K=TensorDef(T2, S.KW, index_dims=[D.kw]), 1642 O=TensorDef(U, S.N, S.C, S.OW, output=True), 1643 strides=IndexAttrDef(S.SW, default=[1]), 1644 dilations=IndexAttrDef(S.DW, default=[1]), 1645): 1646 """Performs sum pooling. 1647 1648 Layout: 1649 * Input: NCW. 1650 * Kernel: W. 1651 1652 Numeric casting is performed on the input operand, promoting it to the same 1653 data type as the accumulator/output. 1654 """ 1655 implements(ConvolutionOpInterface) 1656 domain(D.n, D.c, D.ow, D.kw) 1657 O[D.n, D.c, D.ow] += TypeFn.cast_signed(U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) 1658 1659 1660@linalg_structured_op 1661def pooling_nwc_max( 1662 I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), 1663 K=TensorDef(T2, S.KW, index_dims=[D.kw]), 1664 O=TensorDef(U, S.N, S.OW, S.C, output=True), 1665 strides=IndexAttrDef(S.SW, default=[1]), 1666 dilations=IndexAttrDef(S.DW, default=[1]), 1667): 1668 """Performs max pooling. 1669 1670 Numeric casting is performed on the input operand, promoting it to the same 1671 data type as the accumulator/output. 1672 """ 1673 implements(ConvolutionOpInterface) 1674 domain(D.n, D.ow, D.c, D.kw) 1675 O[D.n, D.ow, D.c] = ReduceFn.max_signed[[D.kw]]( 1676 TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) 1677 ) 1678 1679 1680@linalg_structured_op 1681def pooling_nwc_max_unsigned( 1682 I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), 1683 K=TensorDef(T2, S.KW, index_dims=[D.kw]), 1684 O=TensorDef(U, S.N, S.OW, S.C, output=True), 1685 strides=IndexAttrDef(S.SW, default=[1]), 1686 dilations=IndexAttrDef(S.DW, default=[1]), 1687): 1688 """Performs unsigned max pooling. 1689 1690 Numeric casting is performed on the input operand, promoting it to the same 1691 data type as the accumulator/output. 1692 """ 1693 implements(ConvolutionOpInterface) 1694 domain(D.n, D.ow, D.c, D.kw) 1695 O[D.n, D.ow, D.c] = ReduceFn.max_unsigned[[D.kw]]( 1696 TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) 1697 ) 1698 1699 1700@linalg_structured_op 1701def pooling_ncw_max( 1702 I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), 1703 K=TensorDef(T2, S.KW, index_dims=[D.kw]), 1704 O=TensorDef(U, S.N, S.C, S.OW, output=True), 1705 strides=IndexAttrDef(S.SW, default=[1]), 1706 dilations=IndexAttrDef(S.DW, default=[1]), 1707): 1708 """Performs max pooling. 1709 1710 Numeric casting is performed on the input operand, promoting it to the same 1711 data type as the accumulator/output. 1712 """ 1713 implements(ConvolutionOpInterface) 1714 domain(D.n, D.c, D.ow, D.kw) 1715 O[D.n, D.c, D.ow] = ReduceFn.max_signed[[D.kw]]( 1716 TypeFn.cast_signed( 1717 U, 1718 I[ 1719 D.n, 1720 D.c, 1721 D.ow * S.SW + D.kw * S.DW, 1722 ], 1723 ) 1724 ) 1725 1726 1727@linalg_structured_op 1728def pooling_nwc_min( 1729 I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), 1730 K=TensorDef(T2, S.KW, index_dims=[D.kw]), 1731 O=TensorDef(U, S.N, S.OW, S.C, output=True), 1732 strides=IndexAttrDef(S.SW, default=[1]), 1733 dilations=IndexAttrDef(S.DW, default=[1]), 1734): 1735 """Performs min pooling. 1736 1737 Numeric casting is performed on the input operand, promoting it to the same 1738 data type as the accumulator/output. 1739 """ 1740 implements(ConvolutionOpInterface) 1741 domain(D.n, D.ow, D.c, D.kw) 1742 O[D.n, D.ow, D.c] = ReduceFn.min_signed[[D.kw]]( 1743 TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) 1744 ) 1745 1746 1747@linalg_structured_op 1748def pooling_nwc_min_unsigned( 1749 I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), 1750 K=TensorDef(T2, S.KW, index_dims=[D.kw]), 1751 O=TensorDef(U, S.N, S.OW, S.C, output=True), 1752 strides=IndexAttrDef(S.SW, default=[1]), 1753 dilations=IndexAttrDef(S.DW, default=[1]), 1754): 1755 """Performs unsigned min pooling. 1756 1757 Numeric casting is performed on the input operand, promoting it to the same 1758 data type as the accumulator/output. 1759 """ 1760 implements(ConvolutionOpInterface) 1761 domain(D.n, D.ow, D.c, D.kw) 1762 O[D.n, D.ow, D.c] = ReduceFn.min_unsigned[[D.kw]]( 1763 TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) 1764 ) 1765 1766 1767@linalg_structured_op 1768def pooling_ndhwc_sum( 1769 I=TensorDef( 1770 T1, 1771 S.N, 1772 S.OD * S.SD + S.KD * S.DD, 1773 S.OH * S.SH + S.KH * S.DH, 1774 S.OW * S.SW + S.KW * S.DW, 1775 S.C, 1776 ), 1777 K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), 1778 O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), 1779 strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), 1780 dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), 1781): 1782 """Performs 3D sum pooling. 1783 1784 Numeric casting is performed on the input operand, promoting it to the same 1785 data type as the accumulator/output. 1786 """ 1787 implements(ConvolutionOpInterface) 1788 domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) 1789 O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed( 1790 U, 1791 I[ 1792 D.n, 1793 D.od * S.SD + D.kd * S.DD, 1794 D.oh * S.SH + D.kh * S.DH, 1795 D.ow * S.SW + D.kw * S.DW, 1796 D.c, 1797 ], 1798 ) 1799 1800 1801@linalg_structured_op 1802def pooling_ndhwc_max( 1803 I=TensorDef( 1804 T1, 1805 S.N, 1806 S.OD * S.SD + S.KD * S.DD, 1807 S.OH * S.SH + S.KH * S.DH, 1808 S.OW * S.SW + S.KW * S.DW, 1809 S.C, 1810 ), 1811 K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), 1812 O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), 1813 strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), 1814 dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), 1815): 1816 """Performs 3D max pooling. 1817 1818 Numeric casting is performed on the input operand, promoting it to the same 1819 data type as the accumulator/output. 1820 """ 1821 implements(ConvolutionOpInterface) 1822 domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) 1823 O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw]( 1824 TypeFn.cast_signed( 1825 U, 1826 I[ 1827 D.n, 1828 D.od * S.SD + D.kd * S.DD, 1829 D.oh * S.SH + D.kh * S.DH, 1830 D.ow * S.SW + D.kw * S.DW, 1831 D.c, 1832 ], 1833 ) 1834 ) 1835 1836 1837@linalg_structured_op 1838def pooling_ndhwc_min( 1839 I=TensorDef( 1840 T1, 1841 S.N, 1842 S.OD * S.SD + S.KD * S.DD, 1843 S.OH * S.SH + S.KH * S.DH, 1844 S.OW * S.SW + S.KW * S.DW, 1845 S.C, 1846 ), 1847 K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), 1848 O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), 1849 strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), 1850 dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), 1851): 1852 """Performs 3D min pooling. 1853 1854 Numeric casting is performed on the input operand, promoting it to the same 1855 data type as the accumulator/output. 1856 """ 1857 implements(ConvolutionOpInterface) 1858 domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) 1859 O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw]( 1860 TypeFn.cast_signed( 1861 U, 1862 I[ 1863 D.n, 1864 D.od * S.SD + D.kd * S.DD, 1865 D.oh * S.SH + D.kh * S.DH, 1866 D.ow * S.SW + D.kw * S.DW, 1867 D.c, 1868 ], 1869 ) 1870 ) 1871 1872 1873@linalg_structured_op 1874def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)): 1875 """Fills the output tensor with the given value. 1876 1877 Works for arbitrary ranked output tensors since the operation performs scalar 1878 accesses only and is thus rank polymorphic. Numeric casting is performed on 1879 the value operand, promoting it to the same data type as the output. 1880 """ 1881 implements(FillOpInterface) 1882 defines(Canonicalizer) 1883 O[None] = TypeFn.cast_signed(U, value) 1884 1885 1886@linalg_structured_op 1887def fill_rng_2d( 1888 min=ScalarDef(F64), 1889 max=ScalarDef(F64), 1890 seed=ScalarDef(I32), 1891 O=TensorDef(T, S.M, S.N, output=True), 1892): 1893 """Fills the output tensor with pseudo random numbers. 1894 1895 The operation generations pseudo random numbers using a linear congruential 1896 generator. It provides no guarantees regarding the distribution of the 1897 generated random numbers. Instead of generating the random numbers 1898 sequentially, it instantiates one random number generator per data element 1899 and runs them in parallel. The seed operand and the indices of the data 1900 element seed the random number generation. The min and max operands limit 1901 the range of the generated random numbers. 1902 """ 1903 domain(D.m, D.n) 1904 multiplier = TypeFn.cast_signed(I32, const(1103515245)) 1905 increment = TypeFn.cast_signed(I32, const(12345)) 1906 rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment 1907 rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment 1908 inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10)) 1909 offset = TypeFn.cast_signed(F64, const(2147483647)) 1910 scaling = (max - min) * inv_range 1911 O[D.m, D.n] = TypeFn.cast_signed( 1912 T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min 1913 ) 1914