Skip to content

Commit 37efbd2

Browse files
Added 6x16 and 6xlt16 main kernels for f32 using AVX512 instructions (amd#38)
* Implemented 6xlt8 AVX2 kernel for n<8 inputs * Implemented fringe kernels for 6x16 and 6xlt16 AVX512 kernels for FP32 * Implemented m-fringe kernels for 6xlt8 kernel for AVX2 * Implemented m-fringe kernels for 6xlt8 kernel for AVX2 * Added the deleted kernels and fixed bias bug AMD-Internal: SWLCSG-3556
1 parent 14e46ad commit 37efbd2

11 files changed

Lines changed: 12518 additions & 443 deletions

addon/aocl_gemm/kernels/lpgemm_kernels.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,43 @@ typedef void (*lpgemm_m_fringe_f32_ker_ft)
6464
lpgemm_post_op_attr post_ops_attr
6565
);
6666

67+
typedef void (*lpgemm_n_fringe_f32_ker_ft)
68+
(
69+
const dim_t m0,
70+
const dim_t k0,
71+
const float* a,
72+
const dim_t rs_a,
73+
const dim_t cs_a,
74+
const dim_t ps_a,
75+
const float* b,
76+
const dim_t rs_b,
77+
const dim_t cs_b,
78+
float* c,
79+
const dim_t rs_c,
80+
const float alpha,
81+
const float beta,
82+
lpgemm_post_op* post_ops_list,
83+
lpgemm_post_op_attr post_ops_attr
84+
);
85+
86+
typedef void (*lpgemm_mn_fringe_f32_mask_ker_ft)
87+
(
88+
const dim_t k0,
89+
const float* a,
90+
const dim_t rs_a,
91+
const dim_t cs_a,
92+
const float* b,
93+
const dim_t rs_b,
94+
const dim_t cs_b,
95+
float* c,
96+
const dim_t rs_c,
97+
const float alpha,
98+
const float beta,
99+
const dim_t n0_rem,
100+
lpgemm_post_op* post_ops_list,
101+
lpgemm_post_op_attr post_ops_attr
102+
);
103+
67104
#define LPGEMM_MAIN_KERN(A_type,B_type,C_type,LP_SFX) \
68105
void lpgemm_rowvar_ ## LP_SFX \
69106
( \
@@ -242,6 +279,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x32);
242279
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x32);
243280
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x32);
244281
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x32);
282+
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x16);
283+
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x16);
284+
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x16);
285+
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x16);
286+
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x16);
245287
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x16);
246288
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x16);
247289
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x16);
@@ -363,6 +405,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x48);
363405

364406
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x48m);
365407
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x32m);
408+
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x16m);
366409
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x8m);
367410
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x4m);
368411
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x2m);
@@ -472,6 +515,8 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16);
472515
LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12xlt16);
473516

474517
LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6xlt16);
518+
LPGEMM_N_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6xlt16m);
519+
LPGEMM_N_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_6xlt8m);
475520

476521
LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6xlt16);
477522

@@ -674,6 +719,18 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3xlt16);
674719
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2xlt16);
675720
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1xlt16);
676721

722+
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5xlt16);
723+
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4xlt16);
724+
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3xlt16);
725+
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2xlt16);
726+
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1xlt16);
727+
728+
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_5xlt8);
729+
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_4xlt8);
730+
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_3xlt8);
731+
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_2xlt8);
732+
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_1xlt8);
733+
677734
#define LPGEMM_MN_LT_NR0_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
678735
void lpgemm_rowvar_ ## LP_SFX \
679736
( \

0 commit comments

Comments
 (0)