Skip to content

Commit b43272a

Browse files
authored
Unicode codepoint flags for custom regexs (ggml-org#7245)
* Replace CODEPOINT_TYPE_* with codepoint_flags * Update and bugfix brute force random test * Deterministic brute force random test * Unicode normalization NFD * Get rid of BOM
1 parent 0fc1e82 commit b43272a

7 files changed

+7297
-2407
lines changed

llama.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -12576,16 +12576,16 @@ struct llm_tokenizer_wpm {
1257612576
// to lowercase, pad chinese characters, pad punctuation
1257712577
std::string new_str = "";
1257812578
for (uint32_t code : cpts_nfd) {
12579-
int type = unicode_cpt_type(code);
12580-
if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) {
12579+
const codepoint_flags flags = unicode_cpt_flags(code);
12580+
if (flags.is_accent_mark || flags.is_control) {
1258112581
continue;
1258212582
}
1258312583
code = unicode_tolower(code);
12584-
if (type == CODEPOINT_TYPE_SEPARATOR) {
12584+
if (flags.is_separator || flags.is_whitespace) { //####FIXME: is_separator ?
1258512585
code = ' ';
1258612586
}
1258712587
std::string s = unicode_cpt_to_utf8(code);
12588-
if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) {
12588+
if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) {
1258912589
new_str += " ";
1259012590
new_str += s;
1259112591
new_str += " ";

scripts/gen-unicode-data.py

+116-46
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,134 @@
11
import regex
2+
import ctypes
3+
import unicodedata
24

35

4-
def get_matches(regex_expr):
5-
regex_expr_compiled = regex.compile(regex_expr)
6-
unicode_ranges = []
7-
current_range = None
6+
class CoodepointFlags (ctypes.Structure):
7+
_fields_ = [ # see definition in unicode.h
8+
("is_undefined", ctypes.c_uint16, 1),
9+
("is_number", ctypes.c_uint16, 1), # regex: \p{N}
10+
("is_letter", ctypes.c_uint16, 1), # regex: \p{L}
11+
("is_separator", ctypes.c_uint16, 1), # regex: \p{Z}
12+
("is_accent_mark", ctypes.c_uint16, 1), # regex: \p{M}
13+
("is_punctuation", ctypes.c_uint16, 1), # regex: \p{P}
14+
("is_symbol", ctypes.c_uint16, 1), # regex: \p{S}
15+
("is_control", ctypes.c_uint16, 1), # regex: \p{C}
16+
]
817

9-
for codepoint in range(0x110000):
10-
char = chr(codepoint)
11-
if regex_expr_compiled.match(char):
12-
if current_range is None:
13-
current_range = [codepoint, codepoint]
14-
else:
15-
current_range[1] = codepoint
16-
elif current_range is not None:
17-
unicode_ranges.append(tuple(current_range))
18-
current_range = None
1918

20-
if current_range is not None:
21-
unicode_ranges.append(tuple(current_range))
19+
assert (ctypes.sizeof(CoodepointFlags) == 2)
2220

23-
return unicode_ranges
2421

22+
MAX_CODEPOINTS = 0x110000
2523

26-
def print_cat(mode, cat, ranges):
27-
if mode == "range":
28-
print("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_{} = {{".format(cat)) # noqa: NP100
29-
if mode == "map":
30-
print("const std::map<uint32_t, uint32_t> unicode_map_{} = {{".format(cat)) # noqa: NP100
31-
for i, values in enumerate(ranges):
32-
end = ",\n" if (i % 4 == 3 or i + 1 == len(ranges)) else ", "
33-
values = ["0x%08X" % value for value in values]
34-
print("{" + ", ".join(values) + "}", end=end) # noqa: NP100
35-
print("};") # noqa: NP100
36-
print("") # noqa: NP100
24+
regex_number = regex.compile(r'\p{N}')
25+
regex_letter = regex.compile(r'\p{L}')
26+
regex_separator = regex.compile(r'\p{Z}')
27+
regex_accent_mark = regex.compile(r'\p{M}')
28+
regex_punctuation = regex.compile(r'\p{P}')
29+
regex_symbol = regex.compile(r'\p{S}')
30+
regex_control = regex.compile(r'\p{C}')
31+
regex_whitespace = regex.compile(r'\s')
3732

33+
codepoint_flags = (CoodepointFlags * MAX_CODEPOINTS)()
34+
table_whitespace = []
35+
table_lowercase = []
36+
table_uppercase = []
37+
table_nfd = []
3838

39-
print_cat("range", "number", get_matches(r'\p{N}'))
40-
print_cat("range", "letter", get_matches(r'\p{L}'))
41-
print_cat("range", "separator", get_matches(r'\p{Z}'))
42-
print_cat("range", "accent_mark", get_matches(r'\p{M}'))
43-
print_cat("range", "punctuation", get_matches(r'\p{P}'))
44-
print_cat("range", "symbol", get_matches(r'\p{S}'))
45-
print_cat("range", "control", get_matches(r'\p{C}'))
39+
for codepoint in range(MAX_CODEPOINTS):
40+
# convert codepoint to unicode character
41+
char = chr(codepoint)
4642

47-
print_cat("range", "whitespace", get_matches(r'\s'))
43+
# regex categories
44+
flags = codepoint_flags[codepoint]
45+
flags.is_number = bool(regex_number.match(char))
46+
flags.is_letter = bool(regex_letter.match(char))
47+
flags.is_separator = bool(regex_separator.match(char))
48+
flags.is_accent_mark = bool(regex_accent_mark.match(char))
49+
flags.is_punctuation = bool(regex_punctuation.match(char))
50+
flags.is_symbol = bool(regex_symbol.match(char))
51+
flags.is_control = bool(regex_control.match(char))
52+
flags.is_undefined = bytes(flags)[0] == 0
53+
assert (not flags.is_undefined)
4854

55+
# whitespaces
56+
if bool(regex_whitespace.match(char)):
57+
table_whitespace.append(codepoint)
4958

50-
map_lowercase = []
51-
map_uppercase = []
52-
for codepoint in range(0x110000):
53-
char = chr(codepoint)
59+
# lowercase conversion
5460
lower = ord(char.lower()[0])
55-
upper = ord(char.upper()[0])
5661
if codepoint != lower:
57-
map_lowercase.append((codepoint, lower))
62+
table_lowercase.append((codepoint, lower))
63+
64+
# uppercase conversion
65+
upper = ord(char.upper()[0])
5866
if codepoint != upper:
59-
map_uppercase.append((codepoint, upper))
60-
print_cat("map", "lowercase", map_lowercase)
61-
print_cat("map", "uppercase", map_uppercase)
67+
table_uppercase.append((codepoint, upper))
68+
69+
# NFD normalization
70+
norm = ord(unicodedata.normalize('NFD', char)[0])
71+
if codepoint != norm:
72+
table_nfd.append((codepoint, norm))
73+
74+
75+
# group ranges with same flags
76+
ranges_flags = [(0, codepoint_flags[0])] # start, flags
77+
for codepoint, flags in enumerate(codepoint_flags):
78+
if bytes(flags) != bytes(ranges_flags[-1][1]):
79+
ranges_flags.append((codepoint, flags))
80+
ranges_flags.append((MAX_CODEPOINTS, CoodepointFlags()))
81+
82+
83+
# group ranges with same nfd
84+
ranges_nfd = [(0, 0, 0)] # start, last, nfd
85+
for codepoint, norm in table_nfd:
86+
start = ranges_nfd[-1][0]
87+
if ranges_nfd[-1] != (start, codepoint - 1, norm):
88+
ranges_nfd.append(None)
89+
start = codepoint
90+
ranges_nfd[-1] = (start, codepoint, norm)
91+
92+
93+
# Generate 'unicode-data.cpp'
94+
95+
96+
def out(line=""):
97+
print(line, end='\n') # noqa
98+
99+
100+
out("""\
101+
// generated with scripts/gen-unicode-data.py
102+
103+
#include "unicode-data.h"
104+
105+
#include <cstdint>
106+
#include <vector>
107+
#include <unordered_map>
108+
#include <unordered_set>
109+
""")
110+
111+
out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
112+
for codepoint, flags in ranges_flags:
113+
flags = int.from_bytes(bytes(flags), "little")
114+
out("{0x%06X, 0x%04X}," % (codepoint, flags))
115+
out("};\n")
116+
117+
out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
118+
out(", ".join("0x%06X" % cpt for cpt in table_whitespace))
119+
out("};\n")
120+
121+
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
122+
for tuple in table_lowercase:
123+
out("{0x%06X, 0x%06X}," % tuple)
124+
out("};\n")
62125

126+
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase = {")
127+
for tuple in table_uppercase:
128+
out("{0x%06X, 0x%06X}," % tuple)
129+
out("};\n")
63130

64-
# TODO: generate unicode_map_nfd
131+
out("const std::vector<range_nfd> unicode_ranges_nfd = { // start, last, nfd")
132+
for triple in ranges_nfd:
133+
out("{0x%06X, 0x%06X, 0x%06X}," % triple)
134+
out("};\n")

0 commit comments

Comments
 (0)