Spaces:
Running
Running
File size: 4,877 Bytes
9db8ced 1e5090f 447c0ca 7764421 9db8ced 7764421 cd6894d 9187ced 9db8ced 9187ced 9db8ced 2606dde 9187ced 7764421 9187ced 9db8ced b5041c1 9db8ced 2606dde b7b2c8c 9187ced 9db8ced 9187ced 9db8ced 9187ced 9db8ced |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import { HF_ACCESS_TOKEN, HF_API_ROOT, MODELS, OLD_MODELS, TASK_MODEL } from "$env/static/private";
import type { ChatTemplateInput } from "$lib/types/Template";
import { compileTemplate } from "$lib/utils/template";
import { z } from "zod";
import endpoints, { endpointSchema, type Endpoint } from "./endpoints/endpoints";
import endpointTgi from "./endpoints/tgi/endpointTgi";
import { sum } from "$lib/utils/sum";
type Optional<T, K extends keyof T> = Pick<Partial<T>, K> & Omit<T, K>;
const modelConfig = z.object({
/** Used as an identifier in DB */
id: z.string().optional(),
/** Used to link to the model page, and for inference */
name: z.string().min(1),
displayName: z.string().min(1).optional(),
description: z.string().min(1).optional(),
websiteUrl: z.string().url().optional(),
modelUrl: z.string().url().optional(),
datasetName: z.string().min(1).optional(),
datasetUrl: z.string().url().optional(),
userMessageToken: z.string().default(""),
userMessageEndToken: z.string().default(""),
assistantMessageToken: z.string().default(""),
assistantMessageEndToken: z.string().default(""),
messageEndToken: z.string().default(""),
preprompt: z.string().default(""),
prepromptUrl: z.string().url().optional(),
chatPromptTemplate: z
.string()
.default(
"{{preprompt}}" +
"{{#each messages}}" +
"{{#ifUser}}{{@root.userMessageToken}}{{content}}{{@root.userMessageEndToken}}{{/ifUser}}" +
"{{#ifAssistant}}{{@root.assistantMessageToken}}{{content}}{{@root.assistantMessageEndToken}}{{/ifAssistant}}" +
"{{/each}}" +
"{{assistantMessageToken}}"
),
promptExamples: z
.array(
z.object({
title: z.string().min(1),
prompt: z.string().min(1),
})
)
.optional(),
endpoints: z.array(endpointSchema).optional(),
parameters: z
.object({
temperature: z.number().min(0).max(1),
truncate: z.number().int().positive(),
max_new_tokens: z.number().int().positive(),
stop: z.array(z.string()).optional(),
top_p: z.number().positive().optional(),
top_k: z.number().positive().optional(),
repetition_penalty: z.number().min(-2).max(2).optional(),
})
.passthrough()
.optional(),
});
const modelsRaw = z.array(modelConfig).parse(JSON.parse(MODELS));
const processModel = async (m: z.infer<typeof modelConfig>) => ({
...m,
userMessageEndToken: m?.userMessageEndToken || m?.messageEndToken,
assistantMessageEndToken: m?.assistantMessageEndToken || m?.messageEndToken,
chatPromptRender: compileTemplate<ChatTemplateInput>(m.chatPromptTemplate, m),
id: m.id || m.name,
displayName: m.displayName || m.name,
preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,
parameters: { ...m.parameters, stop_sequences: m.parameters?.stop },
});
const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
...m,
getEndpoint: async (): Promise<Endpoint> => {
if (!m.endpoints) {
return endpointTgi({
type: "tgi",
url: `${HF_API_ROOT}/${m.name}`,
accessToken: HF_ACCESS_TOKEN,
weight: 1,
model: m,
});
}
const totalWeight = sum(m.endpoints.map((e) => e.weight));
let random = Math.random() * totalWeight;
for (const endpoint of m.endpoints) {
if (random < endpoint.weight) {
const args = { ...endpoint, model: m };
if (args.type === "tgi") {
return endpoints.tgi(args);
} else if (args.type === "aws") {
return await endpoints.aws(args);
} else if (args.type === "openai") {
return await endpoints.openai(args);
} else if (args.type === "llamacpp") {
return await endpoints.llamacpp(args);
} else {
// for legacy reason
return await endpoints.tgi(args);
}
}
random -= endpoint.weight;
}
throw new Error(`Failed to select endpoint`);
},
});
export const models = await Promise.all(modelsRaw.map((e) => processModel(e).then(addEndpoint)));
export const defaultModel = models[0];
// Models that have been deprecated
export const oldModels = OLD_MODELS
? z
.array(
z.object({
id: z.string().optional(),
name: z.string().min(1),
displayName: z.string().min(1).optional(),
})
)
.parse(JSON.parse(OLD_MODELS))
.map((m) => ({ ...m, id: m.id || m.name, displayName: m.displayName || m.name }))
: [];
export const validateModel = (_models: BackendModel[]) => {
// Zod enum function requires 2 parameters
return z.enum([_models[0].id, ..._models.slice(1).map((m) => m.id)]);
};
// if `TASK_MODEL` is the name of a model we use it, else we try to parse `TASK_MODEL` as a model config itself
export const smallModel = TASK_MODEL
? (models.find((m) => m.name === TASK_MODEL) ||
(await processModel(modelConfig.parse(JSON.parse(TASK_MODEL))).then((m) =>
addEndpoint(m)
))) ??
defaultModel
: defaultModel;
export type BackendModel = Optional<typeof defaultModel, "preprompt" | "parameters">;
|