Spaces:
Running
Running
File size: 3,584 Bytes
cb000d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import { z } from "zod";
import type { Endpoint } from "../endpoints";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import { CLOUDFLARE_ACCOUNT_ID, CLOUDFLARE_API_TOKEN } from "$env/static/private";
export const endpointCloudflareParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("cloudflare"),
accountId: z.string().default(CLOUDFLARE_ACCOUNT_ID),
apiToken: z.string().default(CLOUDFLARE_API_TOKEN),
});
export async function endpointCloudflare(
input: z.input<typeof endpointCloudflareParametersSchema>
): Promise<Endpoint> {
const { accountId, apiToken, model } = endpointCloudflareParametersSchema.parse(input);
const apiURL = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/@hf/${model.id}`;
return async ({ messages, preprompt }) => {
let messagesFormatted = messages.map((message) => ({
role: message.from,
content: message.content,
}));
if (messagesFormatted?.[0]?.role !== "system") {
messagesFormatted = [{ role: "system", content: preprompt ?? "" }, ...messagesFormatted];
}
const payload = JSON.stringify({
messages: messagesFormatted,
stream: true,
});
const res = await fetch(apiURL, {
method: "POST",
headers: {
Authorization: `Bearer ${apiToken}`,
"Content-Type": "application/json",
},
body: payload,
});
if (!res.ok) {
throw new Error(`Failed to generate text: ${await res.text()}`);
}
const encoder = new TextDecoderStream();
const reader = res.body?.pipeThrough(encoder).getReader();
return (async function* () {
let stop = false;
let generatedText = "";
let tokenId = 0;
let accumulatedData = ""; // Buffer to accumulate data chunks
while (!stop) {
const out = await reader?.read();
// If it's done, we cancel
if (out?.done) {
reader?.cancel();
return;
}
if (!out?.value) {
return;
}
// Accumulate the data chunk
accumulatedData += out.value;
// Process each complete JSON object in the accumulated data
while (accumulatedData.includes("\n")) {
// Assuming each JSON object ends with a newline
const endIndex = accumulatedData.indexOf("\n");
let jsonString = accumulatedData.substring(0, endIndex).trim();
// Remove the processed part from the buffer
accumulatedData = accumulatedData.substring(endIndex + 1);
if (jsonString.startsWith("data: ")) {
jsonString = jsonString.slice(6);
let data = null;
if (jsonString === "[DONE]") {
stop = true;
yield {
token: {
id: tokenId++,
text: "",
logprob: 0,
special: true,
},
generated_text: generatedText,
details: null,
} satisfies TextGenerationStreamOutput;
reader?.cancel();
continue;
}
try {
data = JSON.parse(jsonString);
} catch (e) {
console.error("Failed to parse JSON", e);
console.error("Problematic JSON string:", jsonString);
continue; // Skip this iteration and try the next chunk
}
// Handle the parsed data
if (data.response) {
generatedText += data.response ?? "";
const output: TextGenerationStreamOutput = {
token: {
id: tokenId++,
text: data.response ?? "",
logprob: 0,
special: false,
},
generated_text: null,
details: null,
};
yield output;
}
}
}
}
})();
};
}
export default endpointCloudflare;
|