Spaces:
Running
Running
Galén
Add support for passing an API key or any other custom token in the authorization header (#579)
a1afcb6
unverified
import { z } from "zod"; | |
import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream"; | |
import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream"; | |
import { buildPrompt } from "$lib/buildPrompt"; | |
import { OPENAI_API_KEY } from "$env/static/private"; | |
import type { Endpoint } from "../endpoints"; | |
export const endpointOAIParametersSchema = z.object({ | |
weight: z.number().int().positive().default(1), | |
model: z.any(), | |
type: z.literal("openai"), | |
baseURL: z.string().url().default("https://api.openai.com/v1"), | |
apiKey: z.string().default(OPENAI_API_KEY ?? "sk-"), | |
completion: z | |
.union([z.literal("completions"), z.literal("chat_completions")]) | |
.default("chat_completions"), | |
}); | |
export async function endpointOai( | |
input: z.input<typeof endpointOAIParametersSchema> | |
): Promise<Endpoint> { | |
const { baseURL, apiKey, completion, model } = endpointOAIParametersSchema.parse(input); | |
let OpenAI; | |
try { | |
OpenAI = (await import("openai")).OpenAI; | |
} catch (e) { | |
throw new Error("Failed to import OpenAI", { cause: e }); | |
} | |
const openai = new OpenAI({ | |
apiKey: apiKey ?? "sk-", | |
baseURL: baseURL, | |
}); | |
if (completion === "completions") { | |
return async ({ conversation }) => { | |
return openAICompletionToTextGenerationStream( | |
await openai.completions.create({ | |
model: model.id ?? model.name, | |
prompt: await buildPrompt({ | |
messages: conversation.messages, | |
webSearch: conversation.messages[conversation.messages.length - 1].webSearch, | |
preprompt: conversation.preprompt, | |
model, | |
}), | |
stream: true, | |
max_tokens: model.parameters?.max_new_tokens, | |
stop: model.parameters?.stop, | |
temperature: model.parameters?.temperature, | |
top_p: model.parameters?.top_p, | |
frequency_penalty: model.parameters?.repetition_penalty, | |
}) | |
); | |
}; | |
} else if (completion === "chat_completions") { | |
return async ({ conversation }) => { | |
const messages = conversation.messages.map((message) => ({ | |
role: message.from, | |
content: message.content, | |
})); | |
return openAIChatToTextGenerationStream( | |
await openai.chat.completions.create({ | |
model: model.id ?? model.name, | |
messages: conversation.preprompt | |
? [{ role: "system", content: conversation.preprompt }, ...messages] | |
: messages, | |
stream: true, | |
max_tokens: model.parameters?.max_new_tokens, | |
stop: model.parameters?.stop, | |
temperature: model.parameters?.temperature, | |
top_p: model.parameters?.top_p, | |
frequency_penalty: model.parameters?.repetition_penalty, | |
}) | |
); | |
}; | |
} else { | |
throw new Error("Invalid completion type"); | |
} | |
} | |