import { type ActionFunctionArgs } from '@remix-run/cloudflare'; import { createDataStream } from 'ai'; import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants'; import { CONTINUE_PROMPT } from '~/lib/common/prompts/prompts'; import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text'; import SwitchableStream from '~/lib/.server/llm/switchable-stream'; import type { IProviderSetting } from '~/types/model'; import { createScopedLogger } from '~/utils/logger'; export async function action(args: ActionFunctionArgs) { return chatAction(args); } const logger = createScopedLogger('api.chat'); function parseCookies(cookieHeader: string): Record { const cookies: Record = {}; const items = cookieHeader.split(';').map((cookie) => cookie.trim()); items.forEach((item) => { const [name, ...rest] = item.split('='); if (name && rest) { const decodedName = decodeURIComponent(name.trim()); const decodedValue = decodeURIComponent(rest.join('=').trim()); cookies[decodedName] = decodedValue; } }); return cookies; } async function chatAction({ context, request }: ActionFunctionArgs) { const { messages, files, promptId, contextOptimization, isPromptCachingEnabled } = await request.json<{ messages: Messages; files: any; promptId?: string; contextOptimization: boolean; isPromptCachingEnabled: boolean; }>(); const cookieHeader = request.headers.get('Cookie'); const apiKeys = JSON.parse(parseCookies(cookieHeader || '').apiKeys || '{}'); const providerSettings: Record = JSON.parse( parseCookies(cookieHeader || '').providers || '{}', ); const stream = new SwitchableStream(); const cumulativeUsage = { completionTokens: 0, promptTokens: 0, totalTokens: 0, }; try { const options: StreamingOptions = { toolChoice: 'none', // eslint-disable-next-line @typescript-eslint/naming-convention onFinish: async ({ text: content, finishReason, usage, experimental_providerMetadata }) => { logger.debug('usage', JSON.stringify(usage)); const cacheUsage = experimental_providerMetadata?.anthropic; console.debug({ cacheUsage }); if (usage) { cumulativeUsage.completionTokens += Math.round(usage.completionTokens || 0); cumulativeUsage.promptTokens += Math.round( (usage.promptTokens || 0) + ((cacheUsage?.cacheCreationInputTokens as number) || 0) * 1.25 + ((cacheUsage?.cacheReadInputTokens as number) || 0) * 0.1, ); cumulativeUsage.totalTokens = cumulativeUsage.completionTokens + cumulativeUsage.promptTokens; } if (finishReason !== 'length') { const encoder = new TextEncoder(); const usageStream = createDataStream({ async execute(dataStream) { dataStream.writeMessageAnnotation({ type: 'usage', value: { completionTokens: cumulativeUsage.completionTokens, promptTokens: cumulativeUsage.promptTokens, totalTokens: cumulativeUsage.totalTokens, }, }); }, onError: (error: any) => `Custom error: ${error.message}`, }).pipeThrough( new TransformStream({ transform: (chunk, controller) => { // Convert the string stream to a byte stream const str = typeof chunk === 'string' ? chunk : JSON.stringify(chunk); controller.enqueue(encoder.encode(str)); }, }), ); await stream.switchSource(usageStream); await new Promise((resolve) => setTimeout(resolve, 0)); stream.close(); return; } if (stream.switches >= MAX_RESPONSE_SEGMENTS) { throw Error('Cannot continue message: Maximum segments reached'); } const switchesLeft = MAX_RESPONSE_SEGMENTS - stream.switches; logger.info(`Reached max token limit (${MAX_TOKENS}): Continuing message (${switchesLeft} switches left)`); messages.push({ role: 'assistant', content }); messages.push({ role: 'user', content: CONTINUE_PROMPT }); const result = await streamText({ messages, env: context.cloudflare.env, options, apiKeys, files, providerSettings, promptId, contextOptimization, isPromptCachingEnabled, }); stream.switchSource(result.toDataStream()); return; }, }; const result = await streamText({ messages, env: context.cloudflare.env, options, apiKeys, files, providerSettings, promptId, contextOptimization, isPromptCachingEnabled, }); stream.switchSource(result.toDataStream()); return new Response(stream.readable, { status: 200, headers: { contentType: 'text/plain; charset=utf-8', }, }); } catch (error: any) { logger.error(error); if (error.message?.includes('API key')) { throw new Response('Invalid or missing API key', { status: 401, statusText: 'Unauthorized', }); } throw new Response(null, { status: 500, statusText: 'Internal Server Error', }); } }