Spaces:
Runtime error
Runtime error
Peter
commited on
Commit
·
74b8229
1
Parent(s):
2e158ce
:tada: init from template
Browse files- .gitignore +12 -0
- app.py +251 -0
- converse.py +244 -0
- grammar_improve.py +463 -0
- requirements.txt +16 -0
- symspell_rsc/frequency_bigramdictionary_en_243_342.txt +0 -0
- symspell_rsc/frequency_dictionary_en_82_765.txt +0 -0
- utils.py +385 -0
.gitignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# basics
|
2 |
+
*__pycache__*
|
3 |
+
|
4 |
+
# local testing
|
5 |
+
*aitextgen*
|
6 |
+
*scratch*
|
7 |
+
*tmp*
|
8 |
+
|
9 |
+
# gradio database files
|
10 |
+
*gradio_db_files*
|
11 |
+
*gradio*
|
12 |
+
*flagged*
|
app.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
app.py - the main file for the app. This creates the flask app and handles the routes.
|
3 |
+
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from transformers import pipeline
|
8 |
+
from cleantext import clean
|
9 |
+
from pathlib import Path
|
10 |
+
import warnings
|
11 |
+
import time
|
12 |
+
import argparse
|
13 |
+
import logging
|
14 |
+
import gradio as gr
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
from os.path import dirname
|
18 |
+
import nltk
|
19 |
+
from converse import discussion
|
20 |
+
from grammar_improve import (
|
21 |
+
detect_propers,
|
22 |
+
load_ns_checker,
|
23 |
+
neuspell_correct,
|
24 |
+
remove_repeated_words,
|
25 |
+
remove_trailing_punctuation,
|
26 |
+
build_symspell_obj,
|
27 |
+
symspeller,
|
28 |
+
fix_punct_spacing,
|
29 |
+
)
|
30 |
+
|
31 |
+
from utils import (
|
32 |
+
cleantxt_wrap,
|
33 |
+
corr,
|
34 |
+
)
|
35 |
+
|
36 |
+
nltk.download("stopwords") # TODO: find where this requirement originates from
|
37 |
+
|
38 |
+
sys.path.append(dirname(dirname(os.path.abspath(__file__))))
|
39 |
+
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
|
40 |
+
import transformers
|
41 |
+
|
42 |
+
transformers.logging.set_verbosity_error()
|
43 |
+
logging.basicConfig()
|
44 |
+
cwd = Path.cwd()
|
45 |
+
my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
|
46 |
+
|
47 |
+
|
48 |
+
def chat(trivia_query):
|
49 |
+
"""
|
50 |
+
chat - helper function that makes the whole gradio thing work.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
trivia_query (str): the question to ask the bot
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
[str]: the bot's response
|
57 |
+
"""
|
58 |
+
history = []
|
59 |
+
response = ask_gpt(message=trivia_query, chat_pipe=my_chatbot)
|
60 |
+
history = [trivia_query, response]
|
61 |
+
html = ""
|
62 |
+
for item in history:
|
63 |
+
html += f"<b>{item}</b> <br>"
|
64 |
+
|
65 |
+
html += ""
|
66 |
+
|
67 |
+
return html
|
68 |
+
|
69 |
+
|
70 |
+
def ask_gpt(
|
71 |
+
message: str,
|
72 |
+
chat_pipe,
|
73 |
+
speaker="person alpha",
|
74 |
+
responder="person beta",
|
75 |
+
max_len=196,
|
76 |
+
top_p=0.95,
|
77 |
+
top_k=50,
|
78 |
+
temperature=0.6,
|
79 |
+
):
|
80 |
+
"""
|
81 |
+
|
82 |
+
ask_gpt - a function that takes in a prompt and generates a response using the pipeline. This interacts the discussion function.
|
83 |
+
|
84 |
+
Parameters:
|
85 |
+
message (str): the question to ask the bot
|
86 |
+
chat_pipe (str): the chat_pipe to use for the bot (default: "pszemraj/Ballpark-Trivia-XL")
|
87 |
+
speaker (str): the name of the speaker (default: "person alpha")
|
88 |
+
responder (str): the name of the responder (default: "person beta")
|
89 |
+
max_len (int): the maximum length of the response (default: 128)
|
90 |
+
top_p (float): the top probability threshold (default: 0.95)
|
91 |
+
top_k (int): the top k threshold (default: 50)
|
92 |
+
temperature (float): the temperature of the response (default: 0.7)
|
93 |
+
"""
|
94 |
+
|
95 |
+
st = time.perf_counter()
|
96 |
+
prompt = clean(message) # clean user input
|
97 |
+
prompt = prompt.strip() # get rid of any extra whitespace
|
98 |
+
in_len = len(prompt)
|
99 |
+
if in_len > 512:
|
100 |
+
prompt = prompt[-512:] # truncate to 512 chars
|
101 |
+
print(f"Truncated prompt to last 512 chars: started with {in_len} chars")
|
102 |
+
max_len = min(max_len, 512)
|
103 |
+
|
104 |
+
resp = discussion(
|
105 |
+
prompt_text=prompt,
|
106 |
+
pipeline=chat_pipe,
|
107 |
+
speaker=speaker,
|
108 |
+
responder=responder,
|
109 |
+
top_p=top_p,
|
110 |
+
top_k=top_k,
|
111 |
+
temperature=temperature,
|
112 |
+
max_length=max_len,
|
113 |
+
)
|
114 |
+
gpt_et = time.perf_counter()
|
115 |
+
gpt_rt = round(gpt_et - st, 2)
|
116 |
+
rawtxt = resp["out_text"]
|
117 |
+
# check for proper nouns
|
118 |
+
if basic_sc and not detect_propers(rawtxt):
|
119 |
+
cln_resp = symspeller(rawtxt, sym_checker=schnellspell)
|
120 |
+
elif not detect_propers(rawtxt):
|
121 |
+
cln_resp = neuspell_correct(rawtxt, checker=ns_checker)
|
122 |
+
else:
|
123 |
+
# no correction needed
|
124 |
+
cln_resp = rawtxt.strip()
|
125 |
+
bot_resp_a = corr(remove_repeated_words(cln_resp))
|
126 |
+
bot_resp = fix_punct_spacing(bot_resp_a)
|
127 |
+
print(f"the prompt was:\n\t{message}\nand the response was:\n\t{bot_resp}\n")
|
128 |
+
corr_rt = round(time.perf_counter() - gpt_et, 4)
|
129 |
+
print(
|
130 |
+
f"took {gpt_rt + corr_rt} sec to respond, {gpt_rt} for GPT, {corr_rt} for correction\n"
|
131 |
+
)
|
132 |
+
return remove_trailing_punctuation(bot_resp)
|
133 |
+
|
134 |
+
|
135 |
+
def get_parser():
|
136 |
+
"""
|
137 |
+
get_parser - a helper function for the argparse module
|
138 |
+
"""
|
139 |
+
parser = argparse.ArgumentParser(
|
140 |
+
description="submit a question, GPT model responds"
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"-m",
|
144 |
+
"--model",
|
145 |
+
required=False,
|
146 |
+
type=str,
|
147 |
+
default="pszemraj/GPT-Converse-1pt3B-Neo-WoW-DD-17", # default model
|
148 |
+
help="the model to use for the chatbot on https://huggingface.co/models OR a path to a local model",
|
149 |
+
)
|
150 |
+
parser.add_argument(
|
151 |
+
"--basic-sc",
|
152 |
+
required=False,
|
153 |
+
default=True, # TODO: change this back to False once Neuspell issues are resolved.
|
154 |
+
action="store_true",
|
155 |
+
help="turn on symspell (baseline) correction instead of the more advanced neural net models",
|
156 |
+
)
|
157 |
+
|
158 |
+
parser.add_argument(
|
159 |
+
"--verbose",
|
160 |
+
action="store_true",
|
161 |
+
default=False,
|
162 |
+
help="turn on verbose logging",
|
163 |
+
)
|
164 |
+
return parser
|
165 |
+
|
166 |
+
|
167 |
+
if __name__ == "__main__":
|
168 |
+
args = get_parser().parse_args()
|
169 |
+
default_model = str(args.model)
|
170 |
+
model_loc = Path(default_model) # if the model is a path, use it
|
171 |
+
basic_sc = args.basic_sc # whether to use the baseline spellchecker
|
172 |
+
device = 0 if torch.cuda.is_available() else -1
|
173 |
+
print(f"CUDA avail is {torch.cuda.is_available()}")
|
174 |
+
|
175 |
+
my_chatbot = (
|
176 |
+
pipeline("text-generation", model=model_loc.resolve(), device=device)
|
177 |
+
if model_loc.exists() and model_loc.is_dir()
|
178 |
+
else pipeline("text-generation", model=default_model, device=device)
|
179 |
+
) # if the model is a name, use it. stays on CPU if no GPU available
|
180 |
+
print(f"using model {my_chatbot.model}")
|
181 |
+
|
182 |
+
if basic_sc:
|
183 |
+
print("Using the baseline spellchecker")
|
184 |
+
schnellspell = build_symspell_obj()
|
185 |
+
else:
|
186 |
+
print("using Neuspell spell checker")
|
187 |
+
ns_checker = load_ns_checker(fast=False)
|
188 |
+
|
189 |
+
print(f"using model stored here: \n {model_loc} \n")
|
190 |
+
iface = gr.Interface(
|
191 |
+
chat,
|
192 |
+
inputs=["text"],
|
193 |
+
outputs="html",
|
194 |
+
examples_per_page=10,
|
195 |
+
examples=[
|
196 |
+
"How can you help me?",
|
197 |
+
"what can you do?",
|
198 |
+
"Hi, my name is……",
|
199 |
+
"Happy birthday!",
|
200 |
+
"I have a question, can you help me?",
|
201 |
+
"Do you know a joke?",
|
202 |
+
"Will you marry me?",
|
203 |
+
"Are you single?",
|
204 |
+
"Do you like people?",
|
205 |
+
"Are you part of the Matrix?",
|
206 |
+
"Do you have a hobby?",
|
207 |
+
"You’re clever",
|
208 |
+
"Tell me about your personality",
|
209 |
+
"You’re annoying",
|
210 |
+
"you suck",
|
211 |
+
"I want to speak to a human now.",
|
212 |
+
"Don’t you speak English?!",
|
213 |
+
"Are you human?",
|
214 |
+
"Are you a robot?",
|
215 |
+
"What is your name?",
|
216 |
+
"How old are you?",
|
217 |
+
"What’s your age?",
|
218 |
+
"What day is it today?",
|
219 |
+
"Who made you?",
|
220 |
+
"Which languages can you speak?",
|
221 |
+
"What is your mother’s name?",
|
222 |
+
"Where do you live?",
|
223 |
+
"What’s the weather like today?",
|
224 |
+
"Are you expensive?",
|
225 |
+
"Do you get smarter?",
|
226 |
+
"rate your overall satisfaction with the chatbot",
|
227 |
+
"How many icebergs are in the ocean?",
|
228 |
+
],
|
229 |
+
title=f"NLP template space: {default_model} Model",
|
230 |
+
description=f"this space is used as a template. please copy the files in the space to your own space repo, AND THEN edit them ",
|
231 |
+
article="here you can add more details about your model. \n\n"
|
232 |
+
"**Important Notes & About:**\n\n"
|
233 |
+
"1. the model can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
|
234 |
+
"2. the model started from a pretrained checkpoint, and was trained on several different datasets. Anything it says should be fact-checked before being regarded as a true statement.\n"
|
235 |
+
"3. Some params are still being tweaked (in the future, will be inputs) any feedback is welcome :)\n",
|
236 |
+
css="""
|
237 |
+
.chatbox {display:flex;flex-direction:column}
|
238 |
+
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
|
239 |
+
.user_msg {background-color:cornflowerblue;color:white;align-self:start}
|
240 |
+
.resp_msg {background-color:lightgray;align-self:self-end}
|
241 |
+
""",
|
242 |
+
allow_screenshot=True,
|
243 |
+
allow_flagging="never",
|
244 |
+
theme="dark",
|
245 |
+
)
|
246 |
+
|
247 |
+
# launch the gradio interface and start the server
|
248 |
+
iface.launch(
|
249 |
+
# prevent_thread_lock=True,
|
250 |
+
enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version)
|
251 |
+
)
|
converse.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
converse.py - this script has functions for handling the conversation between the user and the bot.
|
3 |
+
|
4 |
+
https://huggingface.co/docs/transformers/v4.15.0/en/main_classes/model#transformers.generation_utils.GenerationMixin.generate.no_repeat_ngram_size
|
5 |
+
"""
|
6 |
+
|
7 |
+
|
8 |
+
import pprint as pp
|
9 |
+
import time
|
10 |
+
import torch
|
11 |
+
import transformers
|
12 |
+
|
13 |
+
from grammar_improve import remove_trailing_punctuation
|
14 |
+
|
15 |
+
|
16 |
+
def discussion(
|
17 |
+
prompt_text: str,
|
18 |
+
speaker: str,
|
19 |
+
responder: str,
|
20 |
+
pipeline,
|
21 |
+
timeout=45,
|
22 |
+
max_length=128,
|
23 |
+
top_p=0.95,
|
24 |
+
top_k=50,
|
25 |
+
temperature=0.7,
|
26 |
+
full_text=False,
|
27 |
+
num_return_sequences=1,
|
28 |
+
device=-1,
|
29 |
+
verbose=False,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
discussion - a function that takes in a prompt and generates a response. This function is meant to be used in a conversation loop, and is the main function for the bot.
|
33 |
+
|
34 |
+
Parameters
|
35 |
+
----------
|
36 |
+
prompt_text : str, the prompt to ask the bot, usually the user's question
|
37 |
+
speaker : str, the name of the person who is speaking the prompt
|
38 |
+
responder : str, the name of the person who is responding to the prompt
|
39 |
+
pipeline : transformers.Pipeline, the pipeline to use for generating the response
|
40 |
+
timeout : int, optional, the number of seconds to wait before timing out, by default 45
|
41 |
+
max_length : int, optional, the maximum number of tokens to generate, defaults to 128
|
42 |
+
top_p : float, optional, the top probability to use for sampling, defaults to 0.95
|
43 |
+
top_k : int, optional, the top k to use for sampling, defaults to 50
|
44 |
+
temperature : float, optional, the temperature to use for sampling, defaults to 0.7
|
45 |
+
full_text : bool, optional, whether to return the full text or just the generated text, defaults to False
|
46 |
+
num_return_sequences : int, optional, the number of sequences to return, defaults to 1
|
47 |
+
device : int, optional, the device to use for generation, defaults to -1 (CPU)
|
48 |
+
verbose : bool, optional, whether to print the generated text, defaults to False
|
49 |
+
|
50 |
+
Returns
|
51 |
+
-------
|
52 |
+
str, the generated text
|
53 |
+
"""
|
54 |
+
|
55 |
+
p_list = [] # track conversation
|
56 |
+
p_list.append(speaker.lower() + ":" + "\n")
|
57 |
+
p_list.append(prompt_text.lower() + "\n")
|
58 |
+
p_list.append("\n")
|
59 |
+
p_list.append(responder.lower() + ":" + "\n")
|
60 |
+
this_prompt = "".join(p_list)
|
61 |
+
if verbose:
|
62 |
+
print("overall prompt:\n")
|
63 |
+
pp.pprint(this_prompt, indent=4)
|
64 |
+
# call the model
|
65 |
+
print("\n... generating...")
|
66 |
+
bot_dialogue = gen_response(
|
67 |
+
this_prompt,
|
68 |
+
pipeline,
|
69 |
+
speaker,
|
70 |
+
responder,
|
71 |
+
timeout=timeout,
|
72 |
+
max_length=max_length,
|
73 |
+
top_p=top_p,
|
74 |
+
top_k=top_k,
|
75 |
+
temperature=temperature,
|
76 |
+
full_text=full_text,
|
77 |
+
num_return_sequences=num_return_sequences,
|
78 |
+
device=device,
|
79 |
+
verbose=verbose,
|
80 |
+
)
|
81 |
+
if isinstance(bot_dialogue, list) and len(bot_dialogue) > 1:
|
82 |
+
bot_resp = ", ".join(bot_dialogue)
|
83 |
+
elif isinstance(bot_dialogue, list) and len(bot_dialogue) == 1:
|
84 |
+
bot_resp = bot_dialogue[0]
|
85 |
+
else:
|
86 |
+
bot_resp = bot_dialogue
|
87 |
+
bot_resp = " ".join(bot_resp) if isinstance(bot_resp, list) else bot_resp
|
88 |
+
bot_resp = bot_resp.strip()
|
89 |
+
# remove the last ',' '.' chars
|
90 |
+
bot_resp = remove_trailing_punctuation(bot_resp)
|
91 |
+
if verbose:
|
92 |
+
print("\n... bot response:\n")
|
93 |
+
pp.pprint(bot_resp)
|
94 |
+
p_list.append(bot_resp + "\n")
|
95 |
+
p_list.append("\n")
|
96 |
+
|
97 |
+
print("\nfinished!")
|
98 |
+
# return the bot response and the full conversation
|
99 |
+
|
100 |
+
return {"out_text": bot_resp, "full_conv": p_list}
|
101 |
+
|
102 |
+
|
103 |
+
def gen_response(
|
104 |
+
query: str,
|
105 |
+
pipeline,
|
106 |
+
speaker: str,
|
107 |
+
responder: str,
|
108 |
+
timeout=45,
|
109 |
+
max_length=128,
|
110 |
+
top_p=0.95,
|
111 |
+
top_k=50,
|
112 |
+
temperature=0.7,
|
113 |
+
full_text=False,
|
114 |
+
num_return_sequences=1,
|
115 |
+
device=-1,
|
116 |
+
verbose=False,
|
117 |
+
**kwargs,
|
118 |
+
):
|
119 |
+
"""
|
120 |
+
gen_response - a function that takes in a prompt and generates a response using the pipeline. This operates underneath the discussion function.
|
121 |
+
|
122 |
+
Parameters
|
123 |
+
----------
|
124 |
+
query : str, the prompt to ask the bot, usually the user's question
|
125 |
+
speaker : str, the name of the person who is speaking the prompt
|
126 |
+
responder : str, the name of the person who is responding to the prompt
|
127 |
+
pipeline : transformers.Pipeline, the pipeline to use for generating the response
|
128 |
+
timeout : int, optional, the number of seconds to wait before timing out, by default 45
|
129 |
+
max_length : int, optional, the maximum number of tokens to generate, defaults to 128
|
130 |
+
top_p : float, optional, the top probability to use for sampling, defaults to 0.95
|
131 |
+
top_k : int, optional, the top k to use for sampling, defaults to 50
|
132 |
+
temperature : float, optional, the temperature to use for sampling, defaults to 0.7
|
133 |
+
full_text : bool, optional, whether to return the full text or just the generated text, defaults to False
|
134 |
+
num_return_sequences : int, optional, the number of sequences to return, defaults to 1
|
135 |
+
device : int, optional, the device to use for generation, defaults to -1 (CPU)
|
136 |
+
verbose : bool, optional, whether to print the generated text, defaults to False
|
137 |
+
|
138 |
+
Returns
|
139 |
+
-------
|
140 |
+
str, the generated text
|
141 |
+
|
142 |
+
"""
|
143 |
+
|
144 |
+
if max_length > 1024:
|
145 |
+
max_length = 1024
|
146 |
+
print("max_length is too large, setting to 1024")
|
147 |
+
st = time.perf_counter()
|
148 |
+
|
149 |
+
response = pipeline(
|
150 |
+
query,
|
151 |
+
max_length=max_length,
|
152 |
+
temperature=temperature,
|
153 |
+
top_k=top_k,
|
154 |
+
top_p=top_p,
|
155 |
+
num_return_sequences=num_return_sequences,
|
156 |
+
max_time=timeout,
|
157 |
+
return_full_text=full_text,
|
158 |
+
no_repeat_ngram_size=3,
|
159 |
+
length_penalty=0.3,
|
160 |
+
repetition_penalty=3.4,
|
161 |
+
clean_up_tokenization_spaces=True,
|
162 |
+
**kwargs,
|
163 |
+
) # the likely better beam-less method
|
164 |
+
rt = round(time.perf_counter() - st, 2)
|
165 |
+
if verbose:
|
166 |
+
print(f"took {rt} sec to respond")
|
167 |
+
|
168 |
+
if verbose:
|
169 |
+
print("\n[DEBUG] generated:\n")
|
170 |
+
pp.pprint(response) # for debugging
|
171 |
+
# process the full result to get the ~bot response~ piece
|
172 |
+
this_result = str(response[0]["generated_text"]).split(
|
173 |
+
"\n"
|
174 |
+
) # TODO: adjust hardcoded value for index to dynamic (if n>1)
|
175 |
+
|
176 |
+
bot_dialogue = consolidate_texts(
|
177 |
+
name_resp=responder,
|
178 |
+
model_resp=this_result,
|
179 |
+
name_spk=speaker,
|
180 |
+
verbose=verbose,
|
181 |
+
print_debug=True,
|
182 |
+
)
|
183 |
+
if verbose:
|
184 |
+
print(f"DEBUG: {bot_dialogue} was original response pre-SC")
|
185 |
+
|
186 |
+
return bot_dialogue #
|
187 |
+
|
188 |
+
|
189 |
+
def consolidate_texts(
|
190 |
+
model_resp: list,
|
191 |
+
name_resp: str = None,
|
192 |
+
name_spk: str = None,
|
193 |
+
verbose=False,
|
194 |
+
print_debug=False,
|
195 |
+
):
|
196 |
+
"""
|
197 |
+
consolidate_texts - given a list with speaker name followed by speaker text, returns all consecutive values of the first speaker name
|
198 |
+
|
199 |
+
Parameters:
|
200 |
+
name_resp (str): the name of the person who is responding
|
201 |
+
model_resp (list): the list of strings to consolidate (usually from the model)
|
202 |
+
name_spk (str): the name of the person who is speaking
|
203 |
+
verbose (bool): whether to print the results
|
204 |
+
print_debug (bool): whether to print the debug info during looping
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
list, a list of all the consecutive messages of the first speaker name
|
208 |
+
"""
|
209 |
+
assert len(model_resp) > 0, "model_resp is empty"
|
210 |
+
if len(model_resp) == 1:
|
211 |
+
return model_resp[0]
|
212 |
+
name_resp = "person beta" if name_resp is None else name_resp
|
213 |
+
name_spk = "person alpha" if name_spk is None else name_spk
|
214 |
+
if verbose:
|
215 |
+
print("====" * 10)
|
216 |
+
print(f"\n[DEBUG] initial model_resp has {len(model_resp)} lines: \n\t{model_resp}")
|
217 |
+
print(f" the first element is \n\t{model_resp[0]} and it is {type(model_resp[0])}")
|
218 |
+
fn_resp = []
|
219 |
+
|
220 |
+
name_counter = 0
|
221 |
+
break_safe = False
|
222 |
+
for resline in model_resp:
|
223 |
+
if name_resp.lower() in resline:
|
224 |
+
name_counter += 1
|
225 |
+
break_safe = True # know the line is from bot as this line starts with the name of the bot
|
226 |
+
continue # don't add this line to the list
|
227 |
+
if name_spk.lower() in resline.lower():
|
228 |
+
if print_debug:
|
229 |
+
print(f"\nDEBUG: \n\t{resline}\ncaused the break")
|
230 |
+
break # the name of the speaker is in the line, so we're done
|
231 |
+
if any([": " in resline,":\n" in resline]) and name_resp.lower() not in resline.lower():
|
232 |
+
if print_debug:
|
233 |
+
print(f"\nDEBUG: \n\t{resline}\ncaused the break")
|
234 |
+
break
|
235 |
+
else:
|
236 |
+
fn_resp.append(resline)
|
237 |
+
break_safe = False
|
238 |
+
if verbose:
|
239 |
+
print("--" * 10)
|
240 |
+
print("\nthe full response is:\n")
|
241 |
+
print("\n".join(fn_resp))
|
242 |
+
print("--" * 10)
|
243 |
+
|
244 |
+
return fn_resp
|
grammar_improve.py
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
grammar_improve.py - this .py script contains functions to improve the grammar of a user's input or the models output.
|
3 |
+
|
4 |
+
"""
|
5 |
+
|
6 |
+
from datetime import datetime
|
7 |
+
import os
|
8 |
+
import pprint as pp
|
9 |
+
from neuspell import BertChecker, SclstmChecker
|
10 |
+
import neuspell
|
11 |
+
import math
|
12 |
+
from cleantext import clean
|
13 |
+
import time
|
14 |
+
import re
|
15 |
+
import sys
|
16 |
+
from symspellpy.symspellpy import SymSpell
|
17 |
+
|
18 |
+
from utils import suppress_stdout
|
19 |
+
|
20 |
+
|
21 |
+
def detect_propers(text: str):
|
22 |
+
"""
|
23 |
+
detect_propers - detect if a string contains proper nouns
|
24 |
+
|
25 |
+
Args:
|
26 |
+
text (str): [string to be checked]
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
[bool]: [True if string contains proper nouns]
|
30 |
+
"""
|
31 |
+
pat = re.compile(r"(?:\w+['’])?\w+(?:-(?:\w+['’])?\w+)*")
|
32 |
+
return bool(pat.search(text))
|
33 |
+
|
34 |
+
|
35 |
+
def fix_punct_spaces(string):
|
36 |
+
"""
|
37 |
+
fix_punct_spaces - replace spaces around punctuation with punctuation. For example, "hello , there" -> "hello, there"
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
string : str, required, input string to be corrected
|
42 |
+
|
43 |
+
Returns
|
44 |
+
-------
|
45 |
+
str, corrected string
|
46 |
+
"""
|
47 |
+
|
48 |
+
fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*")
|
49 |
+
string = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), string)
|
50 |
+
return string.strip()
|
51 |
+
|
52 |
+
|
53 |
+
def split_sentences(text: str):
|
54 |
+
"""
|
55 |
+
split_sentences - split a string into a list of sentences that keep their ending punctuation. powered by regex witchcraft
|
56 |
+
|
57 |
+
Args:
|
58 |
+
text (str): [string to be split]
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
[list]: [list of strings]
|
62 |
+
"""
|
63 |
+
return re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", text)
|
64 |
+
|
65 |
+
|
66 |
+
def remove_repeated_words(bot_response):
|
67 |
+
"""
|
68 |
+
remove_repeated_words - remove repeated words from a string, returning only the first instance of each word
|
69 |
+
|
70 |
+
Parameters
|
71 |
+
----------
|
72 |
+
bot_response : str
|
73 |
+
string to remove repeated words from
|
74 |
+
|
75 |
+
Returns
|
76 |
+
-------
|
77 |
+
str
|
78 |
+
string containing the first instance of each word
|
79 |
+
"""
|
80 |
+
words = bot_response.split()
|
81 |
+
unique_words = []
|
82 |
+
for word in words:
|
83 |
+
if word not in unique_words:
|
84 |
+
unique_words.append(word)
|
85 |
+
return " ".join(unique_words)
|
86 |
+
|
87 |
+
|
88 |
+
def remove_trailing_punctuation(text: str, fuLL_strip=False):
|
89 |
+
"""
|
90 |
+
remove_trailing_punctuation - remove trailing punctuation from a string. Purpose is to seem more natural to end users
|
91 |
+
|
92 |
+
Args:
|
93 |
+
text (str): [string to be cleaned]
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
[str]: [cleaned string]
|
97 |
+
"""
|
98 |
+
if fuLL_strip:
|
99 |
+
return text.strip("?!.,;:")
|
100 |
+
else:
|
101 |
+
return text.strip(".,;:")
|
102 |
+
|
103 |
+
|
104 |
+
def fix_punct_spacing(text: str):
|
105 |
+
fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*")
|
106 |
+
spc_text = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), text)
|
107 |
+
cln_text = re.sub(r"(\W)(?=\1)", "", spc_text)
|
108 |
+
|
109 |
+
return cln_text
|
110 |
+
|
111 |
+
|
112 |
+
"""
|
113 |
+
start of SymSpell code
|
114 |
+
"""
|
115 |
+
|
116 |
+
|
117 |
+
def symspeller(
|
118 |
+
my_string: str,
|
119 |
+
sym_checker=None,
|
120 |
+
max_dist: int = 2,
|
121 |
+
prefix_length: int = 7,
|
122 |
+
ignore_non_words=True,
|
123 |
+
dictionary_path: str = None,
|
124 |
+
bigram_path: str = None,
|
125 |
+
verbose=False,
|
126 |
+
):
|
127 |
+
"""
|
128 |
+
symspeller - a wrapper for the SymSpell class from symspellpy
|
129 |
+
|
130 |
+
Parameters
|
131 |
+
----------
|
132 |
+
my_string : str, required, default=None, the string to be checked
|
133 |
+
sym_checker : SymSpell, optional, default=None, the SymSpell object to use
|
134 |
+
max_dist : int, optional, default=3, the maximum distance to look for replacements
|
135 |
+
prefix_length : int, optional, default=7, the length of the prefixes to use
|
136 |
+
ignore_non_words : bool, optional, default=True, whether to ignore non-words
|
137 |
+
dictionary_path : str, optional, default=None, the path to the dictionary file
|
138 |
+
bigram_path : str, optional, default=None, the path to the bigram dictionary file
|
139 |
+
verbose : bool, optional, default=False, whether to print the results
|
140 |
+
|
141 |
+
Returns
|
142 |
+
-------
|
143 |
+
list,
|
144 |
+
|
145 |
+
"""
|
146 |
+
|
147 |
+
assert len(my_string) > 0, "entered string for correction is empty"
|
148 |
+
|
149 |
+
if sym_checker is None:
|
150 |
+
# need to create a new class object. user can specify their own dictionary and bigram files
|
151 |
+
if verbose:
|
152 |
+
print("creating new SymSpell object")
|
153 |
+
sym_checker = build_symspell_obj(
|
154 |
+
edit_dist=max_dist,
|
155 |
+
prefix_length=prefix_length,
|
156 |
+
dictionary_path=dictionary_path,
|
157 |
+
bigram_path=bigram_path,
|
158 |
+
)
|
159 |
+
else:
|
160 |
+
if verbose:
|
161 |
+
print("using existing SymSpell object")
|
162 |
+
# max edit distance per lookup (per single word, not per whole input string)
|
163 |
+
suggestions = sym_checker.lookup_compound(
|
164 |
+
my_string,
|
165 |
+
max_edit_distance=max_dist,
|
166 |
+
ignore_non_words=ignore_non_words,
|
167 |
+
ignore_term_with_digits=True,
|
168 |
+
transfer_casing=True,
|
169 |
+
)
|
170 |
+
|
171 |
+
if verbose:
|
172 |
+
print(f"{len(suggestions)} suggestions found")
|
173 |
+
print(f"the original string is:\n\t{my_string}")
|
174 |
+
sug_list = [sug.term for sug in suggestions]
|
175 |
+
print(f"suggestions:\n\t{sug_list}\n")
|
176 |
+
|
177 |
+
if len(suggestions) < 1:
|
178 |
+
return clean(my_string) # no correction because no suggestions
|
179 |
+
else:
|
180 |
+
first_result = suggestions[0] # first result is the most likely
|
181 |
+
return first_result._term
|
182 |
+
|
183 |
+
|
184 |
+
def build_symspell_obj(
|
185 |
+
edit_dist=2,
|
186 |
+
prefix_length=7,
|
187 |
+
dictionary_path=None,
|
188 |
+
bigram_path=None,
|
189 |
+
):
|
190 |
+
"""
|
191 |
+
build_symspell_obj [build a SymSpell object]
|
192 |
+
|
193 |
+
Args:
|
194 |
+
verbose (bool, optional): Defaults to False.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
SymSpell: a SymSpell object
|
198 |
+
"""
|
199 |
+
dictionary_path = (
|
200 |
+
r"symspell_rsc/frequency_dictionary_en_82_765.txt"
|
201 |
+
if dictionary_path is None
|
202 |
+
else dictionary_path
|
203 |
+
)
|
204 |
+
bigram_path = (
|
205 |
+
r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt"
|
206 |
+
if bigram_path is None
|
207 |
+
else bigram_path
|
208 |
+
)
|
209 |
+
sym_checker = SymSpell(
|
210 |
+
max_dictionary_edit_distance=edit_dist + 2, prefix_length=prefix_length
|
211 |
+
)
|
212 |
+
# term_index is the column of the term and count_index is the
|
213 |
+
# column of the term frequency
|
214 |
+
sym_checker.load_dictionary(dictionary_path, term_index=0, count_index=1)
|
215 |
+
sym_checker.load_bigram_dictionary(bigram_path, term_index=0, count_index=2)
|
216 |
+
|
217 |
+
return sym_checker
|
218 |
+
|
219 |
+
|
220 |
+
"""
|
221 |
+
# if using t5b_correction to check for spelling errors, use this code to initialize the objects
|
222 |
+
|
223 |
+
import torch
|
224 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
225 |
+
|
226 |
+
model_name = 'deep-learning-analytics/GrammarCorrector'
|
227 |
+
# torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
228 |
+
torch_device = 'cpu'
|
229 |
+
gc_tokenizer = T5Tokenizer.from_pretrained(model_name)
|
230 |
+
gc_model = T5ForConditionalGeneration.from_pretrained(model_name).to(torch_device)
|
231 |
+
|
232 |
+
"""
|
233 |
+
|
234 |
+
|
235 |
+
def t5b_correction(prompt: str, korrektor, verbose=False, beams=4):
|
236 |
+
"""
|
237 |
+
t5b_correction - correct a string using a text2textgen pipeline model from transformers
|
238 |
+
|
239 |
+
Parameters
|
240 |
+
----------
|
241 |
+
prompt : str, required, input prompt to be corrected
|
242 |
+
korrektor : transformers.pipeline, required, pipeline object
|
243 |
+
verbose : bool, optional, whether to print the corrected prompt. Defaults to False.
|
244 |
+
beams : int, optional, number of beams to use for the correction. Defaults to 4.
|
245 |
+
|
246 |
+
Returns
|
247 |
+
-------
|
248 |
+
str, corrected prompt
|
249 |
+
"""
|
250 |
+
|
251 |
+
p_min_len = int(math.ceil(0.9 * len(prompt)))
|
252 |
+
p_max_len = int(math.ceil(1.1 * len(prompt)))
|
253 |
+
if verbose:
|
254 |
+
print(f"setting min to {p_min_len} and max to {p_max_len}\n")
|
255 |
+
gcorr_result = korrektor(
|
256 |
+
f"grammar: {prompt}",
|
257 |
+
return_text=True,
|
258 |
+
clean_up_tokenization_spaces=True,
|
259 |
+
num_beams=beams,
|
260 |
+
max_length=p_max_len,
|
261 |
+
repetition_penalty=1.3,
|
262 |
+
length_penalty=0.2,
|
263 |
+
no_repeat_ngram_size=2,
|
264 |
+
)
|
265 |
+
if verbose:
|
266 |
+
print(f"grammar correction result: \n\t{gcorr_result}\n")
|
267 |
+
return gcorr_result
|
268 |
+
|
269 |
+
|
270 |
+
def all_neuspell_chkrs():
|
271 |
+
"""
|
272 |
+
disp_neuspell_chkrs - display the neuspell checkers available
|
273 |
+
|
274 |
+
Parameters
|
275 |
+
----------
|
276 |
+
None
|
277 |
+
|
278 |
+
Returns
|
279 |
+
-------
|
280 |
+
checker_opts - list of checkers available
|
281 |
+
"""
|
282 |
+
|
283 |
+
checker_opts = dir(neuspell)
|
284 |
+
print(f"\navailable checkers:")
|
285 |
+
|
286 |
+
pp.pprint(checker_opts, indent=4, compact=True)
|
287 |
+
|
288 |
+
return checker_opts
|
289 |
+
|
290 |
+
|
291 |
+
def load_ns_checker(customckr=None, fast=False):
|
292 |
+
"""
|
293 |
+
load_ns_checker - helper function, load / "set up" a neuspell checker from huggingface transformers
|
294 |
+
|
295 |
+
Args:
|
296 |
+
customckr (neuspell.NeuSpell): [neuspell checker object], optional, if not provided, will load the default checker
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
[neuspell.NeuSpell]: [neuspell checker object]
|
300 |
+
"""
|
301 |
+
st = time.perf_counter()
|
302 |
+
# stop all printing to the console
|
303 |
+
with suppress_stdout():
|
304 |
+
if customckr is None and not fast:
|
305 |
+
|
306 |
+
checker = BertChecker(
|
307 |
+
pretrained=True
|
308 |
+
) # load the default checker, has the best balance
|
309 |
+
elif customckr is None and fast:
|
310 |
+
checker = SclstmChecker(
|
311 |
+
pretrained=True
|
312 |
+
) # this one is faster but not as accurate
|
313 |
+
else:
|
314 |
+
checker = customckr(pretrained=True)
|
315 |
+
rt_min = (time.perf_counter() - st) / 60
|
316 |
+
# return to standard logging level
|
317 |
+
print(f"\n\nloaded checker in {rt_min} minutes")
|
318 |
+
|
319 |
+
return checker
|
320 |
+
|
321 |
+
|
322 |
+
def neuspell_correct(input_text: str, checker=None, verbose=False):
|
323 |
+
"""
|
324 |
+
neuspell_correct - correct a string using neuspell.
|
325 |
+
note that modificaitons to the checker are needed if doing list-based corrections
|
326 |
+
|
327 |
+
Parameters
|
328 |
+
----------
|
329 |
+
input_text : str, required, input string to be corrected
|
330 |
+
checker : neuspell.NeuSpell, optional, neuspell checker object. Defaults to None.
|
331 |
+
verbose : bool, optional, whether to print the corrected string. Defaults to False.
|
332 |
+
|
333 |
+
Returns
|
334 |
+
-------
|
335 |
+
str, corrected string
|
336 |
+
"""
|
337 |
+
if isinstance(input_text, str) and len(input_text) < 4:
|
338 |
+
print(f"input text of {input_text} is too short to be corrected")
|
339 |
+
return input_text
|
340 |
+
|
341 |
+
if checker is None:
|
342 |
+
print("NOTE - no checker provided, loading default checker")
|
343 |
+
checker = SclstmChecker(pretrained=True)
|
344 |
+
|
345 |
+
corrected = checker.correct(input_text)
|
346 |
+
cleaned_txt = fix_punct_spaces(corrected)
|
347 |
+
|
348 |
+
if verbose:
|
349 |
+
print(f"neuspell correction result: \n\t{cleaned_txt}\n")
|
350 |
+
return cleaned_txt
|
351 |
+
|
352 |
+
|
353 |
+
def grammarpipe(corrector, qphrase: str):
|
354 |
+
"""
|
355 |
+
gramformer_correct - THE ORIGINAL ONE USED IN PROJECT AND NEEDS TO BE CHANGED.
|
356 |
+
Idea is to correct a string using a text2textgen pipeline model from transformers
|
357 |
+
Args:
|
358 |
+
corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model]
|
359 |
+
qphrase (str): [text to be corrected]
|
360 |
+
Returns:
|
361 |
+
[str]: [corrected text]
|
362 |
+
"""
|
363 |
+
if isinstance(qphrase, str) and len(qphrase) < 4:
|
364 |
+
print(f"input text of {qphrase} is too short to be corrected")
|
365 |
+
return qphrase
|
366 |
+
try:
|
367 |
+
corrected = corrector(
|
368 |
+
clean(qphrase), return_text=True, clean_up_tokenization_spaces=True
|
369 |
+
)
|
370 |
+
return corrected[0]["generated_text"]
|
371 |
+
except Exception as e:
|
372 |
+
print(f"NOTE - failed to correct with grammarpipe:\n {e}")
|
373 |
+
return clean(qphrase)
|
374 |
+
|
375 |
+
|
376 |
+
def DLA_correct(qphrase: str):
|
377 |
+
"""
|
378 |
+
DLA_correct - an "overhead" function to call correct_grammar() on a string, allowing for each newline to be corrected individually
|
379 |
+
|
380 |
+
Args:
|
381 |
+
qphrase (str): [string to be corrected]
|
382 |
+
|
383 |
+
Returns:
|
384 |
+
str, the list of the corrected strings joined under " "
|
385 |
+
"""
|
386 |
+
if isinstance(qphrase, str) and len(qphrase) < 4:
|
387 |
+
print(f"input text of {qphrase} is too short to be corrected")
|
388 |
+
return qphrase
|
389 |
+
|
390 |
+
sentences = split_sentences(qphrase)
|
391 |
+
if len(sentences) == 1:
|
392 |
+
corrected = correct_grammar(sentences[0])
|
393 |
+
return corrected
|
394 |
+
else:
|
395 |
+
full_cor = []
|
396 |
+
for sen in sentences:
|
397 |
+
corr_sen = correct_grammar(clean(sen))
|
398 |
+
full_cor.append(corr_sen)
|
399 |
+
return " ".join(full_cor)
|
400 |
+
|
401 |
+
|
402 |
+
def correct_grammar(
|
403 |
+
input_text: str,
|
404 |
+
tokenizer,
|
405 |
+
model,
|
406 |
+
n_results: int = 1,
|
407 |
+
beams: int = 8,
|
408 |
+
temp=1,
|
409 |
+
uniq_ngrams=2,
|
410 |
+
rep_penalty=1.5,
|
411 |
+
device="cpu",
|
412 |
+
):
|
413 |
+
"""
|
414 |
+
correct_grammar - correct a string using a text2textgen pipeline model from transformers.
|
415 |
+
This function is an alternative to the t5b_correction function.
|
416 |
+
|
417 |
+
Parameters
|
418 |
+
----------
|
419 |
+
input_text : str, required, input string to be corrected
|
420 |
+
tokenizer : transformers.T5Tokenizer, required, tokenizer object, already created w/ relevant model
|
421 |
+
model : transformers.T5ForConditionalGeneration, required, model object, already created w/ relevant model
|
422 |
+
n_results : int, optional, number of results to return. Defaults to 1.
|
423 |
+
beams : int, optional, number of beams to use for the correction. Defaults to 8.
|
424 |
+
temp : int, optional, temperature to use for the correction. Defaults to 1.
|
425 |
+
uniq_ngrams : int, optional, number of ngrams to use for the correction. Defaults to 2.
|
426 |
+
rep_penalty : float, optional, penalty to use for the correction. Defaults to 1.5.
|
427 |
+
device : str, optional, device to use for the correction. Defaults to 'cpu'.
|
428 |
+
|
429 |
+
Returns
|
430 |
+
-------
|
431 |
+
str, corrected string (or list of strings if n_results > 1)
|
432 |
+
"""
|
433 |
+
st = time.perf_counter()
|
434 |
+
|
435 |
+
if len(input_text) < 5:
|
436 |
+
return input_text
|
437 |
+
max_length = min(int(math.ceil(len(input_text) * 1.2)), 128)
|
438 |
+
batch = tokenizer(
|
439 |
+
[input_text],
|
440 |
+
truncation=True,
|
441 |
+
padding="max_length",
|
442 |
+
max_length=max_length,
|
443 |
+
return_tensors="pt",
|
444 |
+
).to(device)
|
445 |
+
translated = model.generate(
|
446 |
+
**batch,
|
447 |
+
max_length=max_length,
|
448 |
+
min_length=min(10, len(input_text)),
|
449 |
+
no_repeat_ngram_size=uniq_ngrams,
|
450 |
+
repetition_penalty=rep_penalty,
|
451 |
+
num_beams=beams,
|
452 |
+
num_return_sequences=n_results,
|
453 |
+
temperature=temp,
|
454 |
+
)
|
455 |
+
|
456 |
+
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
457 |
+
rt_min = (time.perf_counter() - st) / 60
|
458 |
+
print(f"\n\ncorrected in {rt_min} minutes")
|
459 |
+
|
460 |
+
if isinstance(tgt_text, list):
|
461 |
+
return tgt_text[0]
|
462 |
+
else:
|
463 |
+
return tgt_text
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers>=4.12.5
|
2 |
+
sentencepiece>=0.1.96
|
3 |
+
tqdm>=4.43.0
|
4 |
+
symspellpy>=6.7.0
|
5 |
+
requests>=2.24.0
|
6 |
+
gradio>=2.4.6
|
7 |
+
natsort>=7.1.1
|
8 |
+
pandas>=1.3.0
|
9 |
+
aitextgen>=0.5.2
|
10 |
+
clean-text>=0.5.0
|
11 |
+
openwa>=1.3.16
|
12 |
+
python-telegram-bot>=13.0
|
13 |
+
webwhatsapi>=2.0.5
|
14 |
+
Flask>=2.0.2
|
15 |
+
nltk>=3.6.6
|
16 |
+
neuspell>=1.0.0
|
symspell_rsc/frequency_bigramdictionary_en_243_342.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
symspell_rsc/frequency_dictionary_en_82_765.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
utils.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
utils - general utility functions for loading, saving, and manipulating data
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
import pprint as pp
|
8 |
+
import re
|
9 |
+
import shutil # zipfile formats
|
10 |
+
from datetime import datetime
|
11 |
+
from os.path import basename
|
12 |
+
from os.path import getsize, join
|
13 |
+
|
14 |
+
import requests
|
15 |
+
from cleantext import clean
|
16 |
+
from natsort import natsorted
|
17 |
+
from symspellpy import SymSpell
|
18 |
+
import pandas as pd
|
19 |
+
from tqdm.auto import tqdm
|
20 |
+
|
21 |
+
|
22 |
+
from contextlib import contextmanager
|
23 |
+
import sys
|
24 |
+
import os
|
25 |
+
|
26 |
+
|
27 |
+
@contextmanager
|
28 |
+
def suppress_stdout():
|
29 |
+
"""
|
30 |
+
suppress_stdout - suppress stdout for a given block of code. credit to https://newbedev.com/how-to-suppress-console-output-in-python
|
31 |
+
"""
|
32 |
+
with open(os.devnull, "w") as devnull:
|
33 |
+
old_stdout = sys.stdout
|
34 |
+
sys.stdout = devnull
|
35 |
+
try:
|
36 |
+
yield
|
37 |
+
finally:
|
38 |
+
sys.stdout = old_stdout
|
39 |
+
|
40 |
+
|
41 |
+
def remove_string_extras(mytext):
|
42 |
+
# removes everything from a string except A-Za-z0-9 .,;
|
43 |
+
return re.sub(r"[^A-Za-z0-9 .,;]+", "", mytext)
|
44 |
+
|
45 |
+
|
46 |
+
def corr(s):
|
47 |
+
# adds space after period if there isn't one
|
48 |
+
# removes extra spaces
|
49 |
+
return re.sub(r"\.(?! )", ". ", re.sub(r" +", " ", s))
|
50 |
+
|
51 |
+
|
52 |
+
def get_timestamp():
|
53 |
+
# get timestamp for file names
|
54 |
+
return datetime.now().strftime("%b-%d-%Y_t-%H")
|
55 |
+
|
56 |
+
|
57 |
+
def print_spacer(n=1):
|
58 |
+
"""print_spacer - print a spacer line"""
|
59 |
+
print("\n -------- " * n)
|
60 |
+
|
61 |
+
|
62 |
+
def fast_scandir(dirname: str):
|
63 |
+
"""
|
64 |
+
fast_scandir [an os.path-based means to return all subfolders in a given filepath]
|
65 |
+
|
66 |
+
"""
|
67 |
+
|
68 |
+
subfolders = [f.path for f in os.scandir(dirname) if f.is_dir()]
|
69 |
+
for dirname in list(subfolders):
|
70 |
+
subfolders.extend(fast_scandir(dirname))
|
71 |
+
return subfolders # list
|
72 |
+
|
73 |
+
|
74 |
+
def create_folder(directory: str):
|
75 |
+
# you will never guess what this does
|
76 |
+
os.makedirs(directory, exist_ok=True)
|
77 |
+
|
78 |
+
|
79 |
+
def chunks(lst: list, n: int):
|
80 |
+
"""
|
81 |
+
chunks - Yield successive n-sized chunks from lst
|
82 |
+
Args: lst (list): list to be chunked
|
83 |
+
n (int): size of chunks
|
84 |
+
|
85 |
+
"""
|
86 |
+
|
87 |
+
for i in range(0, len(lst), n):
|
88 |
+
yield lst[i : i + n]
|
89 |
+
|
90 |
+
|
91 |
+
def chunky_pandas(my_df, num_chunks: int = 4):
|
92 |
+
"""
|
93 |
+
chunky_pandas [split dataframe into `num_chunks` equal chunks, return each inside a list]
|
94 |
+
|
95 |
+
Args:
|
96 |
+
my_df (pd.DataFrame)
|
97 |
+
num_chunks (int, optional): Defaults to 4.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
list: a list of dataframes
|
101 |
+
"""
|
102 |
+
n = int(len(my_df) // num_chunks)
|
103 |
+
list_df = [my_df[i : i + n] for i in range(0, my_df.shape[0], n)]
|
104 |
+
|
105 |
+
return list_df
|
106 |
+
|
107 |
+
|
108 |
+
def load_dir_files(
|
109 |
+
directory: str, req_extension=".txt", return_type="list", verbose=False
|
110 |
+
):
|
111 |
+
"""
|
112 |
+
load_dir_files - an os.path based method of returning all files with extension `req_extension` in a given directory and subdirectories
|
113 |
+
|
114 |
+
Args:
|
115 |
+
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
list or dict: an iterable of filepaths or a dict of filepaths and their respective filenames
|
119 |
+
"""
|
120 |
+
appr_files = []
|
121 |
+
# r=root, d=directories, f = files
|
122 |
+
for r, d, f in os.walk(directory):
|
123 |
+
for prefile in f:
|
124 |
+
if prefile.endswith(req_extension):
|
125 |
+
fullpath = os.path.join(r, prefile)
|
126 |
+
appr_files.append(fullpath)
|
127 |
+
|
128 |
+
appr_files = natsorted(appr_files)
|
129 |
+
|
130 |
+
if verbose:
|
131 |
+
print("A list of files in the {} directory are: \n".format(directory))
|
132 |
+
if len(appr_files) < 10:
|
133 |
+
pp.pprint(appr_files)
|
134 |
+
else:
|
135 |
+
pp.pprint(appr_files[:10])
|
136 |
+
print("\n and more. There are a total of {} files".format(len(appr_files)))
|
137 |
+
|
138 |
+
if return_type.lower() == "list":
|
139 |
+
return appr_files
|
140 |
+
else:
|
141 |
+
if verbose:
|
142 |
+
print("returning dictionary")
|
143 |
+
|
144 |
+
appr_file_dict = {}
|
145 |
+
for this_file in appr_files:
|
146 |
+
appr_file_dict[basename(this_file)] = this_file
|
147 |
+
|
148 |
+
return appr_file_dict
|
149 |
+
|
150 |
+
|
151 |
+
def URL_string_filter(text):
|
152 |
+
"""
|
153 |
+
URL_string_filter - filter out nonstandard "text" characters
|
154 |
+
|
155 |
+
"""
|
156 |
+
custom_printable = (
|
157 |
+
"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ._"
|
158 |
+
)
|
159 |
+
|
160 |
+
filtered = "".join((filter(lambda i: i in custom_printable, text)))
|
161 |
+
|
162 |
+
return filtered
|
163 |
+
|
164 |
+
|
165 |
+
def getFilename_fromCd(cd):
|
166 |
+
"""getFilename_fromCd - get the filename from a given cd str"""
|
167 |
+
if not cd:
|
168 |
+
return None
|
169 |
+
fname = re.findall("filename=(.+)", cd)
|
170 |
+
if len(fname) > 0:
|
171 |
+
output = fname[0]
|
172 |
+
elif cd.find("/"):
|
173 |
+
possible_fname = cd.rsplit("/", 1)[1]
|
174 |
+
output = URL_string_filter(possible_fname)
|
175 |
+
else:
|
176 |
+
output = None
|
177 |
+
return output
|
178 |
+
|
179 |
+
|
180 |
+
def get_zip_URL(
|
181 |
+
URLtoget: str,
|
182 |
+
extract_loc: str = None,
|
183 |
+
file_header: str = "dropboxexport_",
|
184 |
+
verbose: bool = False,
|
185 |
+
):
|
186 |
+
"""get_zip_URL - download a zip file from a given URL and extract it to a given location"""
|
187 |
+
|
188 |
+
r = requests.get(URLtoget, allow_redirects=True)
|
189 |
+
names = getFilename_fromCd(r.headers.get("content-disposition"))
|
190 |
+
fixed_fnames = names.split(";") # split the multiple results
|
191 |
+
this_filename = file_header + URL_string_filter(fixed_fnames[0])
|
192 |
+
|
193 |
+
# define paths and save the zip file
|
194 |
+
if extract_loc is None:
|
195 |
+
extract_loc = "dropbox_dl"
|
196 |
+
dl_place = join(os.getcwd(), extract_loc)
|
197 |
+
create_folder(dl_place)
|
198 |
+
save_loc = join(os.getcwd(), this_filename)
|
199 |
+
open(save_loc, "wb").write(r.content)
|
200 |
+
if verbose:
|
201 |
+
print("downloaded file size was {} MB".format(getsize(save_loc) / 1000000))
|
202 |
+
|
203 |
+
# unpack the archive
|
204 |
+
shutil.unpack_archive(save_loc, extract_dir=dl_place)
|
205 |
+
if verbose:
|
206 |
+
print("extracted zip file - ", datetime.now())
|
207 |
+
x = load_dir_files(dl_place, req_extension="", verbose=verbose)
|
208 |
+
|
209 |
+
# remove original
|
210 |
+
try:
|
211 |
+
os.remove(save_loc)
|
212 |
+
del save_loc
|
213 |
+
except Exception:
|
214 |
+
print("unable to delete original zipfile - check if exists", datetime.now())
|
215 |
+
|
216 |
+
print("finished extracting zip - ", datetime.now())
|
217 |
+
|
218 |
+
return dl_place
|
219 |
+
|
220 |
+
|
221 |
+
def merge_dataframes(data_dir: str, ext=".xlsx", verbose=False):
|
222 |
+
"""
|
223 |
+
merge_dataframes - given a filepath, loads and attempts to merge all files as dataframes
|
224 |
+
|
225 |
+
Args:
|
226 |
+
data_dir (str): [root directory to search in]
|
227 |
+
ext (str, optional): [anticipate file extension for the dataframes ]. Defaults to '.xlsx'.
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
pd.DataFrame(): merged dataframe of all files
|
231 |
+
"""
|
232 |
+
|
233 |
+
src = Path(data_dir)
|
234 |
+
src_str = str(src.resolve())
|
235 |
+
mrg_df = pd.DataFrame()
|
236 |
+
|
237 |
+
all_reports = load_dir_files(directory=src_str, req_extension=ext, verbose=verbose)
|
238 |
+
|
239 |
+
failed = []
|
240 |
+
|
241 |
+
for df_path in tqdm(all_reports, total=len(all_reports), desc="joining data..."):
|
242 |
+
|
243 |
+
try:
|
244 |
+
this_df = pd.read_excel(df_path).convert_dtypes()
|
245 |
+
|
246 |
+
mrg_df = pd.concat([mrg_df, this_df], axis=0)
|
247 |
+
except Exception:
|
248 |
+
short_p = os.path.basename(df_path)
|
249 |
+
print(
|
250 |
+
f"WARNING - file with extension {ext} and name {short_p} could not be read."
|
251 |
+
)
|
252 |
+
failed.append(short_p)
|
253 |
+
|
254 |
+
if len(failed) > 0:
|
255 |
+
print("failed to merge {} files, investigate as needed")
|
256 |
+
|
257 |
+
if verbose:
|
258 |
+
pp.pprint(mrg_df.info(True))
|
259 |
+
|
260 |
+
return mrg_df
|
261 |
+
|
262 |
+
|
263 |
+
def download_URL(url: str, file=None, dlpath=None, verbose=False):
|
264 |
+
"""
|
265 |
+
download_URL - download a file from a URL and show progress bar
|
266 |
+
|
267 |
+
Parameters
|
268 |
+
----------
|
269 |
+
url : str
|
270 |
+
URL to download
|
271 |
+
file : [type], optional
|
272 |
+
[description], by default None
|
273 |
+
dlpath : [type], optional
|
274 |
+
[description], by default None
|
275 |
+
verbose : bool, optional
|
276 |
+
[description], by default False
|
277 |
+
|
278 |
+
Returns
|
279 |
+
-------
|
280 |
+
str - path to the downloaded file
|
281 |
+
"""
|
282 |
+
|
283 |
+
if file is None:
|
284 |
+
if "?dl=" in url:
|
285 |
+
# is a dropbox link
|
286 |
+
prefile = url.split("/")[-1]
|
287 |
+
filename = str(prefile).split("?dl=")[0]
|
288 |
+
else:
|
289 |
+
filename = url.split("/")[-1]
|
290 |
+
|
291 |
+
file = clean(filename)
|
292 |
+
if dlpath is None:
|
293 |
+
dlpath = Path.cwd() # save to current working directory
|
294 |
+
else:
|
295 |
+
dlpath = Path(dlpath) # make a path object
|
296 |
+
|
297 |
+
r = requests.get(url, stream=True, allow_redirects=True)
|
298 |
+
total_size = int(r.headers.get("content-length"))
|
299 |
+
initial_pos = 0
|
300 |
+
dl_loc = dlpath / file
|
301 |
+
with open(str(dl_loc.resolve()), "wb") as f:
|
302 |
+
with tqdm(
|
303 |
+
total=total_size,
|
304 |
+
unit="B",
|
305 |
+
unit_scale=True,
|
306 |
+
desc=file,
|
307 |
+
initial=initial_pos,
|
308 |
+
ascii=True,
|
309 |
+
) as pbar:
|
310 |
+
for ch in r.iter_content(chunk_size=1024):
|
311 |
+
if ch:
|
312 |
+
f.write(ch)
|
313 |
+
pbar.update(len(ch))
|
314 |
+
|
315 |
+
if verbose:
|
316 |
+
print(f"\ndownloaded {file} to {dlpath}\n")
|
317 |
+
|
318 |
+
return str(dl_loc.resolve())
|
319 |
+
|
320 |
+
|
321 |
+
def dl_extract_zip(
|
322 |
+
URLtoget: str,
|
323 |
+
extract_loc: str = None,
|
324 |
+
file_header: str = "TEMP_archive_dl_",
|
325 |
+
verbose: bool = False,
|
326 |
+
):
|
327 |
+
"""
|
328 |
+
dl_extract_zip - generic function to download a zip file and extract it
|
329 |
+
|
330 |
+
Parameters
|
331 |
+
----------
|
332 |
+
URLtoget : str
|
333 |
+
zip file URL to download
|
334 |
+
extract_loc : str, optional
|
335 |
+
directory to extract zip to , by default None
|
336 |
+
file_header : str, optional
|
337 |
+
[description], by default "TEMP_archive_dl_"
|
338 |
+
verbose : bool, optional
|
339 |
+
[description], by default False
|
340 |
+
|
341 |
+
Returns
|
342 |
+
-------
|
343 |
+
str - path to the downloaded and extracted folder
|
344 |
+
"""
|
345 |
+
|
346 |
+
extract_loc = Path(extract_loc)
|
347 |
+
extract_loc.mkdir(parents=True, exist_ok=True)
|
348 |
+
|
349 |
+
save_loc = download_URL(
|
350 |
+
url=URLtoget, file=f"{file_header}.zip", dlpath=None, verbose=verbose
|
351 |
+
)
|
352 |
+
|
353 |
+
shutil.unpack_archive(save_loc, extract_dir=extract_loc)
|
354 |
+
|
355 |
+
if verbose:
|
356 |
+
print("extracted zip file - ", datetime.now())
|
357 |
+
x = load_dir_files(extract_loc, req_extension="", verbose=verbose)
|
358 |
+
|
359 |
+
# remove original
|
360 |
+
try:
|
361 |
+
os.remove(save_loc)
|
362 |
+
del save_loc
|
363 |
+
except Exception:
|
364 |
+
print("unable to delete original zipfile - check if exists", datetime.now())
|
365 |
+
|
366 |
+
if verbose:
|
367 |
+
print("finished extracting zip - ", datetime.now())
|
368 |
+
|
369 |
+
return extract_loc
|
370 |
+
|
371 |
+
|
372 |
+
def cleantxt_wrap(ugly_text, all_lower=False):
|
373 |
+
"""
|
374 |
+
cleantxt_wrap - applies the clean function to a string.
|
375 |
+
|
376 |
+
Args:
|
377 |
+
ugly_text (str): [string to be cleaned]
|
378 |
+
|
379 |
+
Returns:
|
380 |
+
[str]: [cleaned string]
|
381 |
+
"""
|
382 |
+
if isinstance(ugly_text, str) and len(ugly_text) > 0:
|
383 |
+
return clean(ugly_text, lower=all_lower)
|
384 |
+
else:
|
385 |
+
return ugly_text
|