File size: 3,250 Bytes
a99cca3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d323cb0
a99cca3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import { z } from "zod";
import { COHERE_API_TOKEN } from "$env/static/private";
import type { Endpoint } from "../endpoints";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import type { Cohere, CohereClient } from "cohere-ai";
import { buildPrompt } from "$lib/buildPrompt";

export const endpointCohereParametersSchema = z.object({
	weight: z.number().int().positive().default(1),
	model: z.any(),
	type: z.literal("cohere"),
	apiKey: z.string().default(COHERE_API_TOKEN),
	raw: z.boolean().default(false),
});

export async function endpointCohere(
	input: z.input<typeof endpointCohereParametersSchema>
): Promise<Endpoint> {
	const { apiKey, model, raw } = endpointCohereParametersSchema.parse(input);

	let cohere: CohereClient;

	try {
		cohere = new (await import("cohere-ai")).CohereClient({
			token: apiKey,
		});
	} catch (e) {
		throw new Error("Failed to import cohere-ai", { cause: e });
	}

	return async ({ messages, preprompt, generateSettings, continueMessage }) => {
		let system = preprompt;
		if (messages?.[0]?.from === "system") {
			system = messages[0].content;
		}

		const parameters = { ...model.parameters, ...generateSettings };

		return (async function* () {
			let stream;
			let tokenId = 0;

			if (raw) {
				const prompt = await buildPrompt({
					messages: messages.filter((message) => message.from !== "system"),
					model,
					preprompt: system,
					continueMessage,
				});

				stream = await cohere.chatStream({
					message: prompt,
					rawPrompting: true,
					model: model.id ?? model.name,
					p: parameters?.top_p,
					k: parameters?.top_k,
					maxTokens: parameters?.max_new_tokens,
					temperature: parameters?.temperature,
					stopSequences: parameters?.stop,
					frequencyPenalty: parameters?.frequency_penalty,
				});
			} else {
				const formattedMessages = messages
					.filter((message) => message.from !== "system")
					.map((message) => ({
						role: message.from === "user" ? "USER" : "CHATBOT",
						message: message.content,
					})) satisfies Cohere.ChatMessage[];

				stream = await cohere.chatStream({
					model: model.id ?? model.name,
					chatHistory: formattedMessages.slice(0, -1),
					message: formattedMessages[formattedMessages.length - 1].message,
					preamble: system,
					p: parameters?.top_p,
					k: parameters?.top_k,
					maxTokens: parameters?.max_new_tokens,
					temperature: parameters?.temperature,
					stopSequences: parameters?.stop,
					frequencyPenalty: parameters?.frequency_penalty,
				});
			}

			for await (const output of stream) {
				if (output.eventType === "text-generation") {
					yield {
						token: {
							id: tokenId++,
							text: output.text,
							logprob: 0,
							special: false,
						},
						generated_text: null,
						details: null,
					} satisfies TextGenerationStreamOutput;
				} else if (output.eventType === "stream-end") {
					if (["ERROR", "ERROR_TOXIC", "ERROR_LIMIT"].includes(output.finishReason)) {
						throw new Error(output.finishReason);
					}
					yield {
						token: {
							id: tokenId++,
							text: "",
							logprob: 0,
							special: true,
						},
						generated_text: output.response.text,
						details: null,
					};
				}
			}
		})();
	};
}