cppowboy commited on
Commit
7f8e63d
·
verified ·
1 Parent(s): a9d3817

Update tokenization_minicpm.py

Browse files
Files changed (1) hide show
  1. tokenization_minicpm.py +10 -10
tokenization_minicpm.py CHANGED
@@ -4,7 +4,6 @@ import keyword
4
  import traceback
5
  import uuid
6
  from collections import deque
7
- from copy import deepcopy
8
  from logging import getLogger
9
  from typing import Any, Dict, List, Optional, Union
10
 
@@ -17,6 +16,7 @@ from jsonschema import Draft202012Validator, exceptions, validate
17
  from transformers import LlamaTokenizerFast
18
  from transformers.tokenization_utils_base import BatchEncoding
19
  from transformers.utils import TensorType
 
20
 
21
 
22
  logger = getLogger(__name__)
@@ -148,7 +148,7 @@ class MiniCPMTokenizer(LlamaTokenizerFast):
148
  tool_calls.append(this_one)
149
 
150
  return {
151
- "content": content.strip(),
152
  "tool_calls": [
153
  {"type": "function", "function": tool_call, "id": "call_" + uuid.uuid4().hex}
154
  for tool_call in tool_calls
@@ -158,13 +158,13 @@ class MiniCPMTokenizer(LlamaTokenizerFast):
158
  except:
159
  logger.error(traceback.format_exc())
160
  return {
161
- "content": content.strip(),
162
  "role": "assistant",
163
  "thought": thought_string,
164
  }
165
  else:
166
  return {
167
- "content": sequence.strip(),
168
  "role": "assistant",
169
  "thought": thought_string,
170
  }
@@ -259,10 +259,11 @@ def message_format(msg, system_suffix="", user_prefix=""):
259
  content = thought_prefix + content
260
  msg["content"] = content
261
  elif msg["role"] == "user":
262
- msg["content"] = user_prefix + "\n" + msg["content"]
 
263
  elif msg["role"] == "system":
264
  msg["content"] = msg["content"] + "\n" + system_suffix
265
- msg["content"] = msg["content"].strip()
266
  return msg
267
 
268
 
@@ -361,12 +362,12 @@ func2(params)
361
  <|tool_call_end|>
362
  {{answer the user's question directly or ask the user for more information}}
363
  """
364
- tools_string = tools_template.format(tools=tools_string).strip()
365
  else:
366
  tools_string = ""
367
 
368
  if add_to_system:
369
- if len(messages) > 0 and messages[0]["role"] != "system" and tools_string != "":
370
  messages.insert(0, {"role": "system", "content": ""})
371
  return [message_format(msg, system_suffix=tools_string, user_prefix="") for msg in messages]
372
  else:
@@ -429,5 +430,4 @@ def resolve_ast_by_type(value):
429
  output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
430
  else:
431
  raise Exception(f"Unsupported AST type: {type(value)}")
432
- return output
433
-
 
4
  import traceback
5
  import uuid
6
  from collections import deque
 
7
  from logging import getLogger
8
  from typing import Any, Dict, List, Optional, Union
9
 
 
16
  from transformers import LlamaTokenizerFast
17
  from transformers.tokenization_utils_base import BatchEncoding
18
  from transformers.utils import TensorType
19
+ from copy import deepcopy
20
 
21
 
22
  logger = getLogger(__name__)
 
148
  tool_calls.append(this_one)
149
 
150
  return {
151
+ "content": content,
152
  "tool_calls": [
153
  {"type": "function", "function": tool_call, "id": "call_" + uuid.uuid4().hex}
154
  for tool_call in tool_calls
 
158
  except:
159
  logger.error(traceback.format_exc())
160
  return {
161
+ "content": content,
162
  "role": "assistant",
163
  "thought": thought_string,
164
  }
165
  else:
166
  return {
167
+ "content": sequence,
168
  "role": "assistant",
169
  "thought": thought_string,
170
  }
 
259
  content = thought_prefix + content
260
  msg["content"] = content
261
  elif msg["role"] == "user":
262
+ if user_prefix != "":
263
+ msg["content"] = user_prefix + "\n" + msg["content"]
264
  elif msg["role"] == "system":
265
  msg["content"] = msg["content"] + "\n" + system_suffix
266
+ msg["content"] = msg["content"]
267
  return msg
268
 
269
 
 
362
  <|tool_call_end|>
363
  {{answer the user's question directly or ask the user for more information}}
364
  """
365
+ tools_string = tools_template.format(tools=tools_string)
366
  else:
367
  tools_string = ""
368
 
369
  if add_to_system:
370
+ if len(messages) > 0 and messages[0]["role"] != "system" and len(tools_string.strip()) > 0:
371
  messages.insert(0, {"role": "system", "content": ""})
372
  return [message_format(msg, system_suffix=tools_string, user_prefix="") for msg in messages]
373
  else:
 
430
  output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
431
  else:
432
  raise Exception(f"Unsupported AST type: {type(value)}")
433
+ return output