nsarrazin HF staff commited on
Commit
2e6d1bb
·
unverified ·
1 Parent(s): 66adc5d

Add Sagemaker support (#401)

Browse files

* work on sagemaker support

* fix sagemaker integration

* remove unnecessary deps

* fix default endpoint

* remove unneeded deps, fixed types

* Use conditional validation for endpoints

This was needed because the discriminated union couldn't handle the legacy case where `host` is undefined.

* add note in readme about aws sagemaker

README.md CHANGED
@@ -198,6 +198,24 @@ You can then add the generated information and the `authorization` parameter to
198
 
199
  ```
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  #### Client Certificate Authentication (mTLS)
202
 
203
  Custom endpoints may require client certificate authentication, depending on how you configure them. To enable mTLS between Chat UI and your custom endpoint, you will need to set the `USE_CLIENT_CERTIFICATE` to `true`, and add the `CERT_PATH` and `KEY_PATH` parameters to your `.env.local`. These parameters should point to the location of the certificate and key files on your local machine. The certificate and key files should be in PEM format. The key file can be encrypted with a passphrase, in which case you will also need to add the `CLIENT_KEY_PASSWORD` parameter to your `.env.local`.
 
198
 
199
  ```
200
 
201
+ ### Amazon SageMaker
202
+
203
+ You can also specify your Amazon SageMaker instance as an endpoint for chat-ui. The config goes like this:
204
+
205
+ ```
206
+ "endpoints": [
207
+ {
208
+ "host" : "sagemaker",
209
+ "url": "", // your aws sagemaker url here
210
+ "accessKey": "",
211
+ "secretKey" : "",
212
+ "sessionToken": "", // optional
213
+ "weight": 1
214
+ }
215
+ ```
216
+
217
+ You can get the `accessKey` and `secretKey` from your AWS user, under programmatic access.
218
+
219
  #### Client Certificate Authentication (mTLS)
220
 
221
  Custom endpoints may require client certificate authentication, depending on how you configure them. To enable mTLS between Chat UI and your custom endpoint, you will need to set the `USE_CLIENT_CERTIFICATE` to `true`, and add the `CERT_PATH` and `KEY_PATH` parameters to your `.env.local`. These parameters should point to the location of the certificate and key files on your local machine. The certificate and key files should be in PEM format. The key file can be encrypted with a passphrase, in which case you will also need to add the `CLIENT_KEY_PASSWORD` parameter to your `.env.local`.
package-lock.json CHANGED
@@ -11,6 +11,7 @@
11
  "@huggingface/hub": "^0.5.1",
12
  "@huggingface/inference": "^2.2.0",
13
  "autoprefixer": "^10.4.14",
 
14
  "date-fns": "^2.29.3",
15
  "dotenv": "^16.0.3",
16
  "highlight.js": "^11.7.0",
@@ -1465,6 +1466,11 @@
1465
  "postcss": "^8.1.0"
1466
  }
1467
  },
 
 
 
 
 
1468
  "node_modules/balanced-match": {
1469
  "version": "1.0.2",
1470
  "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz",
 
11
  "@huggingface/hub": "^0.5.1",
12
  "@huggingface/inference": "^2.2.0",
13
  "autoprefixer": "^10.4.14",
14
+ "aws4fetch": "^1.0.17",
15
  "date-fns": "^2.29.3",
16
  "dotenv": "^16.0.3",
17
  "highlight.js": "^11.7.0",
 
1466
  "postcss": "^8.1.0"
1467
  }
1468
  },
1469
+ "node_modules/aws4fetch": {
1470
+ "version": "1.0.17",
1471
+ "resolved": "https://registry.npmjs.org/aws4fetch/-/aws4fetch-1.0.17.tgz",
1472
+ "integrity": "sha512-4IbOvsxqxeOSxI4oA+8xEO8SzBMVlzbSTgGy/EF83rHnQ/aKtP6Sc6YV/k0oiW0mqrcxuThlbDosnvetGOuO+g=="
1473
+ },
1474
  "node_modules/balanced-match": {
1475
  "version": "1.0.2",
1476
  "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz",
package.json CHANGED
@@ -43,6 +43,7 @@
43
  "@huggingface/hub": "^0.5.1",
44
  "@huggingface/inference": "^2.2.0",
45
  "autoprefixer": "^10.4.14",
 
46
  "date-fns": "^2.29.3",
47
  "dotenv": "^16.0.3",
48
  "highlight.js": "^11.7.0",
 
43
  "@huggingface/hub": "^0.5.1",
44
  "@huggingface/inference": "^2.2.0",
45
  "autoprefixer": "^10.4.14",
46
+ "aws4fetch": "^1.0.17",
47
  "date-fns": "^2.29.3",
48
  "dotenv": "^16.0.3",
49
  "highlight.js": "^11.7.0",
src/lib/server/generateFromDefaultEndpoint.ts CHANGED
@@ -1,9 +1,9 @@
1
  import { defaultModel } from "$lib/server/models";
2
  import { modelEndpoint } from "./modelEndpoint";
3
- import { textGeneration } from "@huggingface/inference";
4
  import { trimSuffix } from "$lib/utils/trimSuffix";
5
  import { trimPrefix } from "$lib/utils/trimPrefix";
6
  import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
 
7
 
8
  interface Parameters {
9
  temperature: number;
@@ -21,24 +21,76 @@ export async function generateFromDefaultEndpoint(
21
  return_full_text: false,
22
  };
23
 
24
- const endpoint = modelEndpoint(defaultModel);
25
- let { generated_text } = await textGeneration(
26
- {
27
- model: endpoint.url,
 
 
 
 
 
28
  inputs: prompt,
29
- parameters: newParameters,
30
- },
31
- {
32
- fetch: (url, options) =>
33
- fetch(url, {
34
- ...options,
35
- headers: { ...options?.headers, Authorization: endpoint.authorization },
36
- }),
37
- }
38
- );
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- generated_text = trimSuffix(
41
- trimPrefix(generated_text, "<|startoftext|>"),
42
  PUBLIC_SEP_TOKEN
43
  ).trimEnd();
44
 
 
1
  import { defaultModel } from "$lib/server/models";
2
  import { modelEndpoint } from "./modelEndpoint";
 
3
  import { trimSuffix } from "$lib/utils/trimSuffix";
4
  import { trimPrefix } from "$lib/utils/trimPrefix";
5
  import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
6
+ import { AwsClient } from "aws4fetch";
7
 
8
  interface Parameters {
9
  temperature: number;
 
21
  return_full_text: false,
22
  };
23
 
24
+ const randomEndpoint = modelEndpoint(defaultModel);
25
+
26
+ const abortController = new AbortController();
27
+
28
+ let resp: Response;
29
+
30
+ if (randomEndpoint.host === "sagemaker") {
31
+ const requestParams = JSON.stringify({
32
+ ...newParameters,
33
  inputs: prompt,
34
+ });
35
+
36
+ const aws = new AwsClient({
37
+ accessKeyId: randomEndpoint.accessKey,
38
+ secretAccessKey: randomEndpoint.secretKey,
39
+ sessionToken: randomEndpoint.sessionToken,
40
+ service: "sagemaker",
41
+ });
42
+
43
+ resp = await aws.fetch(randomEndpoint.url, {
44
+ method: "POST",
45
+ body: requestParams,
46
+ signal: abortController.signal,
47
+ headers: {
48
+ "Content-Type": "application/json",
49
+ },
50
+ });
51
+ } else {
52
+ resp = await fetch(randomEndpoint.url, {
53
+ headers: {
54
+ "Content-Type": "application/json",
55
+ Authorization: randomEndpoint.authorization,
56
+ },
57
+ method: "POST",
58
+ body: JSON.stringify({
59
+ ...newParameters,
60
+ inputs: prompt,
61
+ }),
62
+ signal: abortController.signal,
63
+ });
64
+ }
65
+
66
+ if (!resp.ok) {
67
+ throw new Error(await resp.text());
68
+ }
69
+
70
+ if (!resp.body) {
71
+ throw new Error("Response body is empty");
72
+ }
73
+
74
+ const decoder = new TextDecoder();
75
+ const reader = resp.body.getReader();
76
+
77
+ let isDone = false;
78
+ let result = "";
79
+
80
+ while (!isDone) {
81
+ const { done, value } = await reader.read();
82
+
83
+ isDone = done;
84
+ result += decoder.decode(value, { stream: true }); // Convert current chunk to text
85
+ }
86
+
87
+ // Close the reader when done
88
+ reader.releaseLock();
89
+
90
+ const results = await JSON.parse(result);
91
 
92
+ let generated_text = trimSuffix(
93
+ trimPrefix(trimPrefix(results[0].generated_text, "<|startoftext|>"), prompt),
94
  PUBLIC_SEP_TOKEN
95
  ).trimEnd();
96
 
src/lib/server/modelEndpoint.ts CHANGED
@@ -9,7 +9,7 @@ import {
9
  REJECT_UNAUTHORIZED,
10
  } from "$env/static/private";
11
  import { sum } from "$lib/utils/sum";
12
- import type { BackendModel } from "./models";
13
 
14
  import { loadClientCertificates } from "$lib/utils/loadClientCerts";
15
 
@@ -26,13 +26,10 @@ if (USE_CLIENT_CERTIFICATE === "true") {
26
  /**
27
  * Find a random load-balanced endpoint
28
  */
29
- export function modelEndpoint(model: BackendModel): {
30
- url: string;
31
- authorization: string;
32
- weight: number;
33
- } {
34
  if (!model.endpoints) {
35
  return {
 
36
  url: `${HF_API_ROOT}/${model.name}`,
37
  authorization: `Bearer ${HF_ACCESS_TOKEN}`,
38
  weight: 1,
 
9
  REJECT_UNAUTHORIZED,
10
  } from "$env/static/private";
11
  import { sum } from "$lib/utils/sum";
12
+ import type { BackendModel, Endpoint } from "./models";
13
 
14
  import { loadClientCertificates } from "$lib/utils/loadClientCerts";
15
 
 
26
  /**
27
  * Find a random load-balanced endpoint
28
  */
29
+ export function modelEndpoint(model: BackendModel): Endpoint {
 
 
 
 
30
  if (!model.endpoints) {
31
  return {
32
+ host: "tgi",
33
  url: `${HF_API_ROOT}/${model.name}`,
34
  authorization: `Bearer ${HF_ACCESS_TOKEN}`,
35
  weight: 1,
src/lib/server/models.ts CHANGED
@@ -1,6 +1,38 @@
1
  import { HF_ACCESS_TOKEN, MODELS, OLD_MODELS } from "$env/static/private";
2
  import { z } from "zod";
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  const modelsRaw = z
5
  .array(
6
  z.object({
@@ -29,15 +61,7 @@ const modelsRaw = z
29
  })
30
  )
31
  .optional(),
32
- endpoints: z
33
- .array(
34
- z.object({
35
- url: z.string().url(),
36
- authorization: z.string().min(1).default(`Bearer ${HF_ACCESS_TOKEN}`),
37
- weight: z.number().int().positive().default(1),
38
- })
39
- )
40
- .optional(),
41
  parameters: z
42
  .object({
43
  temperature: z.number().min(0).max(1),
@@ -77,6 +101,7 @@ export const oldModels = OLD_MODELS
77
  : [];
78
 
79
  export type BackendModel = (typeof models)[0];
 
80
 
81
  export const defaultModel = models[0];
82
 
 
1
  import { HF_ACCESS_TOKEN, MODELS, OLD_MODELS } from "$env/static/private";
2
  import { z } from "zod";
3
 
4
+ const sagemakerEndpoint = z.object({
5
+ host: z.literal("sagemaker"),
6
+ url: z.string().url(),
7
+ accessKey: z.string().min(1),
8
+ secretKey: z.string().min(1),
9
+ sessionToken: z.string().optional(),
10
+ });
11
+
12
+ const tgiEndpoint = z.object({
13
+ host: z.union([z.literal("tgi"), z.undefined()]),
14
+ url: z.string().url(),
15
+ authorization: z.string().min(1).default(`Bearer ${HF_ACCESS_TOKEN}`),
16
+ });
17
+
18
+ const commonEndpoint = z.object({
19
+ weight: z.number().int().positive().default(1),
20
+ });
21
+
22
+ const endpoint = z.lazy(() =>
23
+ z.union([sagemakerEndpoint.merge(commonEndpoint), tgiEndpoint.merge(commonEndpoint)])
24
+ );
25
+
26
+ const combinedEndpoint = endpoint.transform((data) => {
27
+ if (data.host === "tgi" || data.host === undefined) {
28
+ return tgiEndpoint.merge(commonEndpoint).parse(data);
29
+ } else if (data.host === "sagemaker") {
30
+ return sagemakerEndpoint.merge(commonEndpoint).parse(data);
31
+ } else {
32
+ throw new Error(`Invalid host: ${data.host}`);
33
+ }
34
+ });
35
+
36
  const modelsRaw = z
37
  .array(
38
  z.object({
 
61
  })
62
  )
63
  .optional(),
64
+ endpoints: z.array(combinedEndpoint).optional(),
 
 
 
 
 
 
 
 
65
  parameters: z
66
  .object({
67
  temperature: z.number().min(0).max(1),
 
101
  : [];
102
 
103
  export type BackendModel = (typeof models)[0];
104
+ export type Endpoint = z.infer<typeof endpoint>;
105
 
106
  export const defaultModel = models[0];
107
 
src/routes/conversation/[id]/+server.ts CHANGED
@@ -16,6 +16,7 @@ import type { TextGenerationStreamOutput } from "@huggingface/inference";
16
  import { error } from "@sveltejs/kit";
17
  import { ObjectId } from "mongodb";
18
  import { z } from "zod";
 
19
 
20
  export async function POST({ request, fetch, locals, params }) {
21
  const id = z.string().parse(params.id);
@@ -101,18 +102,42 @@ export async function POST({ request, fetch, locals, params }) {
101
 
102
  const abortController = new AbortController();
103
 
104
- const resp = await fetch(randomEndpoint.url, {
105
- headers: {
106
- "Content-Type": request.headers.get("Content-Type") ?? "application/json",
107
- Authorization: randomEndpoint.authorization,
108
- },
109
- method: "POST",
110
- body: JSON.stringify({
111
  ...json,
112
  inputs: prompt,
113
- }),
114
- signal: abortController.signal,
115
- });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  if (!resp.body) {
118
  throw new Error("Response body is empty");
 
16
  import { error } from "@sveltejs/kit";
17
  import { ObjectId } from "mongodb";
18
  import { z } from "zod";
19
+ import { AwsClient } from "aws4fetch";
20
 
21
  export async function POST({ request, fetch, locals, params }) {
22
  const id = z.string().parse(params.id);
 
102
 
103
  const abortController = new AbortController();
104
 
105
+ let resp: Response;
106
+ if (randomEndpoint.host === "sagemaker") {
107
+ const requestParams = JSON.stringify({
 
 
 
 
108
  ...json,
109
  inputs: prompt,
110
+ });
111
+
112
+ const aws = new AwsClient({
113
+ accessKeyId: randomEndpoint.accessKey,
114
+ secretAccessKey: randomEndpoint.secretKey,
115
+ sessionToken: randomEndpoint.sessionToken,
116
+ service: "sagemaker",
117
+ });
118
+
119
+ resp = await aws.fetch(randomEndpoint.url, {
120
+ method: "POST",
121
+ body: requestParams,
122
+ signal: abortController.signal,
123
+ headers: {
124
+ "Content-Type": "application/json",
125
+ },
126
+ });
127
+ } else {
128
+ resp = await fetch(randomEndpoint.url, {
129
+ headers: {
130
+ "Content-Type": request.headers.get("Content-Type") ?? "application/json",
131
+ Authorization: randomEndpoint.authorization,
132
+ },
133
+ method: "POST",
134
+ body: JSON.stringify({
135
+ ...json,
136
+ inputs: prompt,
137
+ }),
138
+ signal: abortController.signal,
139
+ });
140
+ }
141
 
142
  if (!resp.body) {
143
  throw new Error("Response body is empty");