Spaces:
Runtime error
Runtime error
🔊 add logs
Browse filesSigned-off-by: peter szemraj <[email protected]>
- constrained_generation.py +7 -1
- converse.py +1 -0
- utils.py +13 -0
constrained_generation.py
CHANGED
@@ -4,6 +4,7 @@
|
|
4 |
|
5 |
import copy
|
6 |
import logging
|
|
|
7 |
logging.basicConfig(level=logging.INFO)
|
8 |
import time
|
9 |
from pathlib import Path
|
@@ -11,6 +12,7 @@ from pathlib import Path
|
|
11 |
import yake
|
12 |
from transformers import AutoTokenizer, PhrasalConstraint
|
13 |
|
|
|
14 |
def get_tokenizer(model_name="gpt2", verbose=False):
|
15 |
"""
|
16 |
get_tokenizer - returns a tokenizer object
|
@@ -164,6 +166,8 @@ def constrained_generation(
|
|
164 |
-------
|
165 |
response : str, generated text
|
166 |
"""
|
|
|
|
|
167 |
st = time.perf_counter()
|
168 |
tokenizer = tokenizer or copy.deepcopy(pipeline.tokenizer)
|
169 |
tokenizer.add_prefix_space = True
|
@@ -228,7 +232,9 @@ def constrained_generation(
|
|
228 |
force_words_ids=force_words_ids if force_flexible is not None else None,
|
229 |
max_length=None,
|
230 |
max_new_tokens=max_generated_tokens,
|
231 |
-
min_length=min_generated_tokens + prompt_length
|
|
|
|
|
232 |
num_beams=num_beams,
|
233 |
no_repeat_ngram_size=no_repeat_ngram_size,
|
234 |
num_return_sequences=num_return_sequences,
|
|
|
4 |
|
5 |
import copy
|
6 |
import logging
|
7 |
+
|
8 |
logging.basicConfig(level=logging.INFO)
|
9 |
import time
|
10 |
from pathlib import Path
|
|
|
12 |
import yake
|
13 |
from transformers import AutoTokenizer, PhrasalConstraint
|
14 |
|
15 |
+
|
16 |
def get_tokenizer(model_name="gpt2", verbose=False):
|
17 |
"""
|
18 |
get_tokenizer - returns a tokenizer object
|
|
|
166 |
-------
|
167 |
response : str, generated text
|
168 |
"""
|
169 |
+
logging.debug(f" constraining generation with {locals()}")
|
170 |
+
|
171 |
st = time.perf_counter()
|
172 |
tokenizer = tokenizer or copy.deepcopy(pipeline.tokenizer)
|
173 |
tokenizer.add_prefix_space = True
|
|
|
232 |
force_words_ids=force_words_ids if force_flexible is not None else None,
|
233 |
max_length=None,
|
234 |
max_new_tokens=max_generated_tokens,
|
235 |
+
min_length=min_generated_tokens + prompt_length
|
236 |
+
if full_text
|
237 |
+
else min_generated_tokens,
|
238 |
num_beams=num_beams,
|
239 |
no_repeat_ngram_size=no_repeat_ngram_size,
|
240 |
num_return_sequences=num_return_sequences,
|
converse.py
CHANGED
@@ -186,6 +186,7 @@ def gen_response(
|
|
186 |
str, the generated text
|
187 |
|
188 |
"""
|
|
|
189 |
input_len = len(pipeline.tokenizer(query).input_ids)
|
190 |
if max_length + input_len > 1024:
|
191 |
max_length = max(1024 - input_len, 8)
|
|
|
186 |
str, the generated text
|
187 |
|
188 |
"""
|
189 |
+
logging.debug(f"input args - gen_response() : {locals()}")
|
190 |
input_len = len(pipeline.tokenizer(query).input_ids)
|
191 |
if max_length + input_len > 1024:
|
192 |
max_length = max(1024 - input_len, 8)
|
utils.py
CHANGED
@@ -7,6 +7,7 @@ from pathlib import Path
|
|
7 |
import pprint as pp
|
8 |
import re
|
9 |
import shutil # zipfile formats
|
|
|
10 |
from datetime import datetime
|
11 |
from os.path import basename
|
12 |
from os.path import getsize, join
|
@@ -383,3 +384,15 @@ def cleantxt_wrap(ugly_text, all_lower=False):
|
|
383 |
return clean(ugly_text, lower=all_lower)
|
384 |
else:
|
385 |
return ugly_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import pprint as pp
|
8 |
import re
|
9 |
import shutil # zipfile formats
|
10 |
+
import logging
|
11 |
from datetime import datetime
|
12 |
from os.path import basename
|
13 |
from os.path import getsize, join
|
|
|
384 |
return clean(ugly_text, lower=all_lower)
|
385 |
else:
|
386 |
return ugly_text
|
387 |
+
|
388 |
+
|
389 |
+
def setup_logging(loglevel):
|
390 |
+
"""Setup basic logging
|
391 |
+
|
392 |
+
Args:
|
393 |
+
loglevel (int): minimum loglevel for emitting messages
|
394 |
+
"""
|
395 |
+
logformat = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s"
|
396 |
+
logging.basicConfig(
|
397 |
+
level=loglevel, stream=sys.stdout, format=logformat, datefmt="%Y-%m-%d %H:%M:%S"
|
398 |
+
)
|