1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import transform 5from mlir.dialects.transform import vector 6 7 8def run_apply_patterns(f): 9 with Context(), Location.unknown(): 10 module = Module.create() 11 with InsertionPoint(module.body): 12 sequence = transform.SequenceOp( 13 transform.FailurePropagationMode.Propagate, 14 [], 15 transform.AnyOpType.get(), 16 ) 17 with InsertionPoint(sequence.body): 18 apply = transform.ApplyPatternsOp(sequence.bodyTarget) 19 with InsertionPoint(apply.patterns): 20 f() 21 transform.YieldOp() 22 print("\nTEST:", f.__name__) 23 print(module) 24 return f 25 26 27@run_apply_patterns 28def non_configurable_patterns(): 29 # CHECK-LABEL: TEST: non_configurable_patterns 30 # CHECK: apply_patterns 31 # CHECK: transform.apply_patterns.vector.cast_away_vector_leading_one_dim 32 vector.ApplyCastAwayVectorLeadingOneDimPatternsOp() 33 # CHECK: transform.apply_patterns.vector.rank_reducing_subview_patterns 34 vector.ApplyRankReducingSubviewPatternsOp() 35 # CHECK: transform.apply_patterns.vector.transfer_permutation_patterns 36 vector.ApplyTransferPermutationPatternsOp() 37 # CHECK: transform.apply_patterns.vector.lower_broadcast 38 vector.ApplyLowerBroadcastPatternsOp() 39 # CHECK: transform.apply_patterns.vector.lower_masks 40 vector.ApplyLowerMasksPatternsOp() 41 # CHECK: transform.apply_patterns.vector.lower_masked_transfers 42 vector.ApplyLowerMaskedTransfersPatternsOp() 43 # CHECK: transform.apply_patterns.vector.materialize_masks 44 vector.ApplyMaterializeMasksPatternsOp() 45 # CHECK: transform.apply_patterns.vector.lower_outerproduct 46 vector.ApplyLowerOuterProductPatternsOp() 47 # CHECK: transform.apply_patterns.vector.lower_gather 48 vector.ApplyLowerGatherPatternsOp() 49 # CHECK: transform.apply_patterns.vector.lower_scan 50 vector.ApplyLowerScanPatternsOp() 51 # CHECK: transform.apply_patterns.vector.lower_shape_cast 52 vector.ApplyLowerShapeCastPatternsOp() 53 54 55@run_apply_patterns 56def configurable_patterns(): 57 # CHECK-LABEL: TEST: configurable_patterns 58 # CHECK: apply_patterns 59 # CHECK: transform.apply_patterns.vector.lower_transfer 60 # CHECK-SAME: max_transfer_rank = 4 61 vector.ApplyLowerTransferPatternsOp(max_transfer_rank=4) 62 # CHECK: transform.apply_patterns.vector.transfer_to_scf 63 # CHECK-SAME: max_transfer_rank = 3 64 # CHECK-SAME: full_unroll = true 65 vector.ApplyTransferToScfPatternsOp(max_transfer_rank=3, full_unroll=True) 66 67 68@run_apply_patterns 69def enum_configurable_patterns(): 70 # CHECK: transform.apply_patterns.vector.lower_contraction 71 vector.ApplyLowerContractionPatternsOp() 72 # CHECK: transform.apply_patterns.vector.lower_contraction 73 # CHECK-SAME: lowering_strategy = matmulintrinsics 74 vector.ApplyLowerContractionPatternsOp( 75 lowering_strategy=vector.VectorContractLowering.Matmul 76 ) 77 # CHECK: transform.apply_patterns.vector.lower_contraction 78 # CHECK-SAME: lowering_strategy = parallelarith 79 vector.ApplyLowerContractionPatternsOp( 80 lowering_strategy=vector.VectorContractLowering.ParallelArith 81 ) 82 83 # CHECK: transform.apply_patterns.vector.lower_multi_reduction 84 vector.ApplyLowerMultiReductionPatternsOp() 85 # CHECK: transform.apply_patterns.vector.lower_multi_reduction 86 # This is the default mode, not printed. 87 vector.ApplyLowerMultiReductionPatternsOp( 88 lowering_strategy=vector.VectorMultiReductionLowering.InnerParallel 89 ) 90 # CHECK: transform.apply_patterns.vector.lower_multi_reduction 91 # CHECK-SAME: lowering_strategy = innerreduction 92 vector.ApplyLowerMultiReductionPatternsOp( 93 lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction 94 ) 95 96 # CHECK: transform.apply_patterns.vector.lower_transpose 97 vector.ApplyLowerTransposePatternsOp() 98 # CHECK: transform.apply_patterns.vector.lower_transpose 99 # This is the default strategy, not printed. 100 vector.ApplyLowerTransposePatternsOp( 101 lowering_strategy=vector.VectorTransposeLowering.EltWise 102 ) 103 # CHECK: transform.apply_patterns.vector.lower_transpose 104 # CHECK-SAME: lowering_strategy = flat_transpose 105 vector.ApplyLowerTransposePatternsOp( 106 lowering_strategy=vector.VectorTransposeLowering.Flat 107 ) 108 # CHECK: transform.apply_patterns.vector.lower_transpose 109 # CHECK-SAME: lowering_strategy = shuffle_1d 110 vector.ApplyLowerTransposePatternsOp( 111 lowering_strategy=vector.VectorTransposeLowering.Shuffle1D 112 ) 113 # CHECK: transform.apply_patterns.vector.lower_transpose 114 # CHECK-SAME: lowering_strategy = shuffle_16x16 115 vector.ApplyLowerTransposePatternsOp( 116 lowering_strategy=vector.VectorTransposeLowering.Shuffle16x16 117 ) 118 # CHECK: transform.apply_patterns.vector.lower_transpose 119 # CHECK-SAME: lowering_strategy = flat_transpose 120 # CHECK-SAME: avx2_lowering_strategy = true 121 vector.ApplyLowerTransposePatternsOp( 122 lowering_strategy=vector.VectorTransposeLowering.Flat, 123 avx2_lowering_strategy=True, 124 ) 125 126 # CHECK: transform.apply_patterns.vector.split_transfer_full_partial 127 vector.ApplySplitTransferFullPartialPatternsOp() 128 # CHECK: transform.apply_patterns.vector.split_transfer_full_partial 129 # CHECK-SAME: split_transfer_strategy = none 130 vector.ApplySplitTransferFullPartialPatternsOp( 131 split_transfer_strategy=vector.VectorTransferSplit.None_ 132 ) 133 # CHECK: transform.apply_patterns.vector.split_transfer_full_partial 134 # CHECK-SAME: split_transfer_strategy = "vector-transfer" 135 vector.ApplySplitTransferFullPartialPatternsOp( 136 split_transfer_strategy=vector.VectorTransferSplit.VectorTransfer 137 ) 138 # CHECK: transform.apply_patterns.vector.split_transfer_full_partial 139 # This is the default mode, not printed. 140 vector.ApplySplitTransferFullPartialPatternsOp( 141 split_transfer_strategy=vector.VectorTransferSplit.LinalgCopy 142 ) 143 # CHECK: transform.apply_patterns.vector.split_transfer_full_partial 144 # CHECK-SAME: split_transfer_strategy = "force-in-bounds" 145 vector.ApplySplitTransferFullPartialPatternsOp( 146 split_transfer_strategy=vector.VectorTransferSplit.ForceInBounds 147 ) 148