Freiburg-AI-Research commited on
Commit
541b616
ยท
1 Parent(s): 66fcbdb

Upload 6 files

Browse files
glide_text2im/tokenizer/__init__.py ADDED
File without changes
glide_text2im/tokenizer/bpe.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte pair encoding utilities adapted from:
3
+ https://github.com/openai/gpt-2/blob/master/src/encoder.py
4
+ """
5
+
6
+ import gzip
7
+ import json
8
+ import os
9
+ from functools import lru_cache
10
+ from typing import List, Tuple
11
+
12
+ import regex as re
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = (
27
+ list(range(ord("!"), ord("~") + 1))
28
+ + list(range(ord("ยก"), ord("ยฌ") + 1))
29
+ + list(range(ord("ยฎ"), ord("รฟ") + 1))
30
+ )
31
+ cs = bs[:]
32
+ n = 0
33
+ for b in range(2 ** 8):
34
+ if b not in bs:
35
+ bs.append(b)
36
+ cs.append(2 ** 8 + n)
37
+ n += 1
38
+ cs = [chr(n) for n in cs]
39
+ return dict(zip(bs, cs))
40
+
41
+
42
+ def get_pairs(word):
43
+ """Return set of symbol pairs in a word.
44
+ Word is represented as tuple of symbols (symbols being variable-length strings).
45
+ """
46
+ pairs = set()
47
+ prev_char = word[0]
48
+ for char in word[1:]:
49
+ pairs.add((prev_char, char))
50
+ prev_char = char
51
+ return pairs
52
+
53
+
54
+ class Encoder:
55
+ def __init__(self, encoder, bpe_merges, errors="replace"):
56
+ self.encoder = encoder
57
+ self.decoder = {v: k for k, v in self.encoder.items()}
58
+ self.errors = errors # how to handle errors in decoding
59
+ self.byte_encoder = bytes_to_unicode()
60
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
61
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
62
+ self.cache = {}
63
+
64
+ # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
65
+ self.pat = re.compile(
66
+ r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
67
+ )
68
+
69
+ @property
70
+ def n_vocab(self) -> int:
71
+ return len(self.encoder)
72
+
73
+ @property
74
+ def end_token(self) -> int:
75
+ return self.n_vocab - 1
76
+
77
+ def padded_tokens_and_mask(
78
+ self, tokens: List[int], text_ctx: int
79
+ ) -> Tuple[List[int], List[bool]]:
80
+ tokens = tokens[:text_ctx]
81
+ padding = text_ctx - len(tokens)
82
+ padded_tokens = tokens + [self.end_token] * padding
83
+ mask = [True] * len(tokens) + [False] * padding
84
+ return padded_tokens, mask
85
+
86
+ def bpe(self, token):
87
+ if token in self.cache:
88
+ return self.cache[token]
89
+ word = tuple(token)
90
+ pairs = get_pairs(word)
91
+
92
+ if not pairs:
93
+ return token
94
+
95
+ while True:
96
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
97
+ if bigram not in self.bpe_ranks:
98
+ break
99
+ first, second = bigram
100
+ new_word = []
101
+ i = 0
102
+ while i < len(word):
103
+ try:
104
+ j = word.index(first, i)
105
+ new_word.extend(word[i:j])
106
+ i = j
107
+ except: # pylint: disable=bare-except
108
+ new_word.extend(word[i:])
109
+ break
110
+
111
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
112
+ new_word.append(first + second)
113
+ i += 2
114
+ else:
115
+ new_word.append(word[i])
116
+ i += 1
117
+ new_word = tuple(new_word)
118
+ word = new_word
119
+ if len(word) == 1:
120
+ break
121
+ else:
122
+ pairs = get_pairs(word)
123
+ word = " ".join(word)
124
+ self.cache[token] = word
125
+ return word
126
+
127
+ def encode(self, text):
128
+ text = text.lower()
129
+ bpe_tokens = []
130
+ for token in re.findall(self.pat, text):
131
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
132
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
133
+ return bpe_tokens
134
+
135
+ def decode(self, tokens):
136
+ text = "".join([self.decoder[token] for token in tokens])
137
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
138
+ return text
139
+
140
+
141
+ def get_encoder():
142
+ root_dir = os.path.dirname(os.path.abspath(__file__))
143
+ with gzip.open(os.path.join(root_dir, "encoder.json.gz"), "r") as f:
144
+ encoder = json.load(f)
145
+ with gzip.open(os.path.join(root_dir, "vocab.bpe.gz"), "r") as f:
146
+ bpe_data = str(f.read(), "utf-8")
147
+ bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
148
+ return Encoder(
149
+ encoder=encoder,
150
+ bpe_merges=bpe_merges,
151
+ )
glide_text2im/tokenizer/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
glide_text2im/tokenizer/encoder.json.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4debc1cf25180021b07744bc9f4488d53c7bf112c8ce5de8097c6a7518f4ec7c
3
+ size 348346
glide_text2im/tokenizer/simple_tokenizer.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from: https://github.com/openai/CLIP/blob/573315e83f07b53a61ff5098757e8fc885f1703e/clip/simple_tokenizer.py
3
+ """
4
+
5
+ import gzip
6
+ import html
7
+ import os
8
+ from functools import lru_cache
9
+ from typing import List, Tuple
10
+
11
+ import ftfy
12
+ import regex as re
13
+
14
+
15
+ @lru_cache()
16
+ def default_bpe():
17
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
18
+
19
+
20
+ @lru_cache()
21
+ def bytes_to_unicode():
22
+ """
23
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
24
+ The reversible bpe codes work on unicode strings.
25
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
26
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
27
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
28
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
29
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
30
+ """
31
+ bs = (
32
+ list(range(ord("!"), ord("~") + 1))
33
+ + list(range(ord("ยก"), ord("ยฌ") + 1))
34
+ + list(range(ord("ยฎ"), ord("รฟ") + 1))
35
+ )
36
+ cs = bs[:]
37
+ n = 0
38
+ for b in range(2 ** 8):
39
+ if b not in bs:
40
+ bs.append(b)
41
+ cs.append(2 ** 8 + n)
42
+ n += 1
43
+ cs = [chr(n) for n in cs]
44
+ return dict(zip(bs, cs))
45
+
46
+
47
+ def get_pairs(word):
48
+ """Return set of symbol pairs in a word.
49
+ Word is represented as tuple of symbols (symbols being variable-length strings).
50
+ """
51
+ pairs = set()
52
+ prev_char = word[0]
53
+ for char in word[1:]:
54
+ pairs.add((prev_char, char))
55
+ prev_char = char
56
+ return pairs
57
+
58
+
59
+ def basic_clean(text):
60
+ text = ftfy.fix_text(text)
61
+ text = html.unescape(html.unescape(text))
62
+ return text.strip()
63
+
64
+
65
+ def whitespace_clean(text):
66
+ text = re.sub(r"\s+", " ", text)
67
+ text = text.strip()
68
+ return text
69
+
70
+
71
+ class SimpleTokenizer(object):
72
+ def __init__(self, bpe_path: str = default_bpe()):
73
+ self.byte_encoder = bytes_to_unicode()
74
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
75
+ merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
76
+ merges = merges[1 : 49152 - 256 - 2 + 1]
77
+ merges = [tuple(merge.split()) for merge in merges]
78
+ vocab = list(bytes_to_unicode().values())
79
+ vocab = vocab + [v + "</w>" for v in vocab]
80
+ for merge in merges:
81
+ vocab.append("".join(merge))
82
+ vocab.extend(["<|startoftext|>", "<|endoftext|>"])
83
+ self.encoder = dict(zip(vocab, range(len(vocab))))
84
+ self.decoder = {v: k for k, v in self.encoder.items()}
85
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
86
+ self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"}
87
+ self.pat = re.compile(
88
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
89
+ re.IGNORECASE,
90
+ )
91
+
92
+ @property
93
+ def start_token(self):
94
+ return self.encoder["<|startoftext|>"]
95
+
96
+ @property
97
+ def end_token(self):
98
+ return self.encoder["<|endoftext|>"]
99
+
100
+ def padded_tokens_and_len(self, tokens: List[int], text_ctx: int) -> Tuple[List[int], int]:
101
+ tokens = [self.start_token] + tokens[: text_ctx - 2] + [self.end_token]
102
+ text_len = len(tokens)
103
+ padding = text_ctx - len(tokens)
104
+ padded_tokens = tokens + [0] * padding
105
+ return padded_tokens, text_len
106
+
107
+ def bpe(self, token):
108
+ if token in self.cache:
109
+ return self.cache[token]
110
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
111
+ pairs = get_pairs(word)
112
+
113
+ if not pairs:
114
+ return token + "</w>"
115
+
116
+ while True:
117
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
118
+ if bigram not in self.bpe_ranks:
119
+ break
120
+ first, second = bigram
121
+ new_word = []
122
+ i = 0
123
+ while i < len(word):
124
+ try:
125
+ j = word.index(first, i)
126
+ new_word.extend(word[i:j])
127
+ i = j
128
+ except: # pylint: disable=bare-except
129
+ new_word.extend(word[i:])
130
+ break
131
+
132
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
133
+ new_word.append(first + second)
134
+ i += 2
135
+ else:
136
+ new_word.append(word[i])
137
+ i += 1
138
+ new_word = tuple(new_word)
139
+ word = new_word
140
+ if len(word) == 1:
141
+ break
142
+ else:
143
+ pairs = get_pairs(word)
144
+ word = " ".join(word)
145
+ self.cache[token] = word
146
+ return word
147
+
148
+ def encode(self, text):
149
+ bpe_tokens = []
150
+ text = whitespace_clean(basic_clean(text)).lower()
151
+ for token in re.findall(self.pat, text):
152
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
153
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
154
+ return bpe_tokens
155
+
156
+ def decode(self, tokens):
157
+ text = "".join([self.decoder[token] for token in tokens])
158
+ text = (
159
+ bytearray([self.byte_decoder[c] for c in text])
160
+ .decode("utf-8", errors="replace")
161
+ .replace("</w>", " ")
162
+ )
163
+ return text
glide_text2im/tokenizer/vocab.bpe.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce239dd5a898827423fee00e3f7ab37de7900f247f2ba360753d860e8a46524d
3
+ size 213544