muryshev commited on
Commit
5efafc9
·
1 Parent(s): efee073

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -24
app.py CHANGED
@@ -2,9 +2,12 @@ from flask import Flask, request, Response
2
  import logging
3
  from llama_cpp import Llama
4
  import threading
5
- from huggingface_hub import snapshot_download
 
6
  import gc
7
  import os.path
 
 
8
 
9
  SYSTEM_PROMPT = "Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык."
10
  SYSTEM_TOKEN = 1788
@@ -51,6 +54,29 @@ model = None
51
  model_path = snapshot_download(repo_id=repo_name, allow_patterns=model_name) + '/' + model_name
52
  app.logger.info('Model path: ' + model_path)
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def init_model(context_size, enable_gpu=False, gpu_layer_number=35):
55
  global model
56
 
@@ -221,18 +247,8 @@ def generate_response():
221
  top_k = parameters.get("top_k", 30)
222
  return_full_text = parameters.get("return_full_text", False)
223
 
224
-
225
- # Generate the response
226
- #system_tokens = get_system_tokens(model)
227
- #tokens = system_tokens
228
-
229
- #if preprompt != "":
230
- # tokens = get_system_tokens_for_preprompt(model, preprompt)
231
- #else:
232
  tokens = get_system_tokens(model)
233
- tokens.append(LINEBREAK_TOKEN)
234
- #model.eval(tokens)
235
-
236
 
237
  tokens = []
238
 
@@ -243,22 +259,13 @@ def generate_response():
243
  message_tokens = get_message_tokens(model=model, role="user", content=message.get("content", ""))
244
 
245
  tokens.extend(message_tokens)
246
-
247
- #app.logger.info('model.eval start')
248
- #model.eval(tokens)
249
- #app.logger.info('model.eval end')
250
-
251
- #last_message = messages[-1]
252
- #if last_message.get("from") == "assistant":
253
- # last_message_tokens = get_message_tokens(model=model, role="bot", content=last_message.get("content", ""))
254
- #else:
255
- # last_message_tokens = get_message_tokens(model=model, role="user", content=last_message.get("content", ""))
256
 
257
  tokens.extend([model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN])
258
 
259
 
260
  app.logger.info('Prompt:')
261
- app.logger.info(model.detokenize(tokens[:CONTEXT_SIZE]).decode("utf-8", errors="ignore"))
 
262
 
263
  stop_generation = False
264
  app.logger.info('Generate started')
@@ -271,8 +278,20 @@ def generate_response():
271
  )
272
  app.logger.info('Generator created')
273
 
 
 
 
 
 
 
 
 
 
 
274
  # Use Response to stream tokens
275
- return Response(generate_tokens(model, generator), content_type='text/plain', status=200, direct_passthrough=True)
 
 
276
 
277
  if __name__ == "__main__":
278
  app.run(host="0.0.0.0", port=7860, debug=False, threaded=False)
 
2
  import logging
3
  from llama_cpp import Llama
4
  import threading
5
+ from huggingface_hub import snapshot_download, Repository
6
+ import huggingface_hub
7
  import gc
8
  import os.path
9
+ import csv
10
+ from datetime import datetime
11
 
12
  SYSTEM_PROMPT = "Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык."
13
  SYSTEM_TOKEN = 1788
 
54
  model_path = snapshot_download(repo_id=repo_name, allow_patterns=model_name) + '/' + model_name
55
  app.logger.info('Model path: ' + model_path)
56
 
57
+ DATASET_REPO_URL = "https://huggingface.co/datasets/muryshev/saiga-chat"
58
+ DATA_FILENAME = "data.csv"
59
+ DATA_FILE = os.path.join("data", DATA_FILENAME)
60
+
61
+ HF_TOKEN = os.environ.get("HF_TOKEN")
62
+ app.logger.info("HF_TOKEN is None?", HF_TOKEN is None)
63
+ app.logger.info("hfh", huggingface_hub.__version__)
64
+
65
+ repo = Repository(
66
+ local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
67
+ )
68
+
69
+ def log(request: str = '', response: str = ''):
70
+ if request or response:
71
+ with open(DATA_FILE, "a") as csvfile:
72
+ writer = csv.DictWriter(csvfile, fieldnames=["request", "response", "time"])
73
+ writer.writerow(
74
+ {"request": request, "response": response, "time": str(datetime.now())}
75
+ )
76
+ commit_url = repo.push_to_hub()
77
+ app.logger.info(commit_url)
78
+
79
+
80
  def init_model(context_size, enable_gpu=False, gpu_layer_number=35):
81
  global model
82
 
 
247
  top_k = parameters.get("top_k", 30)
248
  return_full_text = parameters.get("return_full_text", False)
249
 
 
 
 
 
 
 
 
 
250
  tokens = get_system_tokens(model)
251
+ tokens.append(LINEBREAK_TOKEN)
 
 
252
 
253
  tokens = []
254
 
 
259
  message_tokens = get_message_tokens(model=model, role="user", content=message.get("content", ""))
260
 
261
  tokens.extend(message_tokens)
 
 
 
 
 
 
 
 
 
 
262
 
263
  tokens.extend([model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN])
264
 
265
 
266
  app.logger.info('Prompt:')
267
+ request = model.detokenize(tokens[:CONTEXT_SIZE]).decode("utf-8", errors="ignore")
268
+ app.logger.info(request)
269
 
270
  stop_generation = False
271
  app.logger.info('Generate started')
 
278
  )
279
  app.logger.info('Generator created')
280
 
281
+
282
+ response_tokens = []
283
+ def generate_and_log_tokens(model, generator):
284
+ for token in generate_tokens(model, generator):
285
+ if token == model.token_eos(): # or (max_new_tokens is not None and i >= max_new_tokens):
286
+ log(request=request, response=model.detokenize(response_tokens).decode("utf-8", errors="ignore"))
287
+ break
288
+ response_tokens.append(token)
289
+ yield token
290
+
291
  # Use Response to stream tokens
292
+ return Response(generate_and_log_tokens(model, generator), content_type='text/plain', status=200, direct_passthrough=True)
293
+
294
+
295
 
296
  if __name__ == "__main__":
297
  app.run(host="0.0.0.0", port=7860, debug=False, threaded=False)