|
4 | 4 | #include "pycore_strhex.h" // _Py_strhex_with_sep() |
5 | 5 | #include "pycore_unicodeobject.h" // _PyUnicode_CheckConsistency() |
6 | 6 |
|
7 | | -/* AVX2 SIMD optimization for hexlify. |
| 7 | +/* SIMD optimization for hexlify. |
8 | 8 | Only available on x86-64 with GCC/Clang. */ |
9 | 9 | #if defined(__x86_64__) && (defined(__GNUC__) || defined(__clang__)) |
10 | | -# define PY_HEXLIFY_CAN_COMPILE_AVX2 1 |
| 10 | +# define PY_HEXLIFY_CAN_COMPILE_SIMD 1 |
11 | 11 | # include <cpuid.h> |
12 | 12 | # include <immintrin.h> |
13 | 13 | #else |
14 | | -# define PY_HEXLIFY_CAN_COMPILE_AVX2 0 |
| 14 | +# define PY_HEXLIFY_CAN_COMPILE_SIMD 0 |
15 | 15 | #endif |
16 | 16 |
|
17 | | -#if PY_HEXLIFY_CAN_COMPILE_AVX2 |
| 17 | +#if PY_HEXLIFY_CAN_COMPILE_SIMD |
18 | 18 |
|
19 | | -/* Runtime CPU feature detection (lazy initialization) */ |
20 | | -static int _Py_hexlify_avx2_available = -1; /* -1 = not checked yet */ |
| 19 | +/* Runtime CPU feature detection (lazy initialization) |
| 20 | + -1 = not checked, 0 = no SIMD, 1 = AVX2, 2 = AVX-512 */ |
| 21 | +static int _Py_hexlify_simd_level = -1; |
| 22 | + |
| 23 | +#define PY_HEXLIFY_SIMD_NONE 0 |
| 24 | +#define PY_HEXLIFY_SIMD_AVX2 1 |
| 25 | +#define PY_HEXLIFY_SIMD_AVX512 2 |
21 | 26 |
|
22 | 27 | static void |
23 | 28 | _Py_hexlify_detect_cpu_features(void) |
24 | 29 | { |
25 | 30 | unsigned int eax, ebx, ecx, edx; |
26 | 31 |
|
27 | | - /* Check for AVX2 support: CPUID.7H:EBX bit 5 */ |
28 | | - if (__get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx)) { |
29 | | - _Py_hexlify_avx2_available = (ebx & (1 << 5)) != 0; |
30 | | - } else { |
31 | | - _Py_hexlify_avx2_available = 0; |
| 32 | + _Py_hexlify_simd_level = PY_HEXLIFY_SIMD_NONE; |
| 33 | + |
| 34 | + if (!__get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx)) { |
| 35 | + return; |
| 36 | + } |
| 37 | + |
| 38 | + /* Check for AVX2: CPUID.7H:EBX bit 5 */ |
| 39 | + int has_avx2 = (ebx & (1 << 5)) != 0; |
| 40 | + |
| 41 | + /* Check for AVX-512F + AVX-512BW + AVX-512VBMI: |
| 42 | + CPUID.7H:EBX bits 16 and 30, ECX bit 1 */ |
| 43 | + int has_avx512f = (ebx & (1 << 16)) != 0; |
| 44 | + int has_avx512bw = (ebx & (1 << 30)) != 0; |
| 45 | + int has_avx512vbmi = (ecx & (1 << 1)) != 0; |
| 46 | + |
| 47 | + if (has_avx512f && has_avx512bw && has_avx512vbmi) { |
| 48 | + _Py_hexlify_simd_level = PY_HEXLIFY_SIMD_AVX512; |
| 49 | + } else if (has_avx2) { |
| 50 | + _Py_hexlify_simd_level = PY_HEXLIFY_SIMD_AVX2; |
32 | 51 | } |
33 | 52 | } |
34 | 53 |
|
35 | 54 | static inline int |
36 | | -_Py_hexlify_can_use_avx2(void) |
| 55 | +_Py_hexlify_get_simd_level(void) |
37 | 56 | { |
38 | | - if (_Py_hexlify_avx2_available < 0) { |
| 57 | + if (_Py_hexlify_simd_level < 0) { |
39 | 58 | _Py_hexlify_detect_cpu_features(); |
40 | 59 | } |
41 | | - return _Py_hexlify_avx2_available; |
| 60 | + return _Py_hexlify_simd_level; |
42 | 61 | } |
43 | 62 |
|
44 | 63 | /* AVX2-accelerated hexlify: converts 32 bytes to 64 hex chars per iteration. |
@@ -96,7 +115,115 @@ _Py_hexlify_avx2(const unsigned char *src, Py_UCS1 *dst, Py_ssize_t len) |
96 | 115 | } |
97 | 116 | } |
98 | 117 |
|
99 | | -#endif /* PY_HEXLIFY_CAN_COMPILE_AVX2 */ |
| 118 | +/* AVX-512 accelerated hexlify: converts 64 bytes to 128 hex chars per iteration. |
| 119 | + Requires AVX-512F, AVX-512BW, and AVX-512VBMI for byte-level permutation. */ |
| 120 | +__attribute__((target("avx512f,avx512bw,avx512vbmi"))) |
| 121 | +static void |
| 122 | +_Py_hexlify_avx512(const unsigned char *src, Py_UCS1 *dst, Py_ssize_t len) |
| 123 | +{ |
| 124 | + const __m512i mask_0f = _mm512_set1_epi8(0x0f); |
| 125 | + const __m512i ascii_0 = _mm512_set1_epi8('0'); |
| 126 | + const __m512i ascii_a = _mm512_set1_epi8('a' - 10); |
| 127 | + const __m512i nine = _mm512_set1_epi8(9); |
| 128 | + |
| 129 | + /* Permutation indices for interleaving hi/lo nibbles. |
| 130 | + We need to transform: |
| 131 | + hi: H0 H1 H2 ... H63 |
| 132 | + lo: L0 L1 L2 ... L63 |
| 133 | + into: |
| 134 | + out0: H0 L0 H1 L1 ... H31 L31 |
| 135 | + out1: H32 L32 H33 L33 ... H63 L63 |
| 136 | + */ |
| 137 | + const __m512i interleave_lo = _mm512_set_epi8( |
| 138 | + 39, 103, 38, 102, 37, 101, 36, 100, 35, 99, 34, 98, 33, 97, 32, 96, |
| 139 | + 47, 111, 46, 110, 45, 109, 44, 108, 43, 107, 42, 106, 41, 105, 40, 104, |
| 140 | + 55, 119, 54, 118, 53, 117, 52, 116, 51, 115, 50, 114, 49, 113, 48, 112, |
| 141 | + 63, 127, 62, 126, 61, 125, 60, 124, 59, 123, 58, 122, 57, 121, 56, 120 |
| 142 | + ); |
| 143 | + const __m512i interleave_hi = _mm512_set_epi8( |
| 144 | + 7, 71, 6, 70, 5, 69, 4, 68, 3, 67, 2, 66, 1, 65, 0, 64, |
| 145 | + 15, 79, 14, 78, 13, 77, 12, 76, 11, 75, 10, 74, 9, 73, 8, 72, |
| 146 | + 23, 87, 22, 86, 21, 85, 20, 84, 19, 83, 18, 82, 17, 81, 16, 80, |
| 147 | + 31, 95, 30, 94, 29, 93, 28, 92, 27, 91, 26, 90, 25, 89, 24, 88 |
| 148 | + ); |
| 149 | + |
| 150 | + Py_ssize_t i = 0; |
| 151 | + |
| 152 | + /* Process 64 bytes at a time */ |
| 153 | + for (; i + 64 <= len; i += 64, dst += 128) { |
| 154 | + /* Load 64 input bytes */ |
| 155 | + __m512i data = _mm512_loadu_si512((const __m512i *)(src + i)); |
| 156 | + |
| 157 | + /* Extract high and low nibbles */ |
| 158 | + __m512i hi = _mm512_and_si512(_mm512_srli_epi16(data, 4), mask_0f); |
| 159 | + __m512i lo = _mm512_and_si512(data, mask_0f); |
| 160 | + |
| 161 | + /* Convert nibbles to hex using masked blend: |
| 162 | + if nibble > 9: use 'a' + (nibble - 10) = nibble + ('a' - 10) |
| 163 | + else: use '0' + nibble */ |
| 164 | + __mmask64 hi_alpha = _mm512_cmpgt_epi8_mask(hi, nine); |
| 165 | + __mmask64 lo_alpha = _mm512_cmpgt_epi8_mask(lo, nine); |
| 166 | + |
| 167 | + __m512i hi_digit = _mm512_add_epi8(hi, ascii_0); |
| 168 | + __m512i hi_letter = _mm512_add_epi8(hi, ascii_a); |
| 169 | + hi = _mm512_mask_blend_epi8(hi_alpha, hi_digit, hi_letter); |
| 170 | + |
| 171 | + __m512i lo_digit = _mm512_add_epi8(lo, ascii_0); |
| 172 | + __m512i lo_letter = _mm512_add_epi8(lo, ascii_a); |
| 173 | + lo = _mm512_mask_blend_epi8(lo_alpha, lo_digit, lo_letter); |
| 174 | + |
| 175 | + /* Interleave hi/lo to get correct output order using permutex2var */ |
| 176 | + __m512i result0 = _mm512_permutex2var_epi8(hi, interleave_hi, lo); |
| 177 | + __m512i result1 = _mm512_permutex2var_epi8(hi, interleave_lo, lo); |
| 178 | + |
| 179 | + /* Store 128 hex characters */ |
| 180 | + _mm512_storeu_si512((__m512i *)dst, result0); |
| 181 | + _mm512_storeu_si512((__m512i *)(dst + 64), result1); |
| 182 | + } |
| 183 | + |
| 184 | + /* Use AVX2 for remaining 32-63 bytes */ |
| 185 | + if (i + 32 <= len) { |
| 186 | + const __m256i mask_0f_256 = _mm256_set1_epi8(0x0f); |
| 187 | + const __m256i ascii_0_256 = _mm256_set1_epi8('0'); |
| 188 | + const __m256i offset_256 = _mm256_set1_epi8('a' - '0' - 10); |
| 189 | + const __m256i nine_256 = _mm256_set1_epi8(9); |
| 190 | + |
| 191 | + __m256i data = _mm256_loadu_si256((const __m256i *)(src + i)); |
| 192 | + __m256i hi = _mm256_and_si256(_mm256_srli_epi16(data, 4), mask_0f_256); |
| 193 | + __m256i lo = _mm256_and_si256(data, mask_0f_256); |
| 194 | + |
| 195 | + __m256i hi_gt9 = _mm256_cmpgt_epi8(hi, nine_256); |
| 196 | + __m256i lo_gt9 = _mm256_cmpgt_epi8(lo, nine_256); |
| 197 | + |
| 198 | + hi = _mm256_add_epi8(hi, ascii_0_256); |
| 199 | + lo = _mm256_add_epi8(lo, ascii_0_256); |
| 200 | + hi = _mm256_add_epi8(hi, _mm256_and_si256(hi_gt9, offset_256)); |
| 201 | + lo = _mm256_add_epi8(lo, _mm256_and_si256(lo_gt9, offset_256)); |
| 202 | + |
| 203 | + __m256i mixed_lo = _mm256_unpacklo_epi8(hi, lo); |
| 204 | + __m256i mixed_hi = _mm256_unpackhi_epi8(hi, lo); |
| 205 | + |
| 206 | + __m256i r0 = _mm256_permute2x128_si256(mixed_lo, mixed_hi, 0x20); |
| 207 | + __m256i r1 = _mm256_permute2x128_si256(mixed_lo, mixed_hi, 0x31); |
| 208 | + |
| 209 | + _mm256_storeu_si256((__m256i *)dst, r0); |
| 210 | + _mm256_storeu_si256((__m256i *)(dst + 32), r1); |
| 211 | + |
| 212 | + i += 32; |
| 213 | + dst += 64; |
| 214 | + } |
| 215 | + |
| 216 | + /* Scalar fallback for remaining 0-31 bytes */ |
| 217 | + for (; i < len; i++, dst += 2) { |
| 218 | + unsigned int c = src[i]; |
| 219 | + unsigned int hi = c >> 4; |
| 220 | + unsigned int lo = c & 0x0f; |
| 221 | + dst[0] = (Py_UCS1)(hi + '0' + (hi > 9) * ('a' - '0' - 10)); |
| 222 | + dst[1] = (Py_UCS1)(lo + '0' + (lo > 9) * ('a' - '0' - 10)); |
| 223 | + } |
| 224 | +} |
| 225 | + |
| 226 | +#endif /* PY_HEXLIFY_CAN_COMPILE_SIMD */ |
100 | 227 |
|
101 | 228 | static PyObject *_Py_strhex_impl(const char* argbuf, const Py_ssize_t arglen, |
102 | 229 | PyObject* sep, int bytes_per_sep_group, |
@@ -176,9 +303,13 @@ static PyObject *_Py_strhex_impl(const char* argbuf, const Py_ssize_t arglen, |
176 | 303 | unsigned char c; |
177 | 304 |
|
178 | 305 | if (bytes_per_sep_group == 0) { |
179 | | -#if PY_HEXLIFY_CAN_COMPILE_AVX2 |
180 | | - /* Use AVX2 for inputs >= 32 bytes when available */ |
181 | | - if (arglen >= 32 && _Py_hexlify_can_use_avx2()) { |
| 306 | +#if PY_HEXLIFY_CAN_COMPILE_SIMD |
| 307 | + int simd_level = _Py_hexlify_get_simd_level(); |
| 308 | + /* Use AVX-512 for inputs >= 64 bytes, AVX2 for >= 32 bytes */ |
| 309 | + if (arglen >= 64 && simd_level >= PY_HEXLIFY_SIMD_AVX512) { |
| 310 | + _Py_hexlify_avx512((const unsigned char *)argbuf, retbuf, arglen); |
| 311 | + } |
| 312 | + else if (arglen >= 32 && simd_level >= PY_HEXLIFY_SIMD_AVX2) { |
182 | 313 | _Py_hexlify_avx2((const unsigned char *)argbuf, retbuf, arglen); |
183 | 314 | } |
184 | 315 | else |
|
0 commit comments