EITD commited on
Commit
cb54541
·
1 Parent(s): a408e8f
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -6,6 +6,15 @@ For more information on `huggingface_hub` Inference API support, please check th
6
  """
7
  # client = InferenceClient("EITD/lora_model", token=os.getenv("HF_TOKEN"))
8
 
 
 
 
 
 
 
 
 
 
9
  model = AutoPeftModelForCausalLM.from_pretrained(
10
  "EITD/lora_model_1", # YOUR MODEL YOU USED FOR TRAINING
11
  load_in_4bit = False,
@@ -44,7 +53,7 @@ def respond(
44
 
45
  messages.append({"role": "user", "content": message})
46
 
47
- # response = ""
48
 
49
  # for message in client.chat_completion(
50
  # messages,
@@ -65,14 +74,13 @@ def respond(
65
  return_tensors = "pt",
66
  )
67
 
68
- # outputs = model.generate(input_ids = inputs, max_new_tokens = max_tokens, use_cache = True,
69
- # temperature = temperature, min_p = top_p)
70
-
71
- text_streamer = TextStreamer(tokenizer, skip_prompt = True)
72
- model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = max_tokens,
73
  use_cache = True, temperature = temperature, min_p = top_p)
74
 
75
- # return tokenizer.batch_decode(outputs)
 
 
76
 
77
 
78
  """
 
6
  """
7
  # client = InferenceClient("EITD/lora_model", token=os.getenv("HF_TOKEN"))
8
 
9
+ class CustomTextStreamer(TextStreamer):
10
+ def __init__(self, tokenizer):
11
+ super().__init__(tokenizer)
12
+ self.generated_text = ""
13
+
14
+ def on_token(self, token):
15
+ super().on_token(token)
16
+ self.generated_text += token
17
+
18
  model = AutoPeftModelForCausalLM.from_pretrained(
19
  "EITD/lora_model_1", # YOUR MODEL YOU USED FOR TRAINING
20
  load_in_4bit = False,
 
53
 
54
  messages.append({"role": "user", "content": message})
55
 
56
+ response = ""
57
 
58
  # for message in client.chat_completion(
59
  # messages,
 
74
  return_tensors = "pt",
75
  )
76
 
77
+ custom_streamer = CustomTextStreamer(tokenizer)
78
+ model.generate(input_ids = inputs, streamer = custom_streamer, max_new_tokens = max_tokens,
 
 
 
79
  use_cache = True, temperature = temperature, min_p = top_p)
80
 
81
+ for token in custom_streamer.generated_text:
82
+ response += token
83
+ yield response
84
 
85
 
86
  """