HannahLin271 commited on
Commit
9c95e37
·
verified ·
1 Parent(s): 9782755

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +93 -0
utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/HannahLin271/nanoGPT_single_conversation/resolve/main/pytorch_model.bin
2
+ import os
3
+ import torch
4
+ from model import GPTConfig, GPT
5
+ from huggingface_hub import hf_hub_download
6
+ import shutil
7
+ import re
8
+ import sys
9
+ out_dir = "./out"
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ import requests
12
+ from pathlib import Path
13
+ from tqdm import tqdm
14
+ import gradio as gr
15
+
16
+ def download_file(url, output_path):
17
+ response = requests.get(url, stream=True)
18
+ response.raise_for_status()
19
+
20
+ total_size = int(response.headers.get("content-length", 0))
21
+ block_size = 1024
22
+
23
+ # Create a progress bar
24
+ progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True)
25
+
26
+ with open(output_path, "wb") as file:
27
+ for chunk in response.iter_content(chunk_size=block_size):
28
+ progress_bar.update(len(chunk))
29
+ file.write(chunk)
30
+
31
+ progress_bar.close()
32
+
33
+ if total_size != 0 and progress_bar.n != total_size:
34
+ print("Error: Downloaded file size does not match expected size")
35
+ else:
36
+ print(f"Download complete: {output_path}")
37
+
38
+ try:
39
+ # Send a GET request to the URL
40
+ response = requests.get(url, stream=True)
41
+ response.raise_for_status() # Check if the request was successful
42
+ if not os.path.exists(output_path):
43
+ print("downloading...")
44
+ output_path = Path(output_path)
45
+ output_path.parent.mkdir(parents=True, exist_ok=True)
46
+ with open(output_path, "wb") as file:
47
+ for chunk in response.iter_content(chunk_size=8192):
48
+ file.write(chunk)
49
+
50
+ print(f"File downloaded successfully and saved as {output_path}")
51
+ except requests.exceptions.RequestException as e:
52
+ print(f"An error occurred: {e}")
53
+
54
+ def init_model_from(url, filename):
55
+ # if file not exists, download
56
+ ckpt_path = Path(out_dir) / filename
57
+ ckpt_path.parent.mkdir(parents=True, exist_ok=True)
58
+ if not os.path.exists(ckpt_path):
59
+ gr.Info('Downloading model...')
60
+ download_file(url, ckpt_path)
61
+ gr.Info('✅Model downloaded successfully.', duration=2)
62
+ checkpoint = torch.load(ckpt_path, map_location=device)
63
+ gptconf = GPTConfig(**checkpoint['model_args'])
64
+ model = GPT(gptconf)
65
+ state_dict = checkpoint['model']
66
+ unwanted_prefix = '_orig_mod.'
67
+ for k,v in list(state_dict.items()):
68
+ if k.startswith(unwanted_prefix):
69
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
70
+ model.load_state_dict(state_dict)
71
+ return model
72
+
73
+ def respond(input, samples, model, encode, decode, max_new_tokens,temperature, top_k): # generation function
74
+ x = (torch.tensor(encode(input), dtype=torch.long, device=device)[None, ...])
75
+ with torch.no_grad():
76
+ for k in range(samples):
77
+ generated = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
78
+
79
+ output = decode(generated[0].tolist())
80
+
81
+ match_botoutput = re.search(r'<human>(.*?)<', output)
82
+ match_emotion = re.search(r'<emotion>\s*(.*?)\s*<', output)
83
+ match_context = re.search(r'<context>\s*(.*?)\s*<', output)
84
+ response = ''
85
+ emotion = ''
86
+ context = ''
87
+ if match_botoutput:
88
+ try :
89
+ response = match_botoutput.group(1).replace('<endOfText>','')
90
+ except:
91
+ response = match_botoutput.group(1)
92
+ #return response, emotion, context
93
+ return [input, response]