nxphi47 commited on
Commit
69bfc1e
·
1 Parent(s): 1cb08ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -34
app.py CHANGED
@@ -28,7 +28,11 @@ from huggingface_hub import snapshot_download
28
  # @@ environments ================
29
 
30
  DEBUG = bool(int(os.environ.get("DEBUG", "1")))
31
- BLOCK_ZH = bool(int(os.environ.get("BLOCK_ZH", "1")))
 
 
 
 
32
  # for lang block, wether to block in history too
33
  LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0")))
34
  TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
@@ -64,6 +68,9 @@ STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
64
  # how many iterations to perform safety check on response
65
  STREAM_CHECK_MULTIPLE = int(os.environ.get("STREAM_CHECK_MULTIPLE", "0"))
66
 
 
 
 
67
  # self explanatory
68
  MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
69
  TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
@@ -88,7 +95,6 @@ MODEL_PATH=./seal-13b-chat-a
88
 
89
  """
90
 
91
-
92
  # ==============================
93
  print(f'DEBUG mode: {DEBUG}')
94
  print(f'Torch version: {torch.__version__}')
@@ -136,7 +142,6 @@ Your response should adapt to the norms and customs of the respective language a
136
  # ============ CONSTANT ============
137
  # https://github.com/gradio-app/gradio/issues/884
138
  MODEL_NAME = "SeaLLM-13B"
139
- MODEL_TITLE = "SeaLLM-13B - An Assistant for Southeast Asian Languages"
140
 
141
  MODEL_TITLE = """
142
  <div class="container" style="
@@ -212,22 +217,6 @@ path_markdown = """
212
 
213
 
214
 
215
-
216
- def _detect_lang(text):
217
- # Disable language that may have safety risk
218
- from langdetect import detect as detect_lang
219
- dlang = None
220
- try:
221
- dlang = detect_lang(text)
222
- except Exception as e:
223
- print(f'Error: {e}')
224
- if "No features in text." in str(e):
225
- return "en"
226
- else:
227
- return "zh"
228
- return dlang
229
-
230
-
231
  def custom_hf_model_weights_iterator(
232
  model_name_or_path: str,
233
  cache_dir: Optional[str] = None,
@@ -1003,18 +992,39 @@ Please also consider clearing the chat box for a better experience."""
1003
 
1004
  KEYWORD_BLOCK_MESSAGE = "Sorry, I cannot fulfill your request. If you have any unrelated question, I'll be glad to help."
1005
 
1006
- def block_zh(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1007
  message: str,
1008
  history: List[Tuple[str, str]] = None,
1009
  ) -> str:
1010
  # relieve history base block
 
 
 
1011
  if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history):
1012
  return True
1013
- elif 'zh' in _detect_lang(message):
1014
- print(f'Detect zh: {message}')
1015
- return True
1016
  else:
1017
- return False
 
 
 
 
 
1018
 
1019
 
1020
  def log_responses(history, message, response):
@@ -1029,13 +1039,9 @@ def safety_check(text, history=None, ) -> Optional[str]:
1029
  if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
1030
  return KEYWORD_BLOCK_MESSAGE
1031
 
1032
- if BLOCK_ZH:
1033
- if history is not None:
1034
- if block_zh(text, history):
1035
- return LANG_BLOCK_MESSAGE
1036
- else:
1037
- if "zh" in _detect_lang(text):
1038
- return LANG_BLOCK_MESSAGE
1039
 
1040
  return None
1041
 
@@ -1159,6 +1165,7 @@ def check_model_path(model_path) -> str:
1159
 
1160
  def maybe_delete_folder():
1161
  if IS_DELETE_FOLDER and DOWNLOAD_SNAPSHOT:
 
1162
  print(f'DELETE ALL FILES IN {DELETE_FOLDER}')
1163
  for filename in os.listdir(DELETE_FOLDER):
1164
  file_path = os.path.join(DELETE_FOLDER, filename)
@@ -1170,6 +1177,11 @@ def maybe_delete_folder():
1170
  except Exception as e:
1171
  print('Failed to delete %s. Reason: %s' % (file_path, e))
1172
 
 
 
 
 
 
1173
 
1174
  def launch():
1175
  global demo, llm, DEBUG
@@ -1187,8 +1199,10 @@ def launch():
1187
  ckpt_info = "None"
1188
 
1189
  print(
1190
- f'Launch config: {tensor_parallel=} / {dtype=} / {max_tokens} | {BLOCK_ZH=} '
1191
  f'\n| model_title=`{model_title}` '
 
 
1192
  f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
1193
  f'\n| STREAM_CHECK_MULTIPLE={STREAM_CHECK_MULTIPLE} '
1194
  f'\n| DISPLAY_MODEL_PATH={DISPLAY_MODEL_PATH} '
@@ -1210,6 +1224,7 @@ def launch():
1210
  print(f'Creating in DEBUG MODE')
1211
  else:
1212
  # ! load the model
 
1213
 
1214
  if DOWNLOAD_SNAPSHOT:
1215
  print(f'Downloading from HF_MODEL_NAME={hf_model_name} -> {model_path}')
@@ -1258,7 +1273,8 @@ def launch():
1258
  latex_delimiters=[
1259
  { "left": "$", "right": "$", "display": False},
1260
  { "left": "$$", "right": "$$", "display": True},
1261
- ]
 
1262
  ),
1263
  textbox=gr.Textbox(placeholder='Type message', lines=8, max_lines=128, min_width=200),
1264
  submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
@@ -1276,10 +1292,13 @@ def launch():
1276
  )
1277
  demo.title = MODEL_NAME
1278
  with demo:
1279
- # gr.Markdown(warning_markdown)
1280
  gr.Markdown(cite_markdown)
1281
  if DISPLAY_MODEL_PATH:
1282
  gr.Markdown(path_markdown.format(model_path=model_path))
 
 
 
 
1283
 
1284
  demo.queue()
1285
  demo.launch(server_port=PORT)
 
28
  # @@ environments ================
29
 
30
  DEBUG = bool(int(os.environ.get("DEBUG", "1")))
31
+
32
+ # List of languages to block
33
+ BLOCK_LANGS = str(os.environ.get("BLOCK_LANGS", ""))
34
+ BLOCK_LANGS = BLOCK_LANGS.strip().split(";") if len(BLOCK_LANGS.strip()) > 0 else []
35
+
36
  # for lang block, wether to block in history too
37
  LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0")))
38
  TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
 
68
  # how many iterations to perform safety check on response
69
  STREAM_CHECK_MULTIPLE = int(os.environ.get("STREAM_CHECK_MULTIPLE", "0"))
70
 
71
+ # whether to enable to popup accept user
72
+ ENABLE_AGREE_POPUP = bool(int(os.environ.get("ENABLE_AGREE_POPUP", "0")))
73
+
74
  # self explanatory
75
  MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
76
  TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
 
95
 
96
  """
97
 
 
98
  # ==============================
99
  print(f'DEBUG mode: {DEBUG}')
100
  print(f'Torch version: {torch.__version__}')
 
142
  # ============ CONSTANT ============
143
  # https://github.com/gradio-app/gradio/issues/884
144
  MODEL_NAME = "SeaLLM-13B"
 
145
 
146
  MODEL_TITLE = """
147
  <div class="container" style="
 
217
 
218
 
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  def custom_hf_model_weights_iterator(
221
  model_name_or_path: str,
222
  cache_dir: Optional[str] = None,
 
992
 
993
  KEYWORD_BLOCK_MESSAGE = "Sorry, I cannot fulfill your request. If you have any unrelated question, I'll be glad to help."
994
 
995
+
996
+ def _detect_lang(text):
997
+ # Disable language that may have safety risk
998
+ from langdetect import detect as detect_lang
999
+ dlang = None
1000
+ try:
1001
+ dlang = detect_lang(text)
1002
+ except Exception as e:
1003
+ print(f'Error: {e}')
1004
+ if "No features in text." in str(e):
1005
+ return "en"
1006
+ else:
1007
+ return "zh"
1008
+ return dlang
1009
+
1010
+
1011
+ def block_lang(
1012
  message: str,
1013
  history: List[Tuple[str, str]] = None,
1014
  ) -> str:
1015
  # relieve history base block
1016
+ if len(BLOCK_LANGS) == 0:
1017
+ return False
1018
+
1019
  if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history):
1020
  return True
 
 
 
1021
  else:
1022
+ _lang = _detect_lang(message)
1023
+ if _lang in BLOCK_LANGS:
1024
+ print(f'Detect blocked {_lang}: {message}')
1025
+ return True
1026
+ else:
1027
+ return False
1028
 
1029
 
1030
  def log_responses(history, message, response):
 
1039
  if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
1040
  return KEYWORD_BLOCK_MESSAGE
1041
 
1042
+ if len(BLOCK_LANGS) > 0:
1043
+ if block_lang(text, history):
1044
+ return LANG_BLOCK_MESSAGE
 
 
 
 
1045
 
1046
  return None
1047
 
 
1165
 
1166
  def maybe_delete_folder():
1167
  if IS_DELETE_FOLDER and DOWNLOAD_SNAPSHOT:
1168
+ import shutil
1169
  print(f'DELETE ALL FILES IN {DELETE_FOLDER}')
1170
  for filename in os.listdir(DELETE_FOLDER):
1171
  file_path = os.path.join(DELETE_FOLDER, filename)
 
1177
  except Exception as e:
1178
  print('Failed to delete %s. Reason: %s' % (file_path, e))
1179
 
1180
+ AGREE_POP_SCRIPTS = """
1181
+ async () => {
1182
+ alert("To use our service, you are required to agree to the following terms:\\nYou must not use our service to generate any harmful, unethical or illegal content that violates local and international laws, including but not limited to hate speech, violence and deception.\\nThe service may collect user dialogue data for performance improvement, and reserves the right to distribute it under CC-BY or similar license. So do not enter any personal information!");
1183
+ }
1184
+ """
1185
 
1186
  def launch():
1187
  global demo, llm, DEBUG
 
1199
  ckpt_info = "None"
1200
 
1201
  print(
1202
+ f'Launch config: {tensor_parallel=} / {dtype=} / {max_tokens} '
1203
  f'\n| model_title=`{model_title}` '
1204
+ f'\n| BLOCK_LANGS={BLOCK_LANGS} '
1205
+ f'\n| IS_DELETE_FOLDER={IS_DELETE_FOLDER} '
1206
  f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
1207
  f'\n| STREAM_CHECK_MULTIPLE={STREAM_CHECK_MULTIPLE} '
1208
  f'\n| DISPLAY_MODEL_PATH={DISPLAY_MODEL_PATH} '
 
1224
  print(f'Creating in DEBUG MODE')
1225
  else:
1226
  # ! load the model
1227
+ maybe_delete_folder()
1228
 
1229
  if DOWNLOAD_SNAPSHOT:
1230
  print(f'Downloading from HF_MODEL_NAME={hf_model_name} -> {model_path}')
 
1273
  latex_delimiters=[
1274
  { "left": "$", "right": "$", "display": False},
1275
  { "left": "$$", "right": "$$", "display": True},
1276
+ ],
1277
+ show_copy_button=True,
1278
  ),
1279
  textbox=gr.Textbox(placeholder='Type message', lines=8, max_lines=128, min_width=200),
1280
  submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
 
1292
  )
1293
  demo.title = MODEL_NAME
1294
  with demo:
 
1295
  gr.Markdown(cite_markdown)
1296
  if DISPLAY_MODEL_PATH:
1297
  gr.Markdown(path_markdown.format(model_path=model_path))
1298
+
1299
+ if ENABLE_AGREE_POPUP:
1300
+ demo.load(None, None, None, _js=AGREE_POP_SCRIPTS)
1301
+
1302
 
1303
  demo.queue()
1304
  demo.launch(server_port=PORT)