nsarrazin's picture
nsarrazin HF staff
Add Sagemaker support (#401)
2e6d1bb unverified
raw
history blame
8.22 kB
import { MESSAGES_BEFORE_LOGIN, RATE_LIMIT } from "$env/static/private";
import { buildPrompt } from "$lib/buildPrompt";
import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
import { abortedGenerations } from "$lib/server/abortedGenerations";
import { authCondition, requiresUser } from "$lib/server/auth";
import { collections } from "$lib/server/database";
import { modelEndpoint } from "$lib/server/modelEndpoint";
import { models } from "$lib/server/models";
import { ERROR_MESSAGES } from "$lib/stores/errors.js";
import type { Message } from "$lib/types/Message";
import { concatUint8Arrays } from "$lib/utils/concatUint8Arrays";
import { streamToAsyncIterable } from "$lib/utils/streamToAsyncIterable";
import { trimPrefix } from "$lib/utils/trimPrefix";
import { trimSuffix } from "$lib/utils/trimSuffix";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import { error } from "@sveltejs/kit";
import { ObjectId } from "mongodb";
import { z } from "zod";
import { AwsClient } from "aws4fetch";
export async function POST({ request, fetch, locals, params }) {
const id = z.string().parse(params.id);
const convId = new ObjectId(id);
const date = new Date();
const userId = locals.user?._id ?? locals.sessionId;
if (!userId) {
throw error(401, "Unauthorized");
}
const conv = await collections.conversations.findOne({
_id: convId,
...authCondition(locals),
});
if (!conv) {
throw error(404, "Conversation not found");
}
if (
!locals.user?._id &&
requiresUser &&
conv.messages.length > (MESSAGES_BEFORE_LOGIN ? parseInt(MESSAGES_BEFORE_LOGIN) : 0)
) {
throw error(429, "Exceeded number of messages before login");
}
const nEvents = await collections.messageEvents.countDocuments({ userId });
if (RATE_LIMIT != "" && nEvents > parseInt(RATE_LIMIT)) {
throw error(429, ERROR_MESSAGES.rateLimited);
}
const model = models.find((m) => m.id === conv.model);
if (!model) {
throw error(410, "Model not available anymore");
}
const json = await request.json();
const {
inputs: newPrompt,
options: { id: messageId, is_retry, web_search_id, response_id: responseId },
} = z
.object({
inputs: z.string().trim().min(1),
options: z.object({
id: z.optional(z.string().uuid()),
response_id: z.optional(z.string().uuid()),
is_retry: z.optional(z.boolean()),
web_search_id: z.ostring(),
}),
})
.parse(json);
const messages = (() => {
if (is_retry && messageId) {
let retryMessageIdx = conv.messages.findIndex((message) => message.id === messageId);
if (retryMessageIdx === -1) {
retryMessageIdx = conv.messages.length;
}
return [
...conv.messages.slice(0, retryMessageIdx),
{ content: newPrompt, from: "user", id: messageId as Message["id"], updatedAt: new Date() },
];
}
return [
...conv.messages,
{
content: newPrompt,
from: "user",
id: (messageId as Message["id"]) || crypto.randomUUID(),
createdAt: new Date(),
updatedAt: new Date(),
},
];
})() satisfies Message[];
const prompt = await buildPrompt(messages, model, web_search_id);
const randomEndpoint = modelEndpoint(model);
const abortController = new AbortController();
let resp: Response;
if (randomEndpoint.host === "sagemaker") {
const requestParams = JSON.stringify({
...json,
inputs: prompt,
});
const aws = new AwsClient({
accessKeyId: randomEndpoint.accessKey,
secretAccessKey: randomEndpoint.secretKey,
sessionToken: randomEndpoint.sessionToken,
service: "sagemaker",
});
resp = await aws.fetch(randomEndpoint.url, {
method: "POST",
body: requestParams,
signal: abortController.signal,
headers: {
"Content-Type": "application/json",
},
});
} else {
resp = await fetch(randomEndpoint.url, {
headers: {
"Content-Type": request.headers.get("Content-Type") ?? "application/json",
Authorization: randomEndpoint.authorization,
},
method: "POST",
body: JSON.stringify({
...json,
inputs: prompt,
}),
signal: abortController.signal,
});
}
if (!resp.body) {
throw new Error("Response body is empty");
}
const [stream1, stream2] = resp.body.tee();
async function saveMessage() {
let generated_text = await parseGeneratedText(stream2, convId, date, abortController);
// We could also check if PUBLIC_ASSISTANT_MESSAGE_TOKEN is present and use it to slice the text
if (generated_text.startsWith(prompt)) {
generated_text = generated_text.slice(prompt.length);
}
generated_text = trimSuffix(
trimPrefix(generated_text, "<|startoftext|>"),
PUBLIC_SEP_TOKEN
).trimEnd();
for (const stop of [...(model?.parameters?.stop ?? []), "<|endoftext|>"]) {
if (generated_text.endsWith(stop)) {
generated_text = generated_text.slice(0, -stop.length).trimEnd();
}
}
messages.push({
from: "assistant",
content: generated_text,
webSearchId: web_search_id,
id: (responseId as Message["id"]) || crypto.randomUUID(),
createdAt: new Date(),
updatedAt: new Date(),
});
await collections.messageEvents.insertOne({
userId: userId,
createdAt: new Date(),
});
await collections.conversations.updateOne(
{
_id: convId,
},
{
$set: {
messages,
updatedAt: new Date(),
},
}
);
}
saveMessage().catch(console.error);
// Todo: maybe we should wait for the message to be saved before ending the response - in case of errors
return new Response(stream1, {
headers: Object.fromEntries(resp.headers.entries()),
status: resp.status,
statusText: resp.statusText,
});
}
export async function DELETE({ locals, params }) {
const convId = new ObjectId(params.id);
const conv = await collections.conversations.findOne({
_id: convId,
...authCondition(locals),
});
if (!conv) {
throw error(404, "Conversation not found");
}
await collections.conversations.deleteOne({ _id: conv._id });
return new Response();
}
async function parseGeneratedText(
stream: ReadableStream,
conversationId: ObjectId,
promptedAt: Date,
abortController: AbortController
): Promise<string> {
const inputs: Uint8Array[] = [];
for await (const input of streamToAsyncIterable(stream)) {
inputs.push(input);
const date = abortedGenerations.get(conversationId.toString());
if (date && date > promptedAt) {
abortController.abort("Cancelled by user");
const completeInput = concatUint8Arrays(inputs);
const lines = new TextDecoder()
.decode(completeInput)
.split("\n")
.filter((line) => line.startsWith("data:"));
const tokens = lines.map((line) => {
try {
const json: TextGenerationStreamOutput = JSON.parse(line.slice("data:".length));
return json.token.text;
} catch {
return "";
}
});
return tokens.join("");
}
}
// Merge inputs into a single Uint8Array
const completeInput = concatUint8Arrays(inputs);
// Get last line starting with "data:" and parse it as JSON to get the generated text
const message = new TextDecoder().decode(completeInput);
let lastIndex = message.lastIndexOf("\ndata:");
if (lastIndex === -1) {
lastIndex = message.indexOf("data");
}
if (lastIndex === -1) {
console.error("Could not parse last message", message);
}
let lastMessage = message.slice(lastIndex).trim().slice("data:".length);
if (lastMessage.includes("\n")) {
lastMessage = lastMessage.slice(0, lastMessage.indexOf("\n"));
}
const lastMessageJSON = JSON.parse(lastMessage);
if (lastMessageJSON.error) {
throw new Error(lastMessageJSON.error);
}
const res = lastMessageJSON.generated_text;
if (typeof res !== "string") {
throw new Error("Could not parse generated text");
}
return res;
}
export async function PATCH({ request, locals, params }) {
const { title } = z
.object({ title: z.string().trim().min(1).max(100) })
.parse(await request.json());
const convId = new ObjectId(params.id);
const conv = await collections.conversations.findOne({
_id: convId,
...authCondition(locals),
});
if (!conv) {
throw error(404, "Conversation not found");
}
await collections.conversations.updateOne(
{
_id: convId,
},
{
$set: {
title,
},
}
);
return new Response();
}