Spaces:
Running
Running
File size: 2,968 Bytes
9db8ced 73a5c0d 7c4d92a 9db8ced a1afcb6 7c4d92a 73a5c0d 9db8ced d5559df 73a5c0d 7c4d92a 9db8ced d4016bc e6addfc d4016bc 9db8ced e6addfc 9db8ced d4016bc 9db8ced d4016bc e6addfc 9db8ced e6addfc bcfa394 e6addfc d4016bc 9db8ced e6addfc 9db8ced d4016bc 9db8ced |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import { z } from "zod";
import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
import { buildPrompt } from "$lib/buildPrompt";
import { OPENAI_API_KEY } from "$env/static/private";
import type { Endpoint } from "../endpoints";
export const endpointOAIParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("openai"),
baseURL: z.string().url().default("https://api.openai.com/v1"),
apiKey: z.string().default(OPENAI_API_KEY ?? "sk-"),
completion: z
.union([z.literal("completions"), z.literal("chat_completions")])
.default("chat_completions"),
defaultHeaders: z.record(z.string()).optional(),
defaultQuery: z.record(z.string()).optional(),
});
export async function endpointOai(
input: z.input<typeof endpointOAIParametersSchema>
): Promise<Endpoint> {
const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery } =
endpointOAIParametersSchema.parse(input);
let OpenAI;
try {
OpenAI = (await import("openai")).OpenAI;
} catch (e) {
throw new Error("Failed to import OpenAI", { cause: e });
}
const openai = new OpenAI({
apiKey: apiKey ?? "sk-",
baseURL,
defaultHeaders,
defaultQuery,
});
if (completion === "completions") {
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
const prompt = await buildPrompt({
messages,
continueMessage,
preprompt,
model,
});
const parameters = { ...model.parameters, ...generateSettings };
return openAICompletionToTextGenerationStream(
await openai.completions.create({
model: model.id ?? model.name,
prompt,
stream: true,
max_tokens: parameters?.max_new_tokens,
stop: parameters?.stop,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
frequency_penalty: parameters?.repetition_penalty,
})
);
};
} else if (completion === "chat_completions") {
return async ({ messages, preprompt, generateSettings }) => {
let messagesOpenAI = messages.map((message) => ({
role: message.from,
content: message.content,
}));
if (messagesOpenAI?.[0]?.role !== "system") {
messagesOpenAI = [{ role: "system", content: "" }, ...messagesOpenAI];
}
if (messagesOpenAI?.[0]) {
messagesOpenAI[0].content = preprompt ?? "";
}
const parameters = { ...model.parameters, ...generateSettings };
return openAIChatToTextGenerationStream(
await openai.chat.completions.create({
model: model.id ?? model.name,
messages: messagesOpenAI,
stream: true,
max_tokens: parameters?.max_new_tokens,
stop: parameters?.stop,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
frequency_penalty: parameters?.repetition_penalty,
})
);
};
} else {
throw new Error("Invalid completion type");
}
}
|