Spaces:
Runtime error
Runtime error
test
Browse files
app.py
CHANGED
@@ -6,15 +6,6 @@ For more information on `huggingface_hub` Inference API support, please check th
|
|
6 |
"""
|
7 |
# client = InferenceClient("EITD/lora_model", token=os.getenv("HF_TOKEN"))
|
8 |
|
9 |
-
class CustomTextStreamer(TextStreamer):
|
10 |
-
def __init__(self, tokenizer):
|
11 |
-
super().__init__(tokenizer)
|
12 |
-
self.generated_text = ""
|
13 |
-
|
14 |
-
def on_token(self, token):
|
15 |
-
super().on_token(token)
|
16 |
-
self.generated_text += token
|
17 |
-
|
18 |
model = AutoPeftModelForCausalLM.from_pretrained(
|
19 |
"EITD/lora_model_1", # YOUR MODEL YOU USED FOR TRAINING
|
20 |
load_in_4bit = False,
|
@@ -53,7 +44,7 @@ def respond(
|
|
53 |
|
54 |
messages.append({"role": "user", "content": message})
|
55 |
|
56 |
-
response = ""
|
57 |
|
58 |
# for message in client.chat_completion(
|
59 |
# messages,
|
@@ -74,13 +65,14 @@ def respond(
|
|
74 |
return_tensors = "pt",
|
75 |
)
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
response += token
|
83 |
-
yield response
|
84 |
|
85 |
|
86 |
"""
|
|
|
6 |
"""
|
7 |
# client = InferenceClient("EITD/lora_model", token=os.getenv("HF_TOKEN"))
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
model = AutoPeftModelForCausalLM.from_pretrained(
|
10 |
"EITD/lora_model_1", # YOUR MODEL YOU USED FOR TRAINING
|
11 |
load_in_4bit = False,
|
|
|
44 |
|
45 |
messages.append({"role": "user", "content": message})
|
46 |
|
47 |
+
# response = ""
|
48 |
|
49 |
# for message in client.chat_completion(
|
50 |
# messages,
|
|
|
65 |
return_tensors = "pt",
|
66 |
)
|
67 |
|
68 |
+
outputs = model.generate(input_ids = inputs, max_new_tokens = max_tokens, use_cache = True,
|
69 |
+
temperature = temperature, min_p = top_p)
|
70 |
+
|
71 |
+
# text_streamer = TextStreamer(tokenizer, skip_prompt = True)
|
72 |
+
# model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = max_tokens,
|
73 |
+
# use_cache = True, temperature = temperature, min_p = top_p)
|
74 |
|
75 |
+
yield tokenizer.batch_decode(outputs, skip_special_tokens = True)
|
|
|
|
|
76 |
|
77 |
|
78 |
"""
|