ksort commited on
Commit
afdb110
·
1 Parent(s): 0e503f3

Add new model

Browse files
model/model_manager.py CHANGED
@@ -5,6 +5,7 @@ import requests
5
  import io, base64, json
6
  import spaces
7
  from PIL import Image
 
8
  from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, load_pipeline
9
  from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_museum
10
  from serve.upload import get_random_mscoco_prompt
@@ -26,7 +27,7 @@ class ModelManager:
26
  @spaces.GPU(duration=120)
27
  def generate_image_ig(self, prompt, model_name):
28
  pipe = self.load_model_pipe(model_name)
29
- if 'cascade' not in name:
30
  result = pipe(prompt=prompt).images[0]
31
  else:
32
  prior, decoder = pipe
@@ -40,7 +41,6 @@ class ModelManager:
40
  num_images_per_prompt=1,
41
  num_inference_steps=20
42
  )
43
-
44
  decoder.enable_model_cpu_offload()
45
  result = decoder(
46
  image_embeddings=prior_output.image_embeddings.to(torch.float16),
@@ -55,6 +55,7 @@ class ModelManager:
55
  def generate_image_ig_api(self, prompt, model_name):
56
  pipe = self.load_model_pipe(model_name)
57
  result = pipe(prompt=prompt)
 
58
  return result
59
 
60
  def generate_image_ig_museum(self, model_name):
 
5
  import io, base64, json
6
  import spaces
7
  from PIL import Image
8
+ from openai import OpenAI
9
  from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, load_pipeline
10
  from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_museum
11
  from serve.upload import get_random_mscoco_prompt
 
27
  @spaces.GPU(duration=120)
28
  def generate_image_ig(self, prompt, model_name):
29
  pipe = self.load_model_pipe(model_name)
30
+ if 'Stable-cascade' not in name:
31
  result = pipe(prompt=prompt).images[0]
32
  else:
33
  prior, decoder = pipe
 
41
  num_images_per_prompt=1,
42
  num_inference_steps=20
43
  )
 
44
  decoder.enable_model_cpu_offload()
45
  result = decoder(
46
  image_embeddings=prior_output.image_embeddings.to(torch.float16),
 
55
  def generate_image_ig_api(self, prompt, model_name):
56
  pipe = self.load_model_pipe(model_name)
57
  result = pipe(prompt=prompt)
58
+
59
  return result
60
 
61
  def generate_image_ig_museum(self, model_name):
model/models/__init__.py CHANGED
@@ -5,21 +5,8 @@ from .fal_api_models import load_fal_model
5
  from .huggingface_models import load_huggingface_model
6
  from .replicate_api_models import load_replicate_model
7
  from .openai_api_models import load_openai_model
 
8
 
9
- # IMAGE_GENERATION_MODELS = ['huggingface_SD-v1.5_text2image',
10
- # 'huggingface_SD-v2.1_text2image',
11
- # 'huggingface_SD-XL-v1.0_text2image',
12
- # 'huggingface_IF-I-XL-v1.0_text2image',
13
- # ]
14
-
15
- # IMAGE_GENERATION_MODELS = [ 'imagenhub_SD_generation',
16
- # 'imagenhub_SDXL_generation',
17
- # 'imagenhub_OpenJourney_generation',
18
- # 'imagenhub_LCM_generation',
19
- # 'imagenhub_DeepFloydIF_generation',
20
- # 'imagenhub_PixArtAlpha_generation',
21
- # 'imagenhub_Kandinsky_generation',
22
- # ]
23
 
24
  IMAGE_GENERATION_MODELS = [
25
  'replicate_SDXL_text2image',
@@ -44,6 +31,11 @@ IMAGE_GENERATION_MODELS = [
44
  'replicate_Deepfloyd-IF_text2image',
45
  'huggingface_SD-turbo_text2image',
46
  'huggingface_SDXL-turbo_text2image',
 
 
 
 
 
47
  ]
48
 
49
 
@@ -78,7 +70,9 @@ def load_pipeline(model_name):
78
  elif model_source == "huggingface":
79
  pipe = load_huggingface_model(model_name, model_type)
80
  elif model_source == "openai":
81
- pipe = load_openai_model(model_name)
 
 
82
  else:
83
  raise ValueError(f"Model source {model_source} not supported")
84
  return pipe
 
5
  from .huggingface_models import load_huggingface_model
6
  from .replicate_api_models import load_replicate_model
7
  from .openai_api_models import load_openai_model
8
+ from .other_api_models import load_other_model
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  IMAGE_GENERATION_MODELS = [
12
  'replicate_SDXL_text2image',
 
31
  'replicate_Deepfloyd-IF_text2image',
32
  'huggingface_SD-turbo_text2image',
33
  'huggingface_SDXL-turbo_text2image',
34
+ 'huggingface_Stable-cascade_text2image',
35
+ 'openai_Dalle-2_text2image',
36
+ 'openai_Dalle-3_text2image',
37
+ 'other_Midjourney-v6.0_text2image',
38
+ 'other_Midjourney-v5.0_text2image',
39
  ]
40
 
41
 
 
70
  elif model_source == "huggingface":
71
  pipe = load_huggingface_model(model_name, model_type)
72
  elif model_source == "openai":
73
+ pipe = load_openai_model(model_name, model_type)
74
+ elif model_source == "other":
75
+ pipe = load_other_model(model_name, model_type)
76
  else:
77
  raise ValueError(f"Model source {model_source} not supported")
78
  return pipe
model/models/huggingface_models.py CHANGED
@@ -4,8 +4,6 @@ from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
4
  import torch
5
 
6
 
7
-
8
-
9
  def load_huggingface_model(model_name, model_type):
10
  if model_name == "SD-turbo":
11
  pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
@@ -30,10 +28,10 @@ def load_huggingface_model(model_name, model_type):
30
 
31
 
32
  if __name__ == "__main__":
33
- for name in ["SD-turbo", "SDXL-turbo"]: #"SD-turbo", "SDXL-turbo"
34
- pipe = load_huggingface_model(name, "text2image")
35
-
36
 
37
  # for name in ["IF-I-XL-v1.0"]:
38
  # pipe = load_huggingface_model(name, 'text2image')
39
  # pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
 
 
4
  import torch
5
 
6
 
 
 
7
  def load_huggingface_model(model_name, model_type):
8
  if model_name == "SD-turbo":
9
  pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
 
28
 
29
 
30
  if __name__ == "__main__":
31
+ # for name in ["SD-turbo", "SDXL-turbo"]: #"SD-turbo", "SDXL-turbo"
32
+ # pipe = load_huggingface_model(name, "text2image")
 
33
 
34
  # for name in ["IF-I-XL-v1.0"]:
35
  # pipe = load_huggingface_model(name, 'text2image')
36
  # pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
37
+
model/models/openai_api_models.py CHANGED
@@ -1,33 +1,59 @@
1
  from openai import OpenAI
 
 
 
 
 
2
 
3
 
4
- def load_openai_model(model_name):
5
- client = OpenAI()
 
 
 
 
6
 
7
- if model_name == "Dalle-3":
8
- response = client.images.generate(
9
- model="dall-e-3",
10
- prompt="a white siamese cat",
11
- size="1024x1024",
12
- quality="standard",
13
- n=1,
14
- )
15
- elif model_name == "Dalle-2":
16
- response = client.images.generate(
17
- model="dall-e-2",
18
- prompt="a white siamese cat",
19
- size="512x512",
20
- quality="standard",
21
- n=1,
22
- )
23
- else:
24
- raise NotImplementedError
25
 
26
- image_url = response.data[0].url
27
 
28
- return image_url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  if __name__ == "__main__":
32
- image_url = load_openai_model('Dalle-3')
33
- print(image_url)
 
 
 
1
  from openai import OpenAI
2
+ from PIL import Image
3
+ import requests
4
+ import io
5
+ import os
6
+ import base64
7
 
8
 
9
+ class OpenaiModel():
10
+ def __init__(self, model_name, model_type):
11
+ self.model_name = model_name
12
+ self.model_type = model_type
13
+
14
+ def __call__(self, *args, **kwargs):
15
 
16
+ if self.model_type == "text2image":
17
+ assert "prompt" in kwargs, "prompt is required for text2image model"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ client = OpenAI()
20
 
21
+ if 'Dalle-3' in self.model_name:
22
+ client = OpenAI()
23
+ response = client.images.generate(
24
+ model="dall-e-3",
25
+ prompt=kwargs["prompt"],
26
+ size="1024x1024",
27
+ quality="standard",
28
+ n=1,
29
+ )
30
+ elif 'Dalle-2' in self.model_name:
31
+ client = OpenAI()
32
+ response = client.images.generate(
33
+ model="dall-e-2",
34
+ prompt=kwargs["prompt"],
35
+ size="512x512",
36
+ quality="standard",
37
+ n=1,
38
+ )
39
+ else:
40
+ raise NotImplementedError
41
+
42
+ result_url = response.data[0].url
43
+ response = requests.get(result_url)
44
+ result = Image.open(io.BytesIO(response.content))
45
+ return result
46
+ else:
47
+ raise ValueError("model_type must be text2image or image2image")
48
+
49
+
50
+
51
+ def load_openai_model(model_name, model_type):
52
+ return OpenaiModel(model_name, model_type)
53
 
54
 
55
  if __name__ == "__main__":
56
+ pipe = load_openai_model('Dalle-2', 'text2image')
57
+ result = pipe(prompt='draw a tiger')
58
+ print(result)
59
+
model/models/other_api_models.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import os
4
+ from PIL import Image
5
+ import io, time
6
+
7
+ class OtherModel():
8
+ def __init__(self, model_name, model_type):
9
+ self.model_name = model_name
10
+ self.model_type = model_type
11
+ self.url = "https://www.xdai.online/mj/submit/imagine"
12
+ self.key = os.environ.get('MIDJOURNEY_KEY')
13
+ self.get_url = "https://www.xdai.online/mj/image/"
14
+ self.repeat_num = 5
15
+
16
+ def __call__(self, *args, **kwargs):
17
+ if self.model_type == "text2image":
18
+ assert "prompt" in kwargs, "prompt is required for text2image model"
19
+ if self.model_name == "Midjourney-v6.0":
20
+ data = {
21
+ "base64Array": [],
22
+ "notifyHook": "",
23
+ "prompt": "{} --v 6.0".format(kwargs["prompt"]),
24
+ "state": "",
25
+ "botType": "MID_JOURNEY",
26
+ }
27
+ elif self.model_name == "Midjourney-v5.0":
28
+ data = {
29
+ "base64Array": [],
30
+ "notifyHook": "",
31
+ "prompt": "{} --v 5.0".format(kwargs["prompt"]),
32
+ "state": "",
33
+ "botType": "MID_JOURNEY",
34
+ }
35
+ else:
36
+ raise NotImplementedError
37
+
38
+ headers = {
39
+ "Authorization": "Bearer {}".format(self.key),
40
+ "Content-Type": "application/json"
41
+ }
42
+ while 1:
43
+ response = requests.post(self.url, data=json.dumps(data), headers=headers)
44
+ if response.status_code == 200:
45
+ print("Submit success!")
46
+ response_json = json.loads(response.content.decode('utf-8'))
47
+ img_id = response_json["result"]
48
+ result_url = self.get_url + img_id
49
+ self.repeat_num = 120
50
+ while 1:
51
+ time.sleep(1)
52
+ img_response = requests.get(result_url)
53
+ if img_response.status_code == 200:
54
+ result = Image.open(io.BytesIO(img_response.content))
55
+ width, height = result.size
56
+ new_width = width // 2
57
+ new_height = height // 2
58
+ result = result.crop((0, 0, new_width, new_height))
59
+ self.repeat_num = 5
60
+ return result
61
+ else:
62
+ self.repeat_num = self.repeat_num - 1
63
+ if self.repeat_num == 0:
64
+ raise ValueError("Image request failed.")
65
+ continue
66
+
67
+ else:
68
+ self.repeat_num = self.repeat_num - 1
69
+ if self.repeat_num == 0:
70
+ raise ValueError("API request failed.")
71
+ continue
72
+ else:
73
+ raise ValueError("model_type must be text2image")
74
+ def load_other_model(model_name, model_type):
75
+ return OtherModel(model_name, model_type)
76
+
77
+ if __name__ == "__main__":
78
+
79
+ pipe = load_other_model("Midjourney-v5.0", "text2image")
80
+ result = pipe(prompt="a good girl")
81
+ print(result)
82
+
83
+
84
+