DongfuJiang commited on
Commit
350d553
·
1 Parent(s): 625938c
model/model_manager.py CHANGED
@@ -3,7 +3,6 @@ import random
3
  import gradio as gr
4
  import requests
5
  import io, base64, json
6
- import spaces
7
  from PIL import Image
8
  from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, load_pipeline
9
 
@@ -21,7 +20,6 @@ class ModelManager:
21
  pipe = self.loaded_models[model_name]
22
  return pipe
23
 
24
- @spaces.GPU(duration=60)
25
  def generate_image_ig(self, prompt, model_name):
26
  pipe = self.load_model_pipe(model_name)
27
  result = pipe(prompt=prompt)
@@ -51,15 +49,9 @@ class ModelManager:
51
  results.append(result)
52
  return results[0], results[1]
53
 
54
- @spaces.GPU(duration=150)
55
  def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
56
  pipe = self.load_model_pipe(model_name)
57
- if 'PNP' in model_name:
58
- result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct, num_inversion_steps=100)
59
- elif 'Prompt2prompt' in model_name:
60
- result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct, num_inner_steps=5)
61
- else:
62
- result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
63
  return result
64
 
65
  def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
 
3
  import gradio as gr
4
  import requests
5
  import io, base64, json
 
6
  from PIL import Image
7
  from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, load_pipeline
8
 
 
20
  pipe = self.loaded_models[model_name]
21
  return pipe
22
 
 
23
  def generate_image_ig(self, prompt, model_name):
24
  pipe = self.load_model_pipe(model_name)
25
  result = pipe(prompt=prompt)
 
49
  results.append(result)
50
  return results[0], results[1]
51
 
 
52
  def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
53
  pipe = self.load_model_pipe(model_name)
54
+ result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
 
 
 
 
 
55
  return result
56
 
57
  def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
model/models/imagenhub_models.py CHANGED
@@ -7,5 +7,27 @@ class ImagenHubModel():
7
  def __call__(self, *args, **kwargs):
8
  return self.model.infer_one_image(*args, **kwargs)
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def load_imagenhub_model(model_name, model_type=None):
 
 
 
 
11
  return ImagenHubModel(model_name)
 
7
  def __call__(self, *args, **kwargs):
8
  return self.model.infer_one_image(*args, **kwargs)
9
 
10
+ class PNP(ImagenHubModel):
11
+ def __init__(self):
12
+ super().__init__('PNP')
13
+
14
+ def __call__(self, *args, **kwargs):
15
+ if "num_inversion_steps" not in kwargs:
16
+ kwargs["num_inversion_steps"] = 100
17
+ return super().__call__(*args, **kwargs)
18
+
19
+ class Prompt2prompt(ImagenHubModel):
20
+ def __init__(self):
21
+ super().__init__('Prompt2prompt')
22
+
23
+ def __call__(self, *args, **kwargs):
24
+ if "num_inner_steps" not in kwargs:
25
+ kwargs["num_inner_steps"] = 5
26
+ return super().__call__(*args, **kwargs)
27
+
28
  def load_imagenhub_model(model_name, model_type=None):
29
+ if model_name == 'PNP':
30
+ return PNP()
31
+ if model_name == 'Prompt2prompt':
32
+ return Prompt2prompt()
33
  return ImagenHubModel(model_name)
serve/constants.py CHANGED
@@ -13,4 +13,5 @@ LOG_SERVER_ADDR = os.getenv("LOG_SERVER_ADDR", f"{LOG_SERVER}/{LOG_SERVER_SUBDOA
13
  # LOG SERVER API ENDPOINTS
14
  APPEND_JSON = "append_json"
15
  SAVE_IMAGE = "save_image"
 
16
 
 
13
  # LOG SERVER API ENDPOINTS
14
  APPEND_JSON = "append_json"
15
  SAVE_IMAGE = "save_image"
16
+ SAVE_LOG = "save_log"
17
 
serve/log_server.py CHANGED
@@ -4,9 +4,9 @@ import json
4
  import os
5
  import aiofiles
6
  from .log_utils import build_logger
7
- from .constants import LOG_SERVER_SUBDOAMIN, APPEND_JSON, SAVE_IMAGE
8
 
9
- logger = build_logger("log_server", "log_server.log")
10
 
11
  app = APIRouter(prefix=f"/{LOG_SERVER_SUBDOAMIN}")
12
 
@@ -37,3 +37,20 @@ async def save_image(image: UploadFile = File(...), image_path: str = Form(...))
37
  await f.write(content) # Write the image content to a file
38
  logger.info(f"Image saved successfully at {image_path}")
39
  return {"message": f"Image saved successfully at {image_path}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import os
5
  import aiofiles
6
  from .log_utils import build_logger
7
+ from .constants import LOG_SERVER_SUBDOAMIN, APPEND_JSON, SAVE_IMAGE, SAVE_LOG
8
 
9
+ logger = build_logger("log_server", "log_server.log", add_remote_handler=False)
10
 
11
  app = APIRouter(prefix=f"/{LOG_SERVER_SUBDOAMIN}")
12
 
 
37
  await f.write(content) # Write the image content to a file
38
  logger.info(f"Image saved successfully at {image_path}")
39
  return {"message": f"Image saved successfully at {image_path}"}
40
+
41
+ @app.post(f"/{SAVE_LOG}")
42
+ async def save_log(message: str = Form(...), log_path: str = Form(...)):
43
+ """
44
+ Save a log message to a specified log file on the server.
45
+ """
46
+ print(f"Received log message: {message} to be saved at: {log_path}")
47
+ # Ensure the directory for the log file exists
48
+ if os.path.dirname(log_path):
49
+ os.makedirs(os.path.dirname(log_path), exist_ok=True)
50
+
51
+ # Append the log message to the specified log file
52
+ async with aiofiles.open(log_path, mode='a') as f:
53
+ await f.write(f"{message}\n")
54
+
55
+ logger.info(f"Romote log message saved to {log_path}")
56
+ return {"message": f"Log message saved successfully to {log_path}"}
serve/log_utils.py CHANGED
@@ -10,17 +10,36 @@ import platform
10
  import sys
11
  from typing import AsyncGenerator, Generator
12
  import warnings
 
13
 
14
  import requests
15
 
16
- from .constants import LOGDIR
 
17
 
18
 
19
  handler = None
20
  visited_loggers = set()
21
 
22
 
23
- def build_logger(logger_name, logger_filename):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  global handler
25
 
26
  formatter = logging.Formatter(
@@ -56,6 +75,15 @@ def build_logger(logger_name, logger_filename):
56
  # Get logger
57
  logger = logging.getLogger(logger_name)
58
  logger.setLevel(logging.INFO)
 
 
 
 
 
 
 
 
 
59
 
60
  # if LOGDIR is empty, then don't try output log to local file
61
  if LOGDIR != "":
 
10
  import sys
11
  from typing import AsyncGenerator, Generator
12
  import warnings
13
+ from pathlib import Path
14
 
15
  import requests
16
 
17
+ from .constants import LOGDIR, LOG_SERVER_ADDR, SAVE_LOG
18
+ from .utils import save_log_str_on_log_server
19
 
20
 
21
  handler = None
22
  visited_loggers = set()
23
 
24
 
25
+ # Assuming LOGDIR and other necessary imports and global variables are defined
26
+
27
+ class APIHandler(logging.Handler):
28
+ """Custom logging handler that sends logs to an API."""
29
+
30
+ def __init__(self, apiUrl, log_path, *args, **kwargs):
31
+ super(APIHandler, self).__init__(*args, **kwargs)
32
+ self.apiUrl = apiUrl
33
+ self.log_path = log_path
34
+
35
+ def emit(self, record):
36
+ log_entry = self.format(record)
37
+ try:
38
+ save_log_str_on_log_server(log_entry, self.log_path)
39
+ except requests.RequestException as e:
40
+ print(f"Error sending log to API: {e}", file=sys.stderr)
41
+
42
+ def build_logger(logger_name, logger_filename, add_remote_handler=True):
43
  global handler
44
 
45
  formatter = logging.Formatter(
 
75
  # Get logger
76
  logger = logging.getLogger(logger_name)
77
  logger.setLevel(logging.INFO)
78
+
79
+ if add_remote_handler:
80
+ # Add APIHandler to send logs to your API
81
+ api_url = f"{LOG_SERVER_ADDR}/{SAVE_LOG}"
82
+
83
+ remote_logger_filename = str(Path(logger_filename).stem + "_remote.log")
84
+ api_handler = APIHandler(apiUrl=api_url, log_path=f"{LOGDIR}/{remote_logger_filename}")
85
+ api_handler.setFormatter(formatter)
86
+ logger.addHandler(api_handler)
87
 
88
  # if LOGDIR is empty, then don't try output log to local file
89
  if LOGDIR != "":
serve/utils.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
  import gradio as gr
7
  from pathlib import Path
8
  from model.model_registry import *
9
- from .constants import LOGDIR, LOG_SERVER_ADDR, APPEND_JSON, SAVE_IMAGE
10
  from typing import Union
11
 
12
 
@@ -159,4 +159,12 @@ def append_json_item_on_log_server(json_item: Union[dict, str], log_file: str):
159
  url = f"{LOG_SERVER_ADDR}/{APPEND_JSON}"
160
  # Make the POST request, sending the JSON string and the log file name
161
  response = requests.post(url, data={'json_str': json_item, 'file_name': log_file})
 
 
 
 
 
 
 
 
162
  return response
 
6
  import gradio as gr
7
  from pathlib import Path
8
  from model.model_registry import *
9
+ from .constants import LOGDIR, LOG_SERVER_ADDR, APPEND_JSON, SAVE_IMAGE, SAVE_LOG
10
  from typing import Union
11
 
12
 
 
159
  url = f"{LOG_SERVER_ADDR}/{APPEND_JSON}"
160
  # Make the POST request, sending the JSON string and the log file name
161
  response = requests.post(url, data={'json_str': json_item, 'file_name': log_file})
162
+ return response
163
+
164
+ def save_log_str_on_log_server(log_str: str, log_file: str):
165
+ log_file = Path(log_file).absolute().relative_to(os.getcwd())
166
+ log_file = str(log_file)
167
+ url = f"{LOG_SERVER_ADDR}/{SAVE_LOG}"
168
+ # Make the POST request, sending the log message and the log file name
169
+ response = requests.post(url, data={'message': log_str, 'log_path': log_file})
170
  return response