ntt123 commited on
Commit
e8774dc
·
verified ·
1 Parent(s): 6960251

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -17,6 +17,8 @@ def load_model(config_path, ckpt_path):
17
  with open(ckpt_path, "rb") as f:
18
  leaves = pickle.load(f)
19
 
 
 
20
  from model import DiT, DiTConfig
21
 
22
  dit_config = DiTConfig(**config["model"])
@@ -32,13 +34,16 @@ def sample_images(graphdef, state, x0, t):
32
  flow = nnx.merge(graphdef, state)
33
 
34
  def flow_fn(y, t):
 
 
35
  o = flow(y, t[None])
36
- return o
37
 
38
- o = ode.odeint(flow_fn, x0, t, rtol=1e-2)
39
  o = jnp.clip(o[-1], 0, 1)
40
  return o
41
 
 
42
  @spaces.GPU
43
  def generate_grid(seed, noise_level):
44
  # Load model (doing this inside function to avoid global variables)
@@ -66,8 +71,6 @@ def generate_grid(seed, noise_level):
66
  return jax.device_get(grid)
67
 
68
 
69
- generate_grid(0, 3)
70
-
71
  # Create Gradio interface
72
  demo = gr.Interface(
73
  fn=generate_grid,
 
17
  with open(ckpt_path, "rb") as f:
18
  leaves = pickle.load(f)
19
 
20
+ leaves = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), leaves)
21
+
22
  from model import DiT, DiTConfig
23
 
24
  dit_config = DiTConfig(**config["model"])
 
34
  flow = nnx.merge(graphdef, state)
35
 
36
  def flow_fn(y, t):
37
+ y = y.astype(jnp.bfloat16)
38
+ t = t.astype(jnp.bfloat16)
39
  o = flow(y, t[None])
40
+ return o.astype(jnp.float32)
41
 
42
+ o = ode.odeint(flow_fn, x0, t, rtol=1e-4)
43
  o = jnp.clip(o[-1], 0, 1)
44
  return o
45
 
46
+
47
  @spaces.GPU
48
  def generate_grid(seed, noise_level):
49
  # Load model (doing this inside function to avoid global variables)
 
71
  return jax.device_get(grid)
72
 
73
 
 
 
74
  # Create Gradio interface
75
  demo = gr.Interface(
76
  fn=generate_grid,