jbilcke-hf HF staff commited on
Commit
e52146b
·
1 Parent(s): 166428d

allow support for Inference API for SDXL

Browse files
.env CHANGED
@@ -1,35 +1,49 @@
1
- # ------------- IMAGE API CONFIG --------------
2
  # Supported values:
3
  # - VIDEOCHAIN
4
  # - REPLICATE
5
- RENDERING_ENGINE="REPLICATE"
6
-
7
- VIDEOCHAIN_API_URL="http://localhost:7860"
8
- VIDEOCHAIN_API_TOKEN=
9
-
10
- REPLICATE_API_TOKEN=
11
- REPLICATE_API_MODEL="stabilityai/sdxl"
12
- REPLICATE_API_MODEL_VERSION="da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf"
13
 
14
- # ------------- LLM API CONFIG ----------------
15
  # Supported values:
16
  # - INFERENCE_ENDPOINT
17
  # - INFERENCE_API
18
- LLM_ENGINE="INFERENCE_ENDPOINT"
 
 
19
 
20
- # Hugging Face token (if you choose to use a custom Inference Endpoint or an Inference API model)
 
21
  HF_API_TOKEN=
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # URL to a custom text-generation Inference Endpoint of your choice
24
  # -> You can leave it empty if you decide to use an Inference API Model instead
25
- HF_INFERENCE_ENDPOINT_URL=
26
 
27
  # You can also use a model from the Inference API (not a custom inference endpoint)
28
  # -> You can leave it empty if you decide to use an Inference Endpoint URL instead
29
- HF_INFERENCE_API_MODEL="codellama/CodeLlama-7b-hf"
30
-
31
- # Not supported yet
32
- OPENAI_TOKEN=
33
 
34
  # ----------- COMMUNITY SHARING (OPTIONAL) -----------
35
  NEXT_PUBLIC_ENABLE_COMMUNITY_SHARING="false"
 
 
1
  # Supported values:
2
  # - VIDEOCHAIN
3
  # - REPLICATE
4
+ # - INFERENCE_ENDPOINT
5
+ # - INFERENCE_API
6
+ RENDERING_ENGINE="INFERENCE_API"
 
 
 
 
 
7
 
 
8
  # Supported values:
9
  # - INFERENCE_ENDPOINT
10
  # - INFERENCE_API
11
+ LLM_ENGINE="INFERENCE_API"
12
+
13
+ # ------------- PROVIDER AUTH ------------
14
 
15
+ # Hugging Face token, if you plan to use the Inference API or Inference Endpoint
16
+ # for the LLM or SDXL generation
17
  HF_API_TOKEN=
18
 
19
+ # Replicate token, if you wish to use them as a provider for SDXL
20
+ REPLICATE_API_TOKEN=
21
+
22
+ # OpenAI is not supported yet
23
+ OPENAI_TOKEN=
24
+
25
+ # VideoChain is a custom API used for SDXL but you don't need it for the base features
26
+ VIDEOCHAIN_API_TOKEN=
27
+
28
+ # ------------- RENDERING API CONFIG --------------
29
+
30
+ RENDERING_VIDEOCHAIN_API_URL="http://localhost:7860"
31
+
32
+ RENDERING_REPLICATE_API_MODEL="stabilityai/sdxl"
33
+ RENDERING_REPLICATE_API_MODEL_VERSION="da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf"
34
+
35
+ RENDERING_HF_INFERENCE_ENDPOINT_URL="https://XXXXXXXXXX.endpoints.huggingface.cloud"
36
+ RENDERING_HF_INFERENCE_API_MODEL="stabilityai/stable-diffusion-xl-base-1.0"
37
+
38
+ # ------------- LLM API CONFIG ----------------
39
+
40
  # URL to a custom text-generation Inference Endpoint of your choice
41
  # -> You can leave it empty if you decide to use an Inference API Model instead
42
+ LLM_HF_INFERENCE_ENDPOINT_URL=
43
 
44
  # You can also use a model from the Inference API (not a custom inference endpoint)
45
  # -> You can leave it empty if you decide to use an Inference Endpoint URL instead
46
+ LLM_HF_INFERENCE_API_MODEL="codellama/CodeLlama-7b-hf"
 
 
 
47
 
48
  # ----------- COMMUNITY SHARING (OPTIONAL) -----------
49
  NEXT_PUBLIC_ENABLE_COMMUNITY_SHARING="false"
README.md CHANGED
@@ -17,17 +17,28 @@ First, I would like to highlight that everything is open-source (see [here](http
17
  However the project isn't a monolithic Space that can be duplicated and ran immediately:
18
  it requires various components to run for the frontend, backend, LLM, SDXL etc.
19
 
20
- If you try to duplicate the project and open the `.env` you will see it requires some variables:
21
 
 
22
  - `LLM_ENGINE`: can be either "INFERENCE_API" or "INFERENCE_ENDPOINT"
23
- - `HF_API_TOKEN`: necessary if you decide to use an inference api model or a custom inference endpoint
24
- - `HF_INFERENCE_ENDPOINT_URL`: necessary if you decide to use a custom inference endpoint
25
  - `RENDERING_ENGINE`: can only be "VIDEOCHAIN" or "REPLICATE" for now, unless you code your custom solution
26
- - `VIDEOCHAIN_API_URL`: url to the VideoChain API server
 
 
27
  - `VIDEOCHAIN_API_TOKEN`: secret token to access the VideoChain API server
28
  - `REPLICATE_API_TOKEN`: in case you want to use Replicate.com
29
- - `REPLICATE_API_MODEL`: optional, defaults to "stabilityai/sdxl"
30
- - `REPLICATE_API_MODEL_VERSION`: optional, in case you want to change the version
 
 
 
 
 
 
 
 
 
 
31
 
32
  In addition, there are some community sharing variables that you can just ignore.
33
  Those variables are not required to run the AI Comic Factory on your own website or computer
 
17
  However the project isn't a monolithic Space that can be duplicated and ran immediately:
18
  it requires various components to run for the frontend, backend, LLM, SDXL etc.
19
 
20
+ If you try to duplicate the project, open the `.env` you will see it requires some variables.
21
 
22
+ Provider config:
23
  - `LLM_ENGINE`: can be either "INFERENCE_API" or "INFERENCE_ENDPOINT"
 
 
24
  - `RENDERING_ENGINE`: can only be "VIDEOCHAIN" or "REPLICATE" for now, unless you code your custom solution
25
+
26
+ Auth config:
27
+ - `HF_API_TOKEN`: necessary if you decide to use an inference api model or a custom inference endpoint
28
  - `VIDEOCHAIN_API_TOKEN`: secret token to access the VideoChain API server
29
  - `REPLICATE_API_TOKEN`: in case you want to use Replicate.com
30
+
31
+ Rendering config:
32
+ - `RENDERING_HF_INFERENCE_ENDPOINT_URL`: necessary if you decide to use a custom inference endpoint
33
+ - `RENDERING_REPLICATE_API_MODEL_VERSION`: url to the VideoChain API server
34
+ - `RENDERING_HF_INFERENCE_ENDPOINT_URL`: optional, default to nothing
35
+ - `RENDERING_HF_INFERENCE_API_MODEL`: optional, defaults to "stabilityai/stable-diffusion-xl-base-1.0"
36
+ - `RENDERING_REPLICATE_API_MODEL`: optional, defaults to "stabilityai/sdxl"
37
+ - `RENDERING_REPLICATE_API_MODEL_VERSION`: optional, in case you want to change the version
38
+
39
+ Language model config:
40
+ - `LLM_HF_INFERENCE_ENDPOINT_URL`: "https://llama-v2-70b-chat.ngrok.io"
41
+ - `LLM_HF_INFERENCE_API_MODEL`: "codellama/CodeLlama-7b-hf"
42
 
43
  In addition, there are some community sharing variables that you can just ignore.
44
  Those variables are not required to run the AI Comic Factory on your own website or computer
src/app/engine/caption.ts CHANGED
@@ -2,7 +2,7 @@
2
 
3
  import { ImageAnalysisRequest, ImageAnalysisResponse } from "@/types"
4
 
5
- const apiUrl = `${process.env.VIDEOCHAIN_API_URL || ""}`
6
 
7
  export async function see({
8
  prompt,
 
2
 
3
  import { ImageAnalysisRequest, ImageAnalysisResponse } from "@/types"
4
 
5
+ const apiUrl = `${process.env.RENDERING_VIDEOCHAIN_API_URL || ""}`
6
 
7
  export async function see({
8
  prompt,
src/app/engine/render.ts CHANGED
@@ -1,6 +1,7 @@
1
  "use server"
2
 
3
- import Replicate, { Prediction } from "replicate"
 
4
 
5
  import { RenderRequest, RenderedScene, RenderingEngine } from "@/types"
6
  import { generateSeed } from "@/lib/generateSeed"
@@ -8,13 +9,15 @@ import { sleep } from "@/lib/sleep"
8
 
9
  const renderingEngine = `${process.env.RENDERING_ENGINE || ""}` as RenderingEngine
10
 
11
- const replicateToken = `${process.env.REPLICATE_API_TOKEN || ""}`
12
- const replicateModel = `${process.env.REPLICATE_API_MODEL || ""}`
13
- const replicateModelVersion = `${process.env.REPLICATE_API_MODEL_VERSION || ""}`
14
 
15
- // note: there is no / at the end in the variable
16
- // so we have to add it ourselves if needed
17
- const apiUrl = process.env.VIDEOCHAIN_API_URL
 
 
 
 
18
 
19
  export async function newRender({
20
  prompt,
@@ -79,9 +82,74 @@ export async function newRender({
79
  maskUrl: "",
80
  segments: []
81
  } as RenderedScene
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  } else {
83
  // console.log(`calling POST ${apiUrl}/render with prompt: ${prompt}`)
84
- const res = await fetch(`${apiUrl}/render`, {
85
  method: "POST",
86
  headers: {
87
  Accept: "application/json",
@@ -202,7 +270,7 @@ export async function getRender(renderId: string) {
202
  } as RenderedScene
203
  } else {
204
  // console.log(`calling GET ${apiUrl}/render with renderId: ${renderId}`)
205
- const res = await fetch(`${apiUrl}/render/${renderId}`, {
206
  method: "GET",
207
  headers: {
208
  Accept: "application/json",
@@ -255,7 +323,7 @@ export async function upscaleImage(image: string): Promise<{
255
 
256
  try {
257
  // console.log(`calling GET ${apiUrl}/render with renderId: ${renderId}`)
258
- const res = await fetch(`${apiUrl}/upscale`, {
259
  method: "POST",
260
  headers: {
261
  Accept: "application/json",
 
1
  "use server"
2
 
3
+ import { v4 as uuidv4 } from "uuid"
4
+ import Replicate from "replicate"
5
 
6
  import { RenderRequest, RenderedScene, RenderingEngine } from "@/types"
7
  import { generateSeed } from "@/lib/generateSeed"
 
9
 
10
  const renderingEngine = `${process.env.RENDERING_ENGINE || ""}` as RenderingEngine
11
 
12
+ const videochainApiUrl = `${process.env.RENDERING_VIDEOCHAIN_API_URL || "" }`
 
 
13
 
14
+ const huggingFaceToken = `${process.env.HF_API_TOKEN || ""}`
15
+ const huggingFaceInferenceEndpointUrl = `${process.env.RENDERING_HF_INFERENCE_ENDPOINT_URL || ""}`
16
+ const huggingFaceInferenceApiModel = `${process.env.RENDERING_HF_INFERENCE_API_MODEL || ""}`
17
+
18
+ const replicateToken = `${process.env.RENDERING_REPLICATE_API_TOKEN || ""}`
19
+ const replicateModel = `${process.env.RENDERING_REPLICATE_API_MODEL || ""}`
20
+ const replicateModelVersion = `${process.env.RENDERING_REPLICATE_API_MODEL_VERSION || ""}`
21
 
22
  export async function newRender({
23
  prompt,
 
82
  maskUrl: "",
83
  segments: []
84
  } as RenderedScene
85
+ } if (renderingEngine === "INFERENCE_ENDPOINT" || renderingEngine === "INFERENCE_API") {
86
+ if (!huggingFaceToken) {
87
+ throw new Error(`you need to configure your HF_API_TOKEN in order to use the ${renderingEngine} rendering engine`)
88
+ }
89
+ if (renderingEngine === "INFERENCE_ENDPOINT" && !huggingFaceInferenceEndpointUrl) {
90
+ throw new Error(`you need to configure your RENDERING_HF_INFERENCE_ENDPOINT_URL in order to use the INFERENCE_ENDPOINT rendering engine`)
91
+ }
92
+ if (renderingEngine === "INFERENCE_API" && !huggingFaceInferenceApiModel) {
93
+ throw new Error(`you need to configure your RENDERING_HF_INFERENCE_API_MODEL in order to use the INFERENCE_API rendering engine`)
94
+ }
95
+
96
+ const seed = generateSeed()
97
+
98
+ const url = renderingEngine === "INFERENCE_ENDPOINT"
99
+ ? huggingFaceInferenceEndpointUrl
100
+ : `https://api-inference.huggingface.co/models/${huggingFaceInferenceApiModel}`
101
+
102
+ const res = await fetch(url, {
103
+ method: "POST",
104
+ headers: {
105
+ // Accept: "application/json",
106
+ "Content-Type": "application/json",
107
+ Authorization: `Bearer ${huggingFaceToken}`,
108
+ },
109
+ body: JSON.stringify({
110
+ inputs: [
111
+ "beautiful",
112
+ "intricate details",
113
+ prompt,
114
+ "award winning",
115
+ "high resolution"
116
+ ].join(", "),
117
+ parameters: {
118
+ num_inference_steps: 25,
119
+ guidance_scale: 8,
120
+ width,
121
+ height,
122
+
123
+ }
124
+ }),
125
+ cache: "no-store",
126
+ // we can also use this (see https://vercel.com/blog/vercel-cache-api-nextjs-cache)
127
+ // next: { revalidate: 1 }
128
+ })
129
+
130
+
131
+ // Recommendation: handle errors
132
+ if (res.status !== 200) {
133
+ // This will activate the closest `error.js` Error Boundary
134
+ throw new Error('Failed to fetch data')
135
+ }
136
+
137
+ // the result is a JSON-encoded string
138
+ const response = await res.json() as string
139
+ const assetUrl = `data:image/png;base64,${response}`
140
+
141
+ return {
142
+ renderId: uuidv4(),
143
+ status: "completed",
144
+ assetUrl,
145
+ alt: prompt,
146
+ error: "",
147
+ maskUrl: "",
148
+ segments: []
149
+ } as RenderedScene
150
  } else {
151
  // console.log(`calling POST ${apiUrl}/render with prompt: ${prompt}`)
152
+ const res = await fetch(`${videochainApiUrl}/render`, {
153
  method: "POST",
154
  headers: {
155
  Accept: "application/json",
 
270
  } as RenderedScene
271
  } else {
272
  // console.log(`calling GET ${apiUrl}/render with renderId: ${renderId}`)
273
+ const res = await fetch(`${videochainApiUrl}/render/${renderId}`, {
274
  method: "GET",
275
  headers: {
276
  Accept: "application/json",
 
323
 
324
  try {
325
  // console.log(`calling GET ${apiUrl}/render with renderId: ${renderId}`)
326
+ const res = await fetch(`${videochainApiUrl}/upscale`, {
327
  method: "POST",
328
  headers: {
329
  Accept: "application/json",
src/app/queries/predict.ts CHANGED
@@ -8,8 +8,8 @@ const hf = new HfInference(process.env.HF_API_TOKEN)
8
 
9
  // note: we always try "inference endpoint" first
10
  const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
11
- const inferenceEndpoint = `${process.env.HF_INFERENCE_ENDPOINT_URL || ""}`
12
- const inferenceModel = `${process.env.HF_INFERENCE_API_MODEL || ""}`
13
 
14
  let hfie: HfInferenceEndpoint
15
 
 
8
 
9
  // note: we always try "inference endpoint" first
10
  const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
11
+ const inferenceEndpoint = `${process.env.LLM_HF_INFERENCE_ENDPOINT_URL || ""}`
12
+ const inferenceModel = `${process.env.LLM_HF_INFERENCE_API_MODEL || ""}`
13
 
14
  let hfie: HfInferenceEndpoint
15
 
src/types.ts CHANGED
@@ -91,6 +91,8 @@ export type RenderingEngine =
91
  | "VIDEOCHAIN"
92
  | "OPENAI"
93
  | "REPLICATE"
 
 
94
 
95
  export type PostVisibility =
96
  | "featured" // featured by admins
 
91
  | "VIDEOCHAIN"
92
  | "OPENAI"
93
  | "REPLICATE"
94
+ | "INFERENCE_API"
95
+ | "INFERENCE_ENDPOINT"
96
 
97
  export type PostVisibility =
98
  | "featured" // featured by admins