Update app.py
Browse files
app.py
CHANGED
@@ -17,6 +17,8 @@ def load_model(config_path, ckpt_path):
|
|
17 |
with open(ckpt_path, "rb") as f:
|
18 |
leaves = pickle.load(f)
|
19 |
|
|
|
|
|
20 |
from model import DiT, DiTConfig
|
21 |
|
22 |
dit_config = DiTConfig(**config["model"])
|
@@ -32,13 +34,16 @@ def sample_images(graphdef, state, x0, t):
|
|
32 |
flow = nnx.merge(graphdef, state)
|
33 |
|
34 |
def flow_fn(y, t):
|
|
|
|
|
35 |
o = flow(y, t[None])
|
36 |
-
return o
|
37 |
|
38 |
-
o = ode.odeint(flow_fn, x0, t, rtol=1e-
|
39 |
o = jnp.clip(o[-1], 0, 1)
|
40 |
return o
|
41 |
|
|
|
42 |
@spaces.GPU
|
43 |
def generate_grid(seed, noise_level):
|
44 |
# Load model (doing this inside function to avoid global variables)
|
@@ -66,8 +71,6 @@ def generate_grid(seed, noise_level):
|
|
66 |
return jax.device_get(grid)
|
67 |
|
68 |
|
69 |
-
generate_grid(0, 3)
|
70 |
-
|
71 |
# Create Gradio interface
|
72 |
demo = gr.Interface(
|
73 |
fn=generate_grid,
|
|
|
17 |
with open(ckpt_path, "rb") as f:
|
18 |
leaves = pickle.load(f)
|
19 |
|
20 |
+
leaves = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), leaves)
|
21 |
+
|
22 |
from model import DiT, DiTConfig
|
23 |
|
24 |
dit_config = DiTConfig(**config["model"])
|
|
|
34 |
flow = nnx.merge(graphdef, state)
|
35 |
|
36 |
def flow_fn(y, t):
|
37 |
+
y = y.astype(jnp.bfloat16)
|
38 |
+
t = t.astype(jnp.bfloat16)
|
39 |
o = flow(y, t[None])
|
40 |
+
return o.astype(jnp.float32)
|
41 |
|
42 |
+
o = ode.odeint(flow_fn, x0, t, rtol=1e-4)
|
43 |
o = jnp.clip(o[-1], 0, 1)
|
44 |
return o
|
45 |
|
46 |
+
|
47 |
@spaces.GPU
|
48 |
def generate_grid(seed, noise_level):
|
49 |
# Load model (doing this inside function to avoid global variables)
|
|
|
71 |
return jax.device_get(grid)
|
72 |
|
73 |
|
|
|
|
|
74 |
# Create Gradio interface
|
75 |
demo = gr.Interface(
|
76 |
fn=generate_grid,
|