nsarrazin HF staff commited on
Commit
9187ced
·
unverified ·
1 Parent(s): f249cfc

Switch chat model back to mistral, use zephyr for small tasks (#515)

Browse files

* Switch chat model back to mistral, use zephyr for small tasks

* typo

* fix tests

Files changed (2) hide show
  1. .env.template +42 -10
  2. src/lib/server/models.ts +69 -67
.env.template CHANGED
@@ -94,24 +94,24 @@ MODELS=`[
94
  ]
95
  },
96
  {
97
- "name": "HuggingFaceH4/zephyr-7b-alpha",
98
- "displayName": "HuggingFaceH4/zephyr-7b-alpha",
99
- "description": "Zephyr 7B α is a fine-tune of Mistral 7B, released by the Hugging Face H4 RLHF team.",
100
- "websiteUrl": "https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha/",
101
  "preprompt": "",
102
- "chatPromptTemplate" : "<|system|>\n{{preprompt}}</s>\n{{#each messages}}{{#ifUser}}<|user|>\n{{content}}</s>\n<|assistant|>\n{{/ifUser}}{{#ifAssistant}}{{content}}</s>\n{{/ifAssistant}}{{/each}}",
103
  "parameters": {
104
- "temperature": 0.7,
105
  "top_p": 0.95,
106
  "repetition_penalty": 1.2,
107
  "top_k": 50,
108
  "truncate": 1000,
109
  "max_new_tokens": 2048,
110
- "stop": ["</s>", "<|>"]
111
  },
112
  "promptExamples": [
113
  {
114
- "title": "Write an email from bullet list",
115
  "prompt": "As a restaurant owner, write a professional email to the supplier to get these products every week: \n\n- Wine (x10)\n- Eggs (x24)\n- Bread (x12)"
116
  }, {
117
  "title": "Code a snake game",
@@ -124,8 +124,40 @@ MODELS=`[
124
  }
125
  ]`
126
 
127
- OLD_MODELS=`[{"name":"bigcode/starcoder"}, {"name":"OpenAssistant/oasst-sft-6-llama-30b-xor"}, {"name":"mistralai/Mistral-7B-Instruct-v0.1"}]`
128
- TASK_MODEL='HuggingFaceH4/zephyr-7b-alpha'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  APP_BASE="/chat"
131
  PUBLIC_ORIGIN=https://huggingface.co
 
94
  ]
95
  },
96
  {
97
+ "name": "mistralai/Mistral-7B-Instruct-v0.1",
98
+ "displayName": "mistralai/Mistral-7B-Instruct-v0.1",
99
+ "description": "Mistral 7B is a new Apache 2.0 model, released by Mistral AI that outperforms Llama2 13B in benchmarks.",
100
+ "websiteUrl": "https://mistral.ai/news/announcing-mistral-7b/",
101
  "preprompt": "",
102
+ "chatPromptTemplate" : "<s>{{#each messages}}{{#ifUser}}[INST] {{#if @first}}{{#if @root.preprompt}}{{@root.preprompt}}\n{{/if}}{{/if}}{{content}} [/INST]{{/ifUser}}{{#ifAssistant}}{{content}}</s>{{/ifAssistant}}{{/each}}",
103
  "parameters": {
104
+ "temperature": 0.1,
105
  "top_p": 0.95,
106
  "repetition_penalty": 1.2,
107
  "top_k": 50,
108
  "truncate": 1000,
109
  "max_new_tokens": 2048,
110
+ "stop": ["</s>"]
111
  },
112
  "promptExamples": [
113
  {
114
+ "title": "Write an email from bullet list",
115
  "prompt": "As a restaurant owner, write a professional email to the supplier to get these products every week: \n\n- Wine (x10)\n- Eggs (x24)\n- Bread (x12)"
116
  }, {
117
  "title": "Code a snake game",
 
124
  }
125
  ]`
126
 
127
+ OLD_MODELS=`[{"name":"bigcode/starcoder"}, {"name":"OpenAssistant/oasst-sft-6-llama-30b-xor"}, {"name":"HuggingFaceH4/zephyr-7b-alpha"}]`
128
+
129
+ TASK_MODEL='
130
+ {
131
+ "name": "HuggingFaceH4/zephyr-7b-alpha",
132
+ "displayName": "HuggingFaceH4/zephyr-7b-alpha",
133
+ "description": "Zephyr 7B α is a fine-tune of Mistral 7B, released by the Hugging Face H4 RLHF team.",
134
+ "websiteUrl": "https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha/",
135
+ "preprompt": "",
136
+ "chatPromptTemplate" : "<|system|>\n{{preprompt}}</s>\n{{#each messages}}{{#ifUser}}<|user|>\n{{content}}</s>\n<|assistant|>\n{{/ifUser}}{{#ifAssistant}}{{content}}</s>\n{{/ifAssistant}}{{/each}}",
137
+ "parameters": {
138
+ "temperature": 0.7,
139
+ "top_p": 0.95,
140
+ "repetition_penalty": 1.2,
141
+ "top_k": 50,
142
+ "truncate": 1000,
143
+ "max_new_tokens": 2048,
144
+ "stop": ["</s>", "<|>"]
145
+ },
146
+ "promptExamples": [
147
+ {
148
+ "title": "Write an email from bullet list",
149
+ "prompt": "As a restaurant owner, write a professional email to the supplier to get these products every week: \n\n- Wine (x10)\n- Eggs (x24)\n- Bread (x12)"
150
+ }, {
151
+ "title": "Code a snake game",
152
+ "prompt": "Code a basic snake game in python, give explanations for each step."
153
+ }, {
154
+ "title": "Assist in a task",
155
+ "prompt": "How do I make a delicious lemon cheesecake?"
156
+ }
157
+ ]
158
+ }
159
+ '
160
+
161
 
162
  APP_BASE="/chat"
163
  PUBLIC_ORIGIN=https://huggingface.co
src/lib/server/models.ts CHANGED
@@ -37,70 +37,68 @@ const combinedEndpoint = endpoint.transform((data) => {
37
  }
38
  });
39
 
40
- const modelsRaw = z
41
- .array(
42
- z.object({
43
- /** Used as an identifier in DB */
44
- id: z.string().optional(),
45
- /** Used to link to the model page, and for inference */
46
- name: z.string().min(1),
47
- displayName: z.string().min(1).optional(),
48
- description: z.string().min(1).optional(),
49
- websiteUrl: z.string().url().optional(),
50
- modelUrl: z.string().url().optional(),
51
- datasetName: z.string().min(1).optional(),
52
- datasetUrl: z.string().url().optional(),
53
- userMessageToken: z.string().default(""),
54
- userMessageEndToken: z.string().default(""),
55
- assistantMessageToken: z.string().default(""),
56
- assistantMessageEndToken: z.string().default(""),
57
- messageEndToken: z.string().default(""),
58
- preprompt: z.string().default(""),
59
- prepromptUrl: z.string().url().optional(),
60
- chatPromptTemplate: z
61
- .string()
62
- .default(
63
- "{{preprompt}}" +
64
- "{{#each messages}}" +
65
- "{{#ifUser}}{{@root.userMessageToken}}{{content}}{{@root.userMessageEndToken}}{{/ifUser}}" +
66
- "{{#ifAssistant}}{{@root.assistantMessageToken}}{{content}}{{@root.assistantMessageEndToken}}{{/ifAssistant}}" +
67
- "{{/each}}" +
68
- "{{assistantMessageToken}}"
69
- ),
70
- promptExamples: z
71
- .array(
72
- z.object({
73
- title: z.string().min(1),
74
- prompt: z.string().min(1),
75
- })
76
- )
77
- .optional(),
78
- endpoints: z.array(combinedEndpoint).optional(),
79
- parameters: z
80
- .object({
81
- temperature: z.number().min(0).max(1),
82
- truncate: z.number().int().positive(),
83
- max_new_tokens: z.number().int().positive(),
84
- stop: z.array(z.string()).optional(),
85
- })
86
- .passthrough()
87
- .optional(),
88
  })
89
- )
90
- .parse(JSON.parse(MODELS));
 
91
 
92
- export const models = await Promise.all(
93
- modelsRaw.map(async (m) => ({
94
- ...m,
95
- userMessageEndToken: m?.userMessageEndToken || m?.messageEndToken,
96
- assistantMessageEndToken: m?.assistantMessageEndToken || m?.messageEndToken,
97
- chatPromptRender: compileTemplate<ChatTemplateInput>(m.chatPromptTemplate, m),
98
- id: m.id || m.name,
99
- displayName: m.displayName || m.name,
100
- preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,
101
- parameters: { ...m.parameters, stop_sequences: m.parameters?.stop },
102
- }))
103
- );
 
 
104
 
105
  // Models that have been deprecated
106
  export const oldModels = OLD_MODELS
@@ -116,14 +114,18 @@ export const oldModels = OLD_MODELS
116
  .map((m) => ({ ...m, id: m.id || m.name, displayName: m.displayName || m.name }))
117
  : [];
118
 
119
- export type BackendModel = Optional<(typeof models)[0], "preprompt" | "parameters">;
120
- export type Endpoint = z.infer<typeof endpoint>;
121
-
122
  export const defaultModel = models[0];
123
 
124
- export const smallModel = models.find((m) => m.name === TASK_MODEL) || defaultModel;
125
-
126
  export const validateModel = (_models: BackendModel[]) => {
127
  // Zod enum function requires 2 parameters
128
  return z.enum([_models[0].id, ..._models.slice(1).map((m) => m.id)]);
129
  };
 
 
 
 
 
 
 
 
 
 
37
  }
38
  });
39
 
40
+ const modelConfig = z.object({
41
+ /** Used as an identifier in DB */
42
+ id: z.string().optional(),
43
+ /** Used to link to the model page, and for inference */
44
+ name: z.string().min(1),
45
+ displayName: z.string().min(1).optional(),
46
+ description: z.string().min(1).optional(),
47
+ websiteUrl: z.string().url().optional(),
48
+ modelUrl: z.string().url().optional(),
49
+ datasetName: z.string().min(1).optional(),
50
+ datasetUrl: z.string().url().optional(),
51
+ userMessageToken: z.string().default(""),
52
+ userMessageEndToken: z.string().default(""),
53
+ assistantMessageToken: z.string().default(""),
54
+ assistantMessageEndToken: z.string().default(""),
55
+ messageEndToken: z.string().default(""),
56
+ preprompt: z.string().default(""),
57
+ prepromptUrl: z.string().url().optional(),
58
+ chatPromptTemplate: z
59
+ .string()
60
+ .default(
61
+ "{{preprompt}}" +
62
+ "{{#each messages}}" +
63
+ "{{#ifUser}}{{@root.userMessageToken}}{{content}}{{@root.userMessageEndToken}}{{/ifUser}}" +
64
+ "{{#ifAssistant}}{{@root.assistantMessageToken}}{{content}}{{@root.assistantMessageEndToken}}{{/ifAssistant}}" +
65
+ "{{/each}}" +
66
+ "{{assistantMessageToken}}"
67
+ ),
68
+ promptExamples: z
69
+ .array(
70
+ z.object({
71
+ title: z.string().min(1),
72
+ prompt: z.string().min(1),
73
+ })
74
+ )
75
+ .optional(),
76
+ endpoints: z.array(combinedEndpoint).optional(),
77
+ parameters: z
78
+ .object({
79
+ temperature: z.number().min(0).max(1),
80
+ truncate: z.number().int().positive(),
81
+ max_new_tokens: z.number().int().positive(),
82
+ stop: z.array(z.string()).optional(),
 
 
 
 
 
83
  })
84
+ .passthrough()
85
+ .optional(),
86
+ });
87
 
88
+ const modelsRaw = z.array(modelConfig).parse(JSON.parse(MODELS));
89
+
90
+ const processModel = async (m: z.infer<typeof modelConfig>) => ({
91
+ ...m,
92
+ userMessageEndToken: m?.userMessageEndToken || m?.messageEndToken,
93
+ assistantMessageEndToken: m?.assistantMessageEndToken || m?.messageEndToken,
94
+ chatPromptRender: compileTemplate<ChatTemplateInput>(m.chatPromptTemplate, m),
95
+ id: m.id || m.name,
96
+ displayName: m.displayName || m.name,
97
+ preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,
98
+ parameters: { ...m.parameters, stop_sequences: m.parameters?.stop },
99
+ });
100
+
101
+ export const models = await Promise.all(modelsRaw.map(processModel));
102
 
103
  // Models that have been deprecated
104
  export const oldModels = OLD_MODELS
 
114
  .map((m) => ({ ...m, id: m.id || m.name, displayName: m.displayName || m.name }))
115
  : [];
116
 
 
 
 
117
  export const defaultModel = models[0];
118
 
 
 
119
  export const validateModel = (_models: BackendModel[]) => {
120
  // Zod enum function requires 2 parameters
121
  return z.enum([_models[0].id, ..._models.slice(1).map((m) => m.id)]);
122
  };
123
+
124
+ // if `TASK_MODEL` is the name of a model we use it, else we try to parse `TASK_MODEL` as a model config itself
125
+ export const smallModel = TASK_MODEL
126
+ ? models.find((m) => m.name === TASK_MODEL) ||
127
+ (await processModel(modelConfig.parse(JSON.parse(TASK_MODEL))))
128
+ : defaultModel;
129
+
130
+ export type BackendModel = Optional<(typeof models)[0], "preprompt" | "parameters">;
131
+ export type Endpoint = z.infer<typeof endpoint>;