EvoSDXL-JP / app.py
cocktailpeanut's picture
update
d82dd30
#!/usr/bin/env python
import os
import random
import uuid
import gradio as gr
import numpy as np
#import spaces
import torch
from PIL import Image
from evosdxl_jp_v1 import load_evosdxl_jp
import devicetorch
DESCRIPTION = """# 🐟 EvoSDXL-JP
🤗 [モデル一覧](https://huggingface.co/SakanaAI) | 📚 [技術レポート](https://arxiv.org/abs/2403.13187) | 📝 [ブログ](https://sakana.ai/evosdxl-jp/) | 🐦 [Twitter](https://twitter.com/SakanaAILabs)
[EvoSDXL-JP](https://huggingface.co/SakanaAI/EvoSDXL-JP-v1)は[Sakana AI](https://sakana.ai/)が教育目的で開発した日本特化の高速な画像生成モデルです。
入力した日本語プロンプトに沿った画像を生成することができます。より詳しくは、上記のブログをご参照ください。
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
#device = "cuda" if torch.cuda.is_available() else "cpu"
device = devicetorch.get(torch)
NUM_IMAGES_PER_PROMPT = 1
ENABLE_CPU_OFFLOAD = False
USE_TORCH_COMPILE = False
#SAFETY_CHECKER = True
SAFETY_CHECKER = False
DEVELOP_MODE = True
if SAFETY_CHECKER:
from safety_checker import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
).to(device)
feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)
def check_nsfw_images(
images: list[Image.Image],
) -> tuple[list[Image.Image], list[bool]]:
safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
has_nsfw_concepts = safety_checker(
images=[images],
clip_input=safety_checker_input.pixel_values.to(device)
)
return images, has_nsfw_concepts
#pipe = load_evosdxl_jp("cpu").to("cuda")
pipe = load_evosdxl_jp("cpu").to(device)
def show_warning(warning_text: str) -> gr.Blocks:
with gr.Blocks() as demo:
gr.Markdown(warning_text)
return demo
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
@spaces.GPU
def generate(
prompt: str,
seed: int = 0,
randomize_seed: bool = False,
progress=gr.Progress(track_tqdm=True),
):
pipe.to(device)
seed = int(randomize_seed_fn(seed, randomize_seed))
generator = torch.Generator().manual_seed(seed)
images = pipe(
prompt=prompt,
width=1024,
height=1024,
guidance_scale=0,
num_inference_steps=4,
generator=generator,
num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
output_type="pil",
).images
if SAFETY_CHECKER:
images, has_nsfw_concepts = check_nsfw_images(images)
if any(has_nsfw_concepts):
gr.Warning("NSFW content detected.")
return Image.new("RGB", (512, 512), "WHITE"), seed
return images[0], seed
examples = [
"柴犬が草原に立つ、幻想的な空、アート、最高品質の写真、ピントが当たってる"
]
css = '''
.gradio-container{max-width: 690px !important}
h1{text-align:center}
'''
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Group():
with gr.Row():
prompt = gr.Textbox(placeholder="日本語でプロンプトを入力してください。", show_label=False, scale=8)
submit = gr.Button(scale=0)
result = gr.Image(label="EvoSDXL-JPからの生成結果", show_label=False)
with gr.Accordion("詳細設定", open=False):
seed = gr.Slider(
label="シード値",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="ランダムにシード値を決定", value=True)
# gr.Examples(
# examples=examples,
# inputs=prompt,
# outputs=[result, seed],
# fn=generate,
# # cache_examples=CACHE_EXAMPLES,
# )
gr.on(
triggers=[
prompt.submit,
submit.click,
],
fn=generate,
inputs=[
prompt,
seed,
randomize_seed,
],
outputs=[result, seed],
api_name="run",
)
gr.Markdown("""⚠️ 本モデルは実験段階のプロトタイプであり、教育および研究開発の目的でのみ提供されています。商用利用や、障害が重大な影響を及ぼす可能性のある環境(ミッションクリティカルな環境)での使用には適していません。
本モデルの使用は、利用者の自己責任で行われ、その性能や結果については何ら保証されません。
Sakana AIは、本モデルの使用によって生じた直接的または間接的な損失に対して、結果に関わらず、一切の責任を負いません。
利用者は、本モデルの使用に伴うリスクを十分に理解し、自身の判断で使用することが必要です。""")
demo.queue().launch()