Skip to content

Commit 90da084

Browse files
gpsheadclaude
andcommitted
pystrhex: Add AVX-512 SIMD optimization for hex conversion
Add AVX-512 accelerated hexlify for the no-separator path when available. This processes 64 bytes per iteration using: - AVX-512F, AVX-512BW for 512-bit operations - AVX-512VBMI for efficient byte-level permutation (permutex2var_epi8) - Masked blend for branchless nibble-to-hex conversion Runtime detection via CPUID checks for all three required extensions. Falls back to AVX2 for 32-63 byte remainders, then scalar for <32 bytes. CPU hierarchy: - AVX-512 (F+BW+VBMI): 64 bytes/iteration, uses for inputs >= 64 bytes - AVX2: 32 bytes/iteration, uses for inputs >= 32 bytes - Scalar: remaining bytes Expected performance improvement over AVX2 for large inputs (4KB+) due to doubled throughput per iteration. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent bbb4a8a commit 90da084

1 file changed

Lines changed: 149 additions & 18 deletions

File tree

Python/pystrhex.c

Lines changed: 149 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,41 +4,60 @@
44
#include "pycore_strhex.h" // _Py_strhex_with_sep()
55
#include "pycore_unicodeobject.h" // _PyUnicode_CheckConsistency()
66

7-
/* AVX2 SIMD optimization for hexlify.
7+
/* SIMD optimization for hexlify.
88
Only available on x86-64 with GCC/Clang. */
99
#if defined(__x86_64__) && (defined(__GNUC__) || defined(__clang__))
10-
# define PY_HEXLIFY_CAN_COMPILE_AVX2 1
10+
# define PY_HEXLIFY_CAN_COMPILE_SIMD 1
1111
# include <cpuid.h>
1212
# include <immintrin.h>
1313
#else
14-
# define PY_HEXLIFY_CAN_COMPILE_AVX2 0
14+
# define PY_HEXLIFY_CAN_COMPILE_SIMD 0
1515
#endif
1616

17-
#if PY_HEXLIFY_CAN_COMPILE_AVX2
17+
#if PY_HEXLIFY_CAN_COMPILE_SIMD
1818

19-
/* Runtime CPU feature detection (lazy initialization) */
20-
static int _Py_hexlify_avx2_available = -1; /* -1 = not checked yet */
19+
/* Runtime CPU feature detection (lazy initialization)
20+
-1 = not checked, 0 = no SIMD, 1 = AVX2, 2 = AVX-512 */
21+
static int _Py_hexlify_simd_level = -1;
22+
23+
#define PY_HEXLIFY_SIMD_NONE 0
24+
#define PY_HEXLIFY_SIMD_AVX2 1
25+
#define PY_HEXLIFY_SIMD_AVX512 2
2126

2227
static void
2328
_Py_hexlify_detect_cpu_features(void)
2429
{
2530
unsigned int eax, ebx, ecx, edx;
2631

27-
/* Check for AVX2 support: CPUID.7H:EBX bit 5 */
28-
if (__get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx)) {
29-
_Py_hexlify_avx2_available = (ebx & (1 << 5)) != 0;
30-
} else {
31-
_Py_hexlify_avx2_available = 0;
32+
_Py_hexlify_simd_level = PY_HEXLIFY_SIMD_NONE;
33+
34+
if (!__get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx)) {
35+
return;
36+
}
37+
38+
/* Check for AVX2: CPUID.7H:EBX bit 5 */
39+
int has_avx2 = (ebx & (1 << 5)) != 0;
40+
41+
/* Check for AVX-512F + AVX-512BW + AVX-512VBMI:
42+
CPUID.7H:EBX bits 16 and 30, ECX bit 1 */
43+
int has_avx512f = (ebx & (1 << 16)) != 0;
44+
int has_avx512bw = (ebx & (1 << 30)) != 0;
45+
int has_avx512vbmi = (ecx & (1 << 1)) != 0;
46+
47+
if (has_avx512f && has_avx512bw && has_avx512vbmi) {
48+
_Py_hexlify_simd_level = PY_HEXLIFY_SIMD_AVX512;
49+
} else if (has_avx2) {
50+
_Py_hexlify_simd_level = PY_HEXLIFY_SIMD_AVX2;
3251
}
3352
}
3453

3554
static inline int
36-
_Py_hexlify_can_use_avx2(void)
55+
_Py_hexlify_get_simd_level(void)
3756
{
38-
if (_Py_hexlify_avx2_available < 0) {
57+
if (_Py_hexlify_simd_level < 0) {
3958
_Py_hexlify_detect_cpu_features();
4059
}
41-
return _Py_hexlify_avx2_available;
60+
return _Py_hexlify_simd_level;
4261
}
4362

4463
/* AVX2-accelerated hexlify: converts 32 bytes to 64 hex chars per iteration.
@@ -96,7 +115,115 @@ _Py_hexlify_avx2(const unsigned char *src, Py_UCS1 *dst, Py_ssize_t len)
96115
}
97116
}
98117

99-
#endif /* PY_HEXLIFY_CAN_COMPILE_AVX2 */
118+
/* AVX-512 accelerated hexlify: converts 64 bytes to 128 hex chars per iteration.
119+
Requires AVX-512F, AVX-512BW, and AVX-512VBMI for byte-level permutation. */
120+
__attribute__((target("avx512f,avx512bw,avx512vbmi")))
121+
static void
122+
_Py_hexlify_avx512(const unsigned char *src, Py_UCS1 *dst, Py_ssize_t len)
123+
{
124+
const __m512i mask_0f = _mm512_set1_epi8(0x0f);
125+
const __m512i ascii_0 = _mm512_set1_epi8('0');
126+
const __m512i ascii_a = _mm512_set1_epi8('a' - 10);
127+
const __m512i nine = _mm512_set1_epi8(9);
128+
129+
/* Permutation indices for interleaving hi/lo nibbles.
130+
We need to transform:
131+
hi: H0 H1 H2 ... H63
132+
lo: L0 L1 L2 ... L63
133+
into:
134+
out0: H0 L0 H1 L1 ... H31 L31
135+
out1: H32 L32 H33 L33 ... H63 L63
136+
*/
137+
const __m512i interleave_lo = _mm512_set_epi8(
138+
39, 103, 38, 102, 37, 101, 36, 100, 35, 99, 34, 98, 33, 97, 32, 96,
139+
47, 111, 46, 110, 45, 109, 44, 108, 43, 107, 42, 106, 41, 105, 40, 104,
140+
55, 119, 54, 118, 53, 117, 52, 116, 51, 115, 50, 114, 49, 113, 48, 112,
141+
63, 127, 62, 126, 61, 125, 60, 124, 59, 123, 58, 122, 57, 121, 56, 120
142+
);
143+
const __m512i interleave_hi = _mm512_set_epi8(
144+
7, 71, 6, 70, 5, 69, 4, 68, 3, 67, 2, 66, 1, 65, 0, 64,
145+
15, 79, 14, 78, 13, 77, 12, 76, 11, 75, 10, 74, 9, 73, 8, 72,
146+
23, 87, 22, 86, 21, 85, 20, 84, 19, 83, 18, 82, 17, 81, 16, 80,
147+
31, 95, 30, 94, 29, 93, 28, 92, 27, 91, 26, 90, 25, 89, 24, 88
148+
);
149+
150+
Py_ssize_t i = 0;
151+
152+
/* Process 64 bytes at a time */
153+
for (; i + 64 <= len; i += 64, dst += 128) {
154+
/* Load 64 input bytes */
155+
__m512i data = _mm512_loadu_si512((const __m512i *)(src + i));
156+
157+
/* Extract high and low nibbles */
158+
__m512i hi = _mm512_and_si512(_mm512_srli_epi16(data, 4), mask_0f);
159+
__m512i lo = _mm512_and_si512(data, mask_0f);
160+
161+
/* Convert nibbles to hex using masked blend:
162+
if nibble > 9: use 'a' + (nibble - 10) = nibble + ('a' - 10)
163+
else: use '0' + nibble */
164+
__mmask64 hi_alpha = _mm512_cmpgt_epi8_mask(hi, nine);
165+
__mmask64 lo_alpha = _mm512_cmpgt_epi8_mask(lo, nine);
166+
167+
__m512i hi_digit = _mm512_add_epi8(hi, ascii_0);
168+
__m512i hi_letter = _mm512_add_epi8(hi, ascii_a);
169+
hi = _mm512_mask_blend_epi8(hi_alpha, hi_digit, hi_letter);
170+
171+
__m512i lo_digit = _mm512_add_epi8(lo, ascii_0);
172+
__m512i lo_letter = _mm512_add_epi8(lo, ascii_a);
173+
lo = _mm512_mask_blend_epi8(lo_alpha, lo_digit, lo_letter);
174+
175+
/* Interleave hi/lo to get correct output order using permutex2var */
176+
__m512i result0 = _mm512_permutex2var_epi8(hi, interleave_hi, lo);
177+
__m512i result1 = _mm512_permutex2var_epi8(hi, interleave_lo, lo);
178+
179+
/* Store 128 hex characters */
180+
_mm512_storeu_si512((__m512i *)dst, result0);
181+
_mm512_storeu_si512((__m512i *)(dst + 64), result1);
182+
}
183+
184+
/* Use AVX2 for remaining 32-63 bytes */
185+
if (i + 32 <= len) {
186+
const __m256i mask_0f_256 = _mm256_set1_epi8(0x0f);
187+
const __m256i ascii_0_256 = _mm256_set1_epi8('0');
188+
const __m256i offset_256 = _mm256_set1_epi8('a' - '0' - 10);
189+
const __m256i nine_256 = _mm256_set1_epi8(9);
190+
191+
__m256i data = _mm256_loadu_si256((const __m256i *)(src + i));
192+
__m256i hi = _mm256_and_si256(_mm256_srli_epi16(data, 4), mask_0f_256);
193+
__m256i lo = _mm256_and_si256(data, mask_0f_256);
194+
195+
__m256i hi_gt9 = _mm256_cmpgt_epi8(hi, nine_256);
196+
__m256i lo_gt9 = _mm256_cmpgt_epi8(lo, nine_256);
197+
198+
hi = _mm256_add_epi8(hi, ascii_0_256);
199+
lo = _mm256_add_epi8(lo, ascii_0_256);
200+
hi = _mm256_add_epi8(hi, _mm256_and_si256(hi_gt9, offset_256));
201+
lo = _mm256_add_epi8(lo, _mm256_and_si256(lo_gt9, offset_256));
202+
203+
__m256i mixed_lo = _mm256_unpacklo_epi8(hi, lo);
204+
__m256i mixed_hi = _mm256_unpackhi_epi8(hi, lo);
205+
206+
__m256i r0 = _mm256_permute2x128_si256(mixed_lo, mixed_hi, 0x20);
207+
__m256i r1 = _mm256_permute2x128_si256(mixed_lo, mixed_hi, 0x31);
208+
209+
_mm256_storeu_si256((__m256i *)dst, r0);
210+
_mm256_storeu_si256((__m256i *)(dst + 32), r1);
211+
212+
i += 32;
213+
dst += 64;
214+
}
215+
216+
/* Scalar fallback for remaining 0-31 bytes */
217+
for (; i < len; i++, dst += 2) {
218+
unsigned int c = src[i];
219+
unsigned int hi = c >> 4;
220+
unsigned int lo = c & 0x0f;
221+
dst[0] = (Py_UCS1)(hi + '0' + (hi > 9) * ('a' - '0' - 10));
222+
dst[1] = (Py_UCS1)(lo + '0' + (lo > 9) * ('a' - '0' - 10));
223+
}
224+
}
225+
226+
#endif /* PY_HEXLIFY_CAN_COMPILE_SIMD */
100227

101228
static PyObject *_Py_strhex_impl(const char* argbuf, const Py_ssize_t arglen,
102229
PyObject* sep, int bytes_per_sep_group,
@@ -176,9 +303,13 @@ static PyObject *_Py_strhex_impl(const char* argbuf, const Py_ssize_t arglen,
176303
unsigned char c;
177304

178305
if (bytes_per_sep_group == 0) {
179-
#if PY_HEXLIFY_CAN_COMPILE_AVX2
180-
/* Use AVX2 for inputs >= 32 bytes when available */
181-
if (arglen >= 32 && _Py_hexlify_can_use_avx2()) {
306+
#if PY_HEXLIFY_CAN_COMPILE_SIMD
307+
int simd_level = _Py_hexlify_get_simd_level();
308+
/* Use AVX-512 for inputs >= 64 bytes, AVX2 for >= 32 bytes */
309+
if (arglen >= 64 && simd_level >= PY_HEXLIFY_SIMD_AVX512) {
310+
_Py_hexlify_avx512((const unsigned char *)argbuf, retbuf, arglen);
311+
}
312+
else if (arglen >= 32 && simd_level >= PY_HEXLIFY_SIMD_AVX2) {
182313
_Py_hexlify_avx2((const unsigned char *)argbuf, retbuf, arglen);
183314
}
184315
else

0 commit comments

Comments
 (0)