File size: 3,505 Bytes
9c95e37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5371253
9c95e37
 
 
 
 
 
 
 
 
 
 
 
 
5371253
 
9c95e37
 
 
5371253
9c95e37
 
 
5371253
 
9c95e37
5371253
9c95e37
 
 
5371253
9c95e37
5371253
9c95e37
5371253
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# https://huggingface.co/HannahLin271/nanoGPT_single_conversation/resolve/main/pytorch_model.bin
import os
import torch
from model import GPTConfig, GPT
from huggingface_hub import hf_hub_download
import shutil
import re
import sys
out_dir = "./out"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import requests
from pathlib import Path
from tqdm import tqdm
import gradio as gr

def download_file(url, output_path):
    response = requests.get(url, stream=True)
    response.raise_for_status() 

    total_size = int(response.headers.get("content-length", 0))
    block_size = 1024 

    # Create a progress bar
    progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True)

    with open(output_path, "wb") as file:
        for chunk in response.iter_content(chunk_size=block_size):
            progress_bar.update(len(chunk))  
            file.write(chunk) 

    progress_bar.close()

    if total_size != 0 and progress_bar.n != total_size:
        print("Error: Downloaded file size does not match expected size")
    else:
        print(f"Download complete: {output_path}")
  
    try:
        # Send a GET request to the URL
        response = requests.get(url, stream=True)
        response.raise_for_status()  # Check if the request was successful
        if not os.path.exists(output_path):
            print("downloading...")
            output_path = Path(output_path)
            output_path.parent.mkdir(parents=True, exist_ok=True)
            with open(output_path, "wb") as file:
                for chunk in response.iter_content(chunk_size=8192):
                    file.write(chunk)

        print(f"File downloaded successfully and saved as {output_path}")
    except requests.exceptions.RequestException as e:
        print(f"An error occurred: {e}")

def init_model_from(url, filename):
    # if file not exists, download
    ckpt_path = Path(out_dir) / filename
    ckpt_path.parent.mkdir(parents=True, exist_ok=True)
    if not os.path.exists(ckpt_path):
        gr.Info('Downloading model...',duration=10)
        download_file(url, ckpt_path)
        gr.Info('✅Model downloaded successfully.', duration=2)
    checkpoint = torch.load(ckpt_path, map_location=device)
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    return model

def respond(input, samples, model, encode, decode, max_new_tokens,temperature, top_k):
    input = "<bot> " + input
    x = (torch.tensor(encode(input), dtype=torch.long, device=device)[None, ...]) 
    with torch.no_grad():
        for k in range(samples):
            
            generated = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)

            output = decode(generated[0].tolist())   
            # if input in output:
            #     output = output.split(input)[-1].strip()  # Take the part after `<input>`

            match_botoutput = re.search(r'<human>(.*?)<', output, re.DOTALL)
            response = ''
            if match_botoutput:
                try :
                    response = match_botoutput.group(1).strip()
                except:
                    response = ''
            #return response, emotion, context
            return [input, response, output]