Bils commited on
Commit
f6fe860
·
verified ·
1 Parent(s): 641b21d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoConfig, AutoModelForCausalLM
4
+ from janus.models import MultiModalityCausalLM, VLChatProcessor
5
+ from PIL import Image
6
+ import numpy as np
7
+ import spaces
8
+
9
+ # Load the model and processor
10
+ model_path = "deepseek-ai/Janus-Pro-7B"
11
+ config = AutoConfig.from_pretrained(model_path)
12
+ language_config = config.language_config
13
+ language_config._attn_implementation = 'eager'
14
+
15
+ vl_gpt = AutoModelForCausalLM.from_pretrained(
16
+ model_path,
17
+ language_config=language_config,
18
+ trust_remote_code=True
19
+ )
20
+ if torch.cuda.is_available():
21
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
22
+ else:
23
+ vl_gpt = vl_gpt.to(torch.float16)
24
+
25
+ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
26
+ tokenizer = vl_chat_processor.tokenizer
27
+ cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
28
+
29
+ # Helper functions
30
+ def generate(input_ids, width, height, cfg_weight=5, temperature=1.0, parallel_size=5, patch_size=16):
31
+ torch.cuda.empty_cache()
32
+
33
+ tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
34
+ for i in range(parallel_size * 2):
35
+ tokens[i, :] = input_ids
36
+ if i % 2 != 0:
37
+ tokens[i, 1:-1] = vl_chat_processor.pad_id
38
+
39
+ inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
40
+ generated_tokens = torch.zeros((parallel_size, 576), dtype=torch.int).to(cuda_device)
41
+
42
+ pkv = None
43
+ for i in range(576):
44
+ with torch.no_grad():
45
+ outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
46
+ pkv = outputs.past_key_values
47
+ hidden_states = outputs.last_hidden_state
48
+ logits = vl_gpt.gen_head(hidden_states[:, -1, :])
49
+
50
+ logit_cond = logits[0::2, :]
51
+ logit_uncond = logits[1::2, :]
52
+ logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
53
+
54
+ probs = torch.softmax(logits / temperature, dim=-1)
55
+ next_token = torch.multinomial(probs, num_samples=1)
56
+ generated_tokens[:, i] = next_token.squeeze(dim=-1)
57
+
58
+ next_token = torch.cat([next_token.unsqueeze(dim=1)] * 2, dim=1).view(-1)
59
+ img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
60
+ inputs_embeds = img_embeds.unsqueeze(dim=1)
61
+
62
+ patches = vl_gpt.gen_vision_model.decode_code(
63
+ generated_tokens.to(dtype=torch.int),
64
+ shape=[parallel_size, 8, width // patch_size, height // patch_size]
65
+ )
66
+ return patches
67
+
68
+ def unpack(patches, width, height, parallel_size=5):
69
+ patches = patches.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
70
+ patches = np.clip((patches + 1) / 2 * 255, 0, 255)
71
+
72
+ images = [Image.fromarray(patches[i].astype(np.uint8)) for i in range(parallel_size)]
73
+ return images
74
+
75
+ @torch.inference_mode()
76
+ @spaces.GPU(duration=120)
77
+ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
78
+ torch.cuda.empty_cache()
79
+
80
+ if seed is not None:
81
+ torch.manual_seed(seed)
82
+ torch.cuda.manual_seed(seed)
83
+ np.random.seed(seed)
84
+
85
+ width, height, parallel_size = 384, 384, 5
86
+
87
+ messages = [
88
+ {'role': '<|User|>', 'content': prompt},
89
+ {'role': '<|Assistant|>', 'content': ''}
90
+ ]
91
+
92
+ text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
93
+ conversations=messages, sft_format=vl_chat_processor.sft_format, system_prompt=''
94
+ )
95
+ text += vl_chat_processor.image_start_tag
96
+
97
+ input_ids = torch.LongTensor(tokenizer.encode(text))
98
+ patches = generate(input_ids, width, height, cfg_weight=guidance, temperature=t2i_temperature, parallel_size=parallel_size)
99
+
100
+ return unpack(patches, width, height, parallel_size)
101
+
102
+ # Gradio interface
103
+ def create_interface():
104
+ with gr.Blocks() as demo:
105
+ gr.Markdown("# Text-to-Image Generation")
106
+
107
+ prompt_input = gr.Textbox(label="Prompt (describe the image)")
108
+ seed_input = gr.Number(label="Seed (Optional)", value=12345, precision=0)
109
+ guidance_slider = gr.Slider(label="CFG Guidance Weight", minimum=1, maximum=10, value=5, step=0.5)
110
+ temperature_slider = gr.Slider(label="Temperature", minimum=0, maximum=1, value=1.0, step=0.05)
111
+
112
+ generate_button = gr.Button("Generate Images")
113
+ output_gallery = gr.Gallery(label="Generated Images", columns=2, height=300)
114
+
115
+ generate_button.click(
116
+ generate_image,
117
+ inputs=[prompt_input, seed_input, guidance_slider, temperature_slider],
118
+ outputs=output_gallery
119
+ )
120
+
121
+ return demo
122
+
123
+ demo = create_interface()
124
+ demo.launch(share=True)