Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
import os | |
import random | |
import uuid | |
import gradio as gr | |
import numpy as np | |
#import spaces | |
import torch | |
from PIL import Image | |
from evosdxl_jp_v1 import load_evosdxl_jp | |
import devicetorch | |
DESCRIPTION = """# 🐟 EvoSDXL-JP | |
🤗 [モデル一覧](https://huggingface.co/SakanaAI) | 📚 [技術レポート](https://arxiv.org/abs/2403.13187) | 📝 [ブログ](https://sakana.ai/evosdxl-jp/) | 🐦 [Twitter](https://twitter.com/SakanaAILabs) | |
[EvoSDXL-JP](https://huggingface.co/SakanaAI/EvoSDXL-JP-v1)は[Sakana AI](https://sakana.ai/)が教育目的で開発した日本特化の高速な画像生成モデルです。 | |
入力した日本語プロンプトに沿った画像を生成することができます。より詳しくは、上記のブログをご参照ください。 | |
""" | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>" | |
MAX_SEED = np.iinfo(np.int32).max | |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1" | |
#device = "cuda" if torch.cuda.is_available() else "cpu" | |
device = devicetorch.get(torch) | |
NUM_IMAGES_PER_PROMPT = 1 | |
ENABLE_CPU_OFFLOAD = False | |
USE_TORCH_COMPILE = False | |
#SAFETY_CHECKER = True | |
SAFETY_CHECKER = False | |
DEVELOP_MODE = True | |
if SAFETY_CHECKER: | |
from safety_checker import StableDiffusionSafetyChecker | |
from transformers import CLIPFeatureExtractor | |
safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
"CompVis/stable-diffusion-safety-checker" | |
).to(device) | |
feature_extractor = CLIPFeatureExtractor.from_pretrained( | |
"openai/clip-vit-base-patch32" | |
) | |
def check_nsfw_images( | |
images: list[Image.Image], | |
) -> tuple[list[Image.Image], list[bool]]: | |
safety_checker_input = feature_extractor(images, return_tensors="pt").to(device) | |
has_nsfw_concepts = safety_checker( | |
images=[images], | |
clip_input=safety_checker_input.pixel_values.to(device) | |
) | |
return images, has_nsfw_concepts | |
#pipe = load_evosdxl_jp("cpu").to("cuda") | |
pipe = load_evosdxl_jp("cpu").to(device) | |
def show_warning(warning_text: str) -> gr.Blocks: | |
with gr.Blocks() as demo: | |
gr.Markdown(warning_text) | |
return demo | |
def save_image(img): | |
unique_name = str(uuid.uuid4()) + ".png" | |
img.save(unique_name) | |
return unique_name | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
def generate( | |
prompt: str, | |
seed: int = 0, | |
randomize_seed: bool = False, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
pipe.to(device) | |
seed = int(randomize_seed_fn(seed, randomize_seed)) | |
generator = torch.Generator().manual_seed(seed) | |
images = pipe( | |
prompt=prompt, | |
width=1024, | |
height=1024, | |
guidance_scale=0, | |
num_inference_steps=4, | |
generator=generator, | |
num_images_per_prompt=NUM_IMAGES_PER_PROMPT, | |
output_type="pil", | |
).images | |
if SAFETY_CHECKER: | |
images, has_nsfw_concepts = check_nsfw_images(images) | |
if any(has_nsfw_concepts): | |
gr.Warning("NSFW content detected.") | |
return Image.new("RGB", (512, 512), "WHITE"), seed | |
return images[0], seed | |
examples = [ | |
"柴犬が草原に立つ、幻想的な空、アート、最高品質の写真、ピントが当たってる" | |
] | |
css = ''' | |
.gradio-container{max-width: 690px !important} | |
h1{text-align:center} | |
''' | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox(placeholder="日本語でプロンプトを入力してください。", show_label=False, scale=8) | |
submit = gr.Button(scale=0) | |
result = gr.Image(label="EvoSDXL-JPからの生成結果", show_label=False) | |
with gr.Accordion("詳細設定", open=False): | |
seed = gr.Slider( | |
label="シード値", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="ランダムにシード値を決定", value=True) | |
# gr.Examples( | |
# examples=examples, | |
# inputs=prompt, | |
# outputs=[result, seed], | |
# fn=generate, | |
# # cache_examples=CACHE_EXAMPLES, | |
# ) | |
gr.on( | |
triggers=[ | |
prompt.submit, | |
submit.click, | |
], | |
fn=generate, | |
inputs=[ | |
prompt, | |
seed, | |
randomize_seed, | |
], | |
outputs=[result, seed], | |
api_name="run", | |
) | |
gr.Markdown("""⚠️ 本モデルは実験段階のプロトタイプであり、教育および研究開発の目的でのみ提供されています。商用利用や、障害が重大な影響を及ぼす可能性のある環境(ミッションクリティカルな環境)での使用には適していません。 | |
本モデルの使用は、利用者の自己責任で行われ、その性能や結果については何ら保証されません。 | |
Sakana AIは、本モデルの使用によって生じた直接的または間接的な損失に対して、結果に関わらず、一切の責任を負いません。 | |
利用者は、本モデルの使用に伴うリスクを十分に理解し、自身の判断で使用することが必要です。""") | |
demo.queue().launch() | |