lnyan commited on
Commit
8ad6ef6
·
1 Parent(s): 61b0e28
Files changed (1) hide show
  1. app.py +25 -2
app.py CHANGED
@@ -92,7 +92,7 @@ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmb
92
  def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
93
  return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, dtype=jnp.bfloat16)
94
 
95
- @spaces.GPU(duration=30)
96
  def load_encoders():
97
  is_schnell = True
98
  t5 = load_t5("cuda", max_length=256 if is_schnell else 512)
@@ -109,7 +109,7 @@ def b64(txt,vec):
109
 
110
  t5,clip=load_encoders()
111
 
112
- @spaces.GPU(duration=10)
113
  def convert(prompt):
114
  if isinstance(prompt, str):
115
  prompt = [prompt]
@@ -118,6 +118,25 @@ def convert(prompt):
118
  vec = clip.tokenize(prompt)
119
  vec = clip(vec)
120
  return b64(txt,vec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  with gr.Blocks() as demo:
123
  gr.Markdown("""A workaround for flux-flax to fit into 40G VRAM""")
@@ -125,10 +144,14 @@ with gr.Blocks() as demo:
125
  with gr.Column():
126
  prompt = gr.Textbox(label="prompt")
127
  convert_btn = gr.Button(value="Convert")
 
128
  with gr.Column():
129
  output = gr.Textbox(label="output")
130
 
131
  convert_btn.click(convert, inputs=prompt, outputs=output, api_name="convert")
 
 
 
132
 
133
 
134
  demo.launch()
 
92
  def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
93
  return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, dtype=jnp.bfloat16)
94
 
95
+ @spaces.GPU(duration=60)
96
  def load_encoders():
97
  is_schnell = True
98
  t5 = load_t5("cuda", max_length=256 if is_schnell else 512)
 
109
 
110
  t5,clip=load_encoders()
111
 
112
+ @spaces.GPU(duration=20)
113
  def convert(prompt):
114
  if isinstance(prompt, str):
115
  prompt = [prompt]
 
118
  vec = clip.tokenize(prompt)
119
  vec = clip(vec)
120
  return b64(txt,vec)
121
+ import jax
122
+ def _to_embed(t5, clip, txt, vec):
123
+ t5=nnx.merge(*t5)
124
+ clip=nnx.merge(*clip)
125
+ return t5(txt), clip(vec)
126
+
127
+ to_embed=jax.jit(_to_embed)
128
+
129
+ t5_tuple=nnx.split(t5)
130
+ clip_tuple=nnx.split(clip)
131
+
132
+ @spaces.GPU(duration=120)
133
+ def compile(prompt):
134
+ if isinstance(prompt, str):
135
+ prompt = [prompt]
136
+ txt = t5.tokenize(prompt)
137
+ vec = clip.tokenize(prompt)
138
+ text,vec=to_embed(t5_tuple,clip_tuple,txt,vec)
139
+ return b64(txt,vec)
140
 
141
  with gr.Blocks() as demo:
142
  gr.Markdown("""A workaround for flux-flax to fit into 40G VRAM""")
 
144
  with gr.Column():
145
  prompt = gr.Textbox(label="prompt")
146
  convert_btn = gr.Button(value="Convert")
147
+ compile_btn = gr.Button(value="Compile")
148
  with gr.Column():
149
  output = gr.Textbox(label="output")
150
 
151
  convert_btn.click(convert, inputs=prompt, outputs=output, api_name="convert")
152
+ compile_btn.click(compile, inputs=prompt, outputs=output, api_name="compile")
153
+
154
+
155
 
156
 
157
  demo.launch()