Skip to content

Commit 7081e72

Browse files
authored
Add SWAR versions of Base validations (#15357)
1 parent 3e3ce13 commit 7081e72

1 file changed

Lines changed: 243 additions & 74 deletions

File tree

lib/elixir/lib/base.ex

Lines changed: 243 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,220 @@ defmodule Base do
155155
for <<char::8 <- string>>, char not in ~c"\s\t\r\n", into: <<>>, do: <<char::8>>
156156
end
157157

158+
# SWAR (SIMD Within A Register) fast paths for valid16?/2 and valid32?/2
159+
# (non-hex). Each chunk of 8 bytes is validated in one guard: 7 bytes via
160+
# bitwise arithmetic on a single 56-bit integer, plus a per-byte range
161+
# check for the 8th byte. 56 bits is the largest width that fits in a BEAM
162+
# small int on 64-bit (fixnum range is 59-bit signed); at 64 bits every
163+
# `w + 0x80..` would allocate a bignum on the heap and the optimisation
164+
# would collapse. See https://github.com/erlang/otp/pull/10938 for the
165+
# corresponding pattern in OTP.
166+
@swar_mask80 0x80808080808080
167+
168+
# Per-range SWAR constants, broadcast across 7 lanes. Naming convention:
169+
# @swar_ge_X = 0x80 - X → high bit of `(w + @swar_ge_X)` lane is set
170+
# iff that byte is ≥ X
171+
# @swar_gt_X = 0x7F - X → high bit of `(w + @swar_gt_X)` lane is set
172+
# iff that byte is > X
173+
# A byte is in range [lo, hi] iff
174+
# `bxor(w + @swar_ge_lo, w + @swar_gt_hi)` has its high bit set.
175+
@swar_ge_0 0x50505050505050
176+
@swar_gt_9 0x46464646464646
177+
@swar_ge_2 0x4E4E4E4E4E4E4E
178+
@swar_gt_7 0x48484848484848
179+
@swar_ge_A 0x3F3F3F3F3F3F3F
180+
@swar_gt_F 0x39393939393939
181+
@swar_gt_V 0x29292929292929
182+
@swar_gt_Z 0x25252525252525
183+
@swar_ge_a 0x1F1F1F1F1F1F1F
184+
@swar_gt_f 0x19191919191919
185+
@swar_gt_v 0x09090909090909
186+
@swar_gt_z 0x05050505050505
187+
188+
# For base64 standard, '/' (0x2F) sits exactly one below '0' (0x30), so we
189+
# extend the digit range to [0x2F, 0x39], which absorbs '/' into one range
190+
# check — saves one Mycroft singleton. Trick lifted from
191+
# https://lemire.me/blog/2025/04/13/detect-control-characters-quotes-and-backslashes-efficiently-using-swar/
192+
@swar_ge_slash 0x51515151515151
193+
194+
# Mycroft zero-byte detection for base64 singletons (+, -, _).
195+
# Per lane: high bit set iff `bxor(w, K*ones) - 0x01..01` has its high bit
196+
# set, i.e. that byte's V value was 0 → original byte was K. Simplified
197+
# (no `bnot V` term) — for ASCII-gated `w`, borrow-propagation false
198+
# positives only occur for adjacent bytes that happen to equal `K xor 0x01`,
199+
# which is outside the base64 alphabet, so it never matters here.
200+
# Pattern follows https://github.com/elixir-lang/elixir/pull/15255.
201+
@swar_mask01 0x01010101010101
202+
@swar_plus_x7 0x2B2B2B2B2B2B2B
203+
@swar_dash_x7 0x2D2D2D2D2D2D2D
204+
@swar_under_x7 0x5F5F5F5F5F5F5F
205+
206+
# Per-byte validity guards (used in both the SWAR clauses for the 8th byte
207+
# of each stride and in the body of the sub-8-byte tail clauses).
208+
defguardp valid_char16upper?(c) when c in ?0..?9 or c in ?A..?F
209+
defguardp valid_char16lower?(c) when c in ?0..?9 or c in ?a..?f
210+
defguardp valid_char16mixed?(c) when c in ?0..?9 or c in ?A..?F or c in ?a..?f
211+
212+
defguardp valid_char32upper?(c) when c in ?A..?Z or c in ?2..?7
213+
defguardp valid_char32lower?(c) when c in ?a..?z or c in ?2..?7
214+
defguardp valid_char32mixed?(c) when c in ?A..?Z or c in ?a..?z or c in ?2..?7
215+
216+
# Most common range first — letters dominate (22/32) over digits (10/32)
217+
# in hex base32, so letters go first in the OR short-circuit.
218+
defguardp valid_char32hexupper?(c) when c in ?A..?V or c in ?0..?9
219+
defguardp valid_char32hexlower?(c) when c in ?a..?v or c in ?0..?9
220+
defguardp valid_char32hexmixed?(c) when c in ?A..?V or c in ?a..?v or c in ?0..?9
221+
222+
defguardp valid_char64base?(c)
223+
when c in ?A..?Z or c in ?a..?z or c in ?0..?9 or c == ?+ or c == ?/
224+
225+
defguardp valid_char64url?(c)
226+
when c in ?A..?Z or c in ?a..?z or c in ?0..?9 or c == ?- or c == ?_
227+
228+
# SWAR 7-byte word validity. Structure for each guard:
229+
# 1. ASCII gate `band(w, MASK80) == 0` — every byte < 0x80 so the
230+
# additions below cannot carry across lanes.
231+
# 2. "Each byte is in range A OR range B (OR range C)" gate — OR per-
232+
# range XOR masks (high bit set in lane iff byte in that range), AND
233+
# with MASK80, demand all 7 high bits set.
234+
defguardp valid_word16upper?(w)
235+
when band(w, @swar_mask80) == 0 and
236+
band(
237+
bor(
238+
bxor(w + @swar_ge_0, w + @swar_gt_9),
239+
bxor(w + @swar_ge_A, w + @swar_gt_F)
240+
),
241+
@swar_mask80
242+
) == @swar_mask80
243+
244+
defguardp valid_word16lower?(w)
245+
when band(w, @swar_mask80) == 0 and
246+
band(
247+
bor(
248+
bxor(w + @swar_ge_0, w + @swar_gt_9),
249+
bxor(w + @swar_ge_a, w + @swar_gt_f)
250+
),
251+
@swar_mask80
252+
) == @swar_mask80
253+
254+
defguardp valid_word16mixed?(w)
255+
when band(w, @swar_mask80) == 0 and
256+
band(
257+
bor(
258+
bor(
259+
bxor(w + @swar_ge_0, w + @swar_gt_9),
260+
bxor(w + @swar_ge_A, w + @swar_gt_F)
261+
),
262+
bxor(w + @swar_ge_a, w + @swar_gt_f)
263+
),
264+
@swar_mask80
265+
) == @swar_mask80
266+
267+
defguardp valid_word32upper?(w)
268+
when band(w, @swar_mask80) == 0 and
269+
band(
270+
bor(
271+
bxor(w + @swar_ge_A, w + @swar_gt_Z),
272+
bxor(w + @swar_ge_2, w + @swar_gt_7)
273+
),
274+
@swar_mask80
275+
) == @swar_mask80
276+
277+
defguardp valid_word32lower?(w)
278+
when band(w, @swar_mask80) == 0 and
279+
band(
280+
bor(
281+
bxor(w + @swar_ge_a, w + @swar_gt_z),
282+
bxor(w + @swar_ge_2, w + @swar_gt_7)
283+
),
284+
@swar_mask80
285+
) == @swar_mask80
286+
287+
defguardp valid_word32mixed?(w)
288+
when band(w, @swar_mask80) == 0 and
289+
band(
290+
bor(
291+
bor(
292+
bxor(w + @swar_ge_A, w + @swar_gt_Z),
293+
bxor(w + @swar_ge_a, w + @swar_gt_z)
294+
),
295+
bxor(w + @swar_ge_2, w + @swar_gt_7)
296+
),
297+
@swar_mask80
298+
) == @swar_mask80
299+
300+
defguardp valid_word32hexupper?(w)
301+
when band(w, @swar_mask80) == 0 and
302+
band(
303+
bor(
304+
bxor(w + @swar_ge_0, w + @swar_gt_9),
305+
bxor(w + @swar_ge_A, w + @swar_gt_V)
306+
),
307+
@swar_mask80
308+
) == @swar_mask80
309+
310+
defguardp valid_word32hexlower?(w)
311+
when band(w, @swar_mask80) == 0 and
312+
band(
313+
bor(
314+
bxor(w + @swar_ge_0, w + @swar_gt_9),
315+
bxor(w + @swar_ge_a, w + @swar_gt_v)
316+
),
317+
@swar_mask80
318+
) == @swar_mask80
319+
320+
defguardp valid_word32hexmixed?(w)
321+
when band(w, @swar_mask80) == 0 and
322+
band(
323+
bor(
324+
bor(
325+
bxor(w + @swar_ge_0, w + @swar_gt_9),
326+
bxor(w + @swar_ge_A, w + @swar_gt_V)
327+
),
328+
bxor(w + @swar_ge_a, w + @swar_gt_v)
329+
),
330+
@swar_mask80
331+
) == @swar_mask80
332+
333+
# base64 SWAR word validity: 3 ranges (A-Z, a-z, 0-9) OR'd with singletons.
334+
# For base, the digit range is extended to [0x2F, 0x39] to absorb '/' as
335+
# part of one range (Lemire merge), leaving only '+' as a Mycroft singleton.
336+
# For url, the singletons '-' and '_' are detected via two Mycroft terms.
337+
defguardp valid_word64base?(w)
338+
when band(w, @swar_mask80) == 0 and
339+
band(
340+
bor(
341+
bor(
342+
bor(
343+
bxor(w + @swar_ge_A, w + @swar_gt_Z),
344+
bxor(w + @swar_ge_a, w + @swar_gt_z)
345+
),
346+
bxor(w + @swar_ge_slash, w + @swar_gt_9)
347+
),
348+
bxor(w, @swar_plus_x7) - @swar_mask01
349+
),
350+
@swar_mask80
351+
) == @swar_mask80
352+
353+
defguardp valid_word64url?(w)
354+
when band(w, @swar_mask80) == 0 and
355+
band(
356+
bor(
357+
bor(
358+
bor(
359+
bxor(w + @swar_ge_A, w + @swar_gt_Z),
360+
bxor(w + @swar_ge_a, w + @swar_gt_z)
361+
),
362+
bxor(w + @swar_ge_0, w + @swar_gt_9)
363+
),
364+
bor(
365+
bxor(w, @swar_dash_x7) - @swar_mask01,
366+
bxor(w, @swar_under_x7) - @swar_mask01
367+
)
368+
),
369+
@swar_mask80
370+
) == @swar_mask80
371+
158372
@doc """
159373
Encodes a binary string into a base 16 encoded string.
160374
@@ -371,45 +585,21 @@ defmodule Base do
371585
decode_name = :"decode16#{base}!"
372586
validate_name = :"validate16#{base}?"
373587
valid_char_name = :"valid_char16#{base}?"
588+
valid_word_name = :"valid_word16#{base}?"
374589

375590
{min, decoded} = to_decode_list.(alphabet)
376591

377-
defp unquote(validate_name)(<<>>), do: true
378-
379-
defp unquote(validate_name)(<<c1, c2, c3, c4, c5, c6, c7, c8, rest::binary>>) do
380-
unquote(valid_char_name)(c1) and
381-
unquote(valid_char_name)(c2) and
382-
unquote(valid_char_name)(c3) and
383-
unquote(valid_char_name)(c4) and
384-
unquote(valid_char_name)(c5) and
385-
unquote(valid_char_name)(c6) and
386-
unquote(valid_char_name)(c7) and
387-
unquote(valid_char_name)(c8) and
388-
unquote(validate_name)(rest)
389-
end
390-
391-
defp unquote(validate_name)(<<c1, c2, c3, c4, rest::binary>>) do
392-
unquote(valid_char_name)(c1) and
393-
unquote(valid_char_name)(c2) and
394-
unquote(valid_char_name)(c3) and
395-
unquote(valid_char_name)(c4) and
396-
unquote(validate_name)(rest)
397-
end
398-
399-
defp unquote(validate_name)(<<c1, c2, rest::binary>>) do
400-
unquote(valid_char_name)(c1) and
401-
unquote(valid_char_name)(c2) and
402-
unquote(validate_name)(rest)
403-
end
592+
# SWAR fast path: 7 bytes per stride, validated entirely via
593+
# `valid_word16<base>?` in the body. The `and` short-circuits when SWAR
594+
# fails on any byte. Tail bytes (1-6 leftover) recurse through the
595+
# single-byte clause below.
596+
defp unquote(validate_name)(<<w::56, rest::binary>>),
597+
do: unquote(valid_word_name)(w) and unquote(validate_name)(rest)
404598

405-
defp unquote(validate_name)(<<_char, _rest::binary>>), do: false
406-
407-
@compile {:inline, [{valid_char_name, 1}]}
408-
defp unquote(valid_char_name)(char)
409-
when elem({unquote_splicing(decoded)}, char - unquote(min)) != nil,
410-
do: true
599+
defp unquote(validate_name)(<<>>), do: true
411600

412-
defp unquote(valid_char_name)(_char), do: false
601+
defp unquote(validate_name)(<<char, rest::binary>>),
602+
do: unquote(valid_char_name)(char) and unquote(validate_name)(rest)
413603

414604
defp unquote(decode_name)(char) do
415605
index = char - unquote(min)
@@ -781,23 +971,19 @@ defmodule Base do
781971
validate_name = :"validate64#{base}?"
782972
validate_main_name = :"validate_main64#{validate_name}?"
783973
valid_char_name = :"valid_char64#{base}?"
974+
valid_word_name = :"valid_word64#{base}?"
784975
{min, decoded} = alphabet |> Enum.with_index() |> to_decode_list.()
785976

977+
# SWAR fast path: 7 bytes per stride, validated via `valid_word64<base>?`
978+
# in the body. Tail leftover (1-6 bytes after a 7-byte stride hits an
979+
# 8-byte-multiple `main`) recurses through the single-byte clause.
980+
defp unquote(validate_main_name)(<<w::56, rest::binary>>),
981+
do: unquote(valid_word_name)(w) and unquote(validate_main_name)(rest)
982+
786983
defp unquote(validate_main_name)(<<>>), do: true
787984

788-
defp unquote(validate_main_name)(
789-
<<c1::8, c2::8, c3::8, c4::8, c5::8, c6::8, c7::8, c8::8, rest::binary>>
790-
) do
791-
unquote(valid_char_name)(c1) and
792-
unquote(valid_char_name)(c2) and
793-
unquote(valid_char_name)(c3) and
794-
unquote(valid_char_name)(c4) and
795-
unquote(valid_char_name)(c5) and
796-
unquote(valid_char_name)(c6) and
797-
unquote(valid_char_name)(c7) and
798-
unquote(valid_char_name)(c8) and
799-
unquote(validate_main_name)(rest)
800-
end
985+
defp unquote(validate_main_name)(<<char, rest::binary>>),
986+
do: unquote(valid_char_name)(char) and unquote(validate_main_name)(rest)
801987

802988
defp unquote(validate_name)(<<>>, _pad?), do: true
803989

@@ -883,13 +1069,6 @@ defmodule Base do
8831069
end
8841070
end
8851071

886-
@compile {:inline, [{valid_char_name, 1}]}
887-
defp unquote(valid_char_name)(char)
888-
when elem({unquote_splicing(decoded)}, char - unquote(min)) != nil,
889-
do: true
890-
891-
defp unquote(valid_char_name)(_char), do: false
892-
8931072
defp unquote(decode_name)(char) do
8941073
index = char - unquote(min)
8951074

@@ -1445,21 +1624,18 @@ defmodule Base do
14451624
valid_char_name = :"valid_char32#{base}?"
14461625
{min, decoded} = to_decode_list.(alphabet)
14471626

1627+
# SWAR fast path: 7 bytes per stride, validated via `valid_word32<base>?`
1628+
# in the body. Tail leftover (1-6 bytes after a 7-byte stride hits an
1629+
# 8-byte-multiple `main`) recurses through the single-byte clause.
1630+
valid_word_name = :"valid_word32#{base}?"
1631+
1632+
defp unquote(validate_main_name)(<<w::56, rest::binary>>),
1633+
do: unquote(valid_word_name)(w) and unquote(validate_main_name)(rest)
1634+
14481635
defp unquote(validate_main_name)(<<>>), do: true
14491636

1450-
defp unquote(validate_main_name)(
1451-
<<c1::8, c2::8, c3::8, c4::8, c5::8, c6::8, c7::8, c8::8, rest::binary>>
1452-
) do
1453-
unquote(valid_char_name)(c1) and
1454-
unquote(valid_char_name)(c2) and
1455-
unquote(valid_char_name)(c3) and
1456-
unquote(valid_char_name)(c4) and
1457-
unquote(valid_char_name)(c5) and
1458-
unquote(valid_char_name)(c6) and
1459-
unquote(valid_char_name)(c7) and
1460-
unquote(valid_char_name)(c8) and
1461-
unquote(validate_main_name)(rest)
1462-
end
1637+
defp unquote(validate_main_name)(<<char, rest::binary>>),
1638+
do: unquote(valid_char_name)(char) and unquote(validate_main_name)(rest)
14631639

14641640
defp unquote(validate_name)(<<>>, _pad?), do: true
14651641

@@ -1539,13 +1715,6 @@ defmodule Base do
15391715
end
15401716
end
15411717

1542-
@compile {:inline, [{valid_char_name, 1}]}
1543-
defp unquote(valid_char_name)(char)
1544-
when elem({unquote_splicing(decoded)}, char - unquote(min)) != nil,
1545-
do: true
1546-
1547-
defp unquote(valid_char_name)(_char), do: false
1548-
15491718
defp unquote(decode_name)(char) do
15501719
index = char - unquote(min)
15511720

0 commit comments

Comments
 (0)