WideMax commited on
Commit
80e7e8e
·
1 Parent(s): f1a1b02

Initial commit

Browse files
Files changed (8) hide show
  1. app.py +96 -0
  2. inference_pb2.py +30 -0
  3. inference_pb2.pyi +25 -0
  4. inference_pb2_grpc.py +102 -0
  5. input/1.png +0 -0
  6. input/13.png +0 -0
  7. input/zebra.jpeg +0 -0
  8. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import BytesIO
3
+
4
+ import gradio as gr
5
+ import grpc
6
+ from PIL import Image
7
+ from cachetools import LRUCache
8
+ import hashlib
9
+
10
+ from protos.inference_pb2 import GuideAndRescaleRequest, GuideAndRescaleResponse
11
+ from protos.inference_pb2_grpc import GuideAndRescaleServiceStub
12
+
13
+
14
+ def get_bytes(img):
15
+ if img is None:
16
+ return img
17
+
18
+ buffered = BytesIO()
19
+ img.save(buffered, format="JPEG")
20
+ return buffered.getvalue()
21
+
22
+
23
+ def bytes_to_image(image: bytes) -> Image.Image:
24
+ image = Image.open(BytesIO(image))
25
+ return image
26
+
27
+
28
+ def resize(img):
29
+ if img.size != (512, 512):
30
+ img = img.resize((512, 512), Image.Resampling.LANCZOS)
31
+
32
+ return img
33
+
34
+
35
+ def edit(image, source_prompt, target_prompt, config, progress=gr.Progress(track_tqdm=True)):
36
+ if not image or not source_prompt or not target_prompt:
37
+ raise ValueError("Need to upload an image and enter init and edit prompts")
38
+
39
+ image_bytes = get_bytes(image)
40
+ os.environ['SERVER'] = "0.0.0.0:50052"
41
+ with grpc.insecure_channel(os.environ['SERVER']) as channel:
42
+ stub = GuideAndRescaleServiceStub(channel)
43
+
44
+ output: GuideAndRescaleResponse = stub.swap(
45
+ GuideAndRescaleRequest(image=image_bytes, source_prompt=source_prompt, target_prompt=target_prompt,
46
+ config=config, use_cache=True)
47
+ )
48
+
49
+ output = bytes_to_image(output.image)
50
+ return output
51
+
52
+
53
+ def get_demo():
54
+ with gr.Blocks() as demo:
55
+ gr.Markdown("## Guide-and-Rescale")
56
+ gr.Markdown(
57
+ '<div style="display: flex; align-items: center; gap: 10px;">'
58
+ '<span>Official Guide-and-Rescale Gradio demo:</span>'
59
+ '<a href="https://github.com/AIRI-Institute/Guide-and-Rescale"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>'
60
+ '<a href="https://colab.research.google.com/drive/1noKOOcDBBL_m5_UqU15jBBqiM8piLZ1O?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>'
61
+ '</div>'
62
+ )
63
+ with gr.Row():
64
+ with gr.Column():
65
+ with gr.Row():
66
+ image = gr.Image(label="Image that you want to edit", type="pil")
67
+ with gr.Row():
68
+ source_prompt = gr.Textbox(label="Init Prompt", info="Describs the content on the original image.")
69
+ target_prompt = gr.Textbox(label="Edit Prompt", info="Describs what is expected in the output image.")
70
+ config = gr.Radio(["non-stylisation", "stylisation"], value='non-stylisation',
71
+ label="Type of Editing", info="Selects a config for editing.")
72
+ with gr.Row():
73
+ btn = gr.Button("Edit image")
74
+ with gr.Column():
75
+ with gr.Row():
76
+ output = gr.Image(label="Result: edited image")
77
+
78
+ gr.Examples(examples=[["input/1.png", 'A photo of a tiger', 'A photo of a lion', 'non-stylisation'], ["input/zebra.jpeg", 'A photo of a zebra', 'A photo of a white horse', 'non-stylisation'], ["input/13.png", 'A photo', 'Anime style face', 'stylisation']], inputs=[image, source_prompt, target_prompt, config],
79
+ outputs=output)
80
+
81
+ image.upload(fn=resize, inputs=[image], outputs=image)
82
+
83
+ btn.click(fn=edit, inputs=[image, source_prompt, target_prompt, config], outputs=output)
84
+
85
+ gr.Markdown('''To cite the paper by the authors
86
+ ```
87
+ TODO: add cite
88
+ ```
89
+ ''')
90
+ return demo
91
+
92
+
93
+ if __name__ == '__main__':
94
+ align_cache = LRUCache(maxsize=10)
95
+ demo = get_demo()
96
+ demo.launch(server_name="0.0.0.0", server_port=7860)
inference_pb2.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: inference.proto
4
+ # Protobuf Python Version: 5.26.1
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+ # @@protoc_insertion_point(imports)
11
+
12
+ _sym_db = _symbol_database.Default()
13
+
14
+
15
+
16
+
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0finference.proto\x12\tinference\"x\n\x16GuideAndRescaleRequest\x12\r\n\x05image\x18\x01 \x01(\x0c\x12\x15\n\rsource_prompt\x18\x02 \x01(\t\x12\x15\n\rtarget_prompt\x18\x03 \x01(\t\x12\x0e\n\x06\x63onfig\x18\x04 \x01(\t\x12\x11\n\tuse_cache\x18\x05 \x01(\x08\"(\n\x17GuideAndRescaleResponse\x12\r\n\x05image\x18\x01 \x01(\x0c\x32g\n\x16GuideAndRescaleService\x12M\n\x04swap\x12!.inference.GuideAndRescaleRequest\x1a\".inference.GuideAndRescaleResponseb\x06proto3')
18
+
19
+ _globals = globals()
20
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
21
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'inference_pb2', _globals)
22
+ if not _descriptor._USE_C_DESCRIPTORS:
23
+ DESCRIPTOR._loaded_options = None
24
+ _globals['_GUIDEANDRESCALEREQUEST']._serialized_start=30
25
+ _globals['_GUIDEANDRESCALEREQUEST']._serialized_end=150
26
+ _globals['_GUIDEANDRESCALERESPONSE']._serialized_start=152
27
+ _globals['_GUIDEANDRESCALERESPONSE']._serialized_end=192
28
+ _globals['_GUIDEANDRESCALESERVICE']._serialized_start=194
29
+ _globals['_GUIDEANDRESCALESERVICE']._serialized_end=297
30
+ # @@protoc_insertion_point(module_scope)
inference_pb2.pyi ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from google.protobuf import descriptor as _descriptor
2
+ from google.protobuf import message as _message
3
+ from typing import ClassVar as _ClassVar, Optional as _Optional
4
+
5
+ DESCRIPTOR: _descriptor.FileDescriptor
6
+
7
+ class GuideAndRescaleRequest(_message.Message):
8
+ __slots__ = ("image", "source_prompt", "target_prompt", "config", "use_cache")
9
+ IMAGE_FIELD_NUMBER: _ClassVar[int]
10
+ SOURCE_PROMPT_FIELD_NUMBER: _ClassVar[int]
11
+ TARGET_PROMPT_FIELD_NUMBER: _ClassVar[int]
12
+ CONFIG_FIELD_NUMBER: _ClassVar[int]
13
+ USE_CACHE_FIELD_NUMBER: _ClassVar[int]
14
+ image: bytes
15
+ source_prompt: str
16
+ target_prompt: str
17
+ config: str
18
+ use_cache: bool
19
+ def __init__(self, image: _Optional[bytes] = ..., source_prompt: _Optional[str] = ..., target_prompt: _Optional[str] = ..., config: _Optional[str] = ..., use_cache: bool = ...) -> None: ...
20
+
21
+ class GuideAndRescaleResponse(_message.Message):
22
+ __slots__ = ("image",)
23
+ IMAGE_FIELD_NUMBER: _ClassVar[int]
24
+ image: bytes
25
+ def __init__(self, image: _Optional[bytes] = ...) -> None: ...
inference_pb2_grpc.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+ import warnings
5
+
6
+ import inference_pb2 as inference__pb2
7
+
8
+ GRPC_GENERATED_VERSION = '1.65.1'
9
+ GRPC_VERSION = grpc.__version__
10
+ EXPECTED_ERROR_RELEASE = '1.66.0'
11
+ SCHEDULED_RELEASE_DATE = 'August 6, 2024'
12
+ _version_not_supported = False
13
+
14
+ try:
15
+ from grpc._utilities import first_version_is_lower
16
+ _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
17
+ except ImportError:
18
+ _version_not_supported = True
19
+
20
+ if _version_not_supported:
21
+ warnings.warn(
22
+ f'The grpc package installed is at version {GRPC_VERSION},'
23
+ + f' but the generated code in inference_pb2_grpc.py depends on'
24
+ + f' grpcio>={GRPC_GENERATED_VERSION}.'
25
+ + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
26
+ + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
27
+ + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
28
+ + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
29
+ RuntimeWarning
30
+ )
31
+
32
+
33
+ class GuideAndRescaleServiceStub(object):
34
+ """Missing associated documentation comment in .proto file."""
35
+
36
+ def __init__(self, channel):
37
+ """Constructor.
38
+
39
+ Args:
40
+ channel: A grpc.Channel.
41
+ """
42
+ self.swap = channel.unary_unary(
43
+ '/inference.GuideAndRescaleService/swap',
44
+ request_serializer=inference__pb2.GuideAndRescaleRequest.SerializeToString,
45
+ response_deserializer=inference__pb2.GuideAndRescaleResponse.FromString,
46
+ _registered_method=True)
47
+
48
+
49
+ class GuideAndRescaleServiceServicer(object):
50
+ """Missing associated documentation comment in .proto file."""
51
+
52
+ def swap(self, request, context):
53
+ """Missing associated documentation comment in .proto file."""
54
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
55
+ context.set_details('Method not implemented!')
56
+ raise NotImplementedError('Method not implemented!')
57
+
58
+
59
+ def add_GuideAndRescaleServiceServicer_to_server(servicer, server):
60
+ rpc_method_handlers = {
61
+ 'swap': grpc.unary_unary_rpc_method_handler(
62
+ servicer.swap,
63
+ request_deserializer=inference__pb2.GuideAndRescaleRequest.FromString,
64
+ response_serializer=inference__pb2.GuideAndRescaleResponse.SerializeToString,
65
+ ),
66
+ }
67
+ generic_handler = grpc.method_handlers_generic_handler(
68
+ 'inference.GuideAndRescaleService', rpc_method_handlers)
69
+ server.add_generic_rpc_handlers((generic_handler,))
70
+ server.add_registered_method_handlers('inference.GuideAndRescaleService', rpc_method_handlers)
71
+
72
+
73
+ # This class is part of an EXPERIMENTAL API.
74
+ class GuideAndRescaleService(object):
75
+ """Missing associated documentation comment in .proto file."""
76
+
77
+ @staticmethod
78
+ def swap(request,
79
+ target,
80
+ options=(),
81
+ channel_credentials=None,
82
+ call_credentials=None,
83
+ insecure=False,
84
+ compression=None,
85
+ wait_for_ready=None,
86
+ timeout=None,
87
+ metadata=None):
88
+ return grpc.experimental.unary_unary(
89
+ request,
90
+ target,
91
+ '/inference.GuideAndRescaleService/swap',
92
+ inference__pb2.GuideAndRescaleRequest.SerializeToString,
93
+ inference__pb2.GuideAndRescaleResponse.FromString,
94
+ options,
95
+ channel_credentials,
96
+ insecure,
97
+ call_credentials,
98
+ compression,
99
+ wait_for_ready,
100
+ timeout,
101
+ metadata,
102
+ _registered_method=True)
input/1.png ADDED
input/13.png ADDED
input/zebra.jpeg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ pillow==10.0.0
2
+ grpcio==1.63.0
3
+ grpcio_tools==1.63.0
4
+ gradio==4.31.5
5
+ cachetools==5.3.3