hr16 commited on
Commit
0d8c2d3
·
1 Parent(s): 36bec38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -9,6 +9,7 @@ import stylegan2
9
  from stylegan2 import utils
10
  from huggingface_hub import hf_hub_download
11
  import gradio as gr
 
12
 
13
  from types import SimpleNamespace
14
 
@@ -78,9 +79,17 @@ def generate_images(G, args):
78
 
79
 
80
  #----------------------------------------------------------------------------
 
 
 
 
 
81
 
82
- def inference(seed, truncation_psi, model_version):
83
- G = stylegan2.models.load(hf_hub_download("hr16/Gwern-TWDNEv3-pytorch_ckpt", f"{model_version}/Gs.pth", use_auth_token=os.environ['MODEL_READING_TOKEN']))
 
 
 
84
  G.eval()
85
  return generate_images(
86
  G,
@@ -103,14 +112,10 @@ gr.Interface(
103
  gr.Number(precision=0, label="PCG64 PRNG Seed (any-bit-size unsigned int, note that it may different from the original site)"),
104
  gr.Slider(0, 2, step=0.1, value=0.7, label='Truncation psi (aka creative level, between 0 and 2)'),
105
  gr.Radio(
106
- ["iteration-24664", "iteration-18528", "iteration-17325"],
107
- value="iteration-24664",
108
  type="value",
109
- label=[
110
- "TWDNEv3 iteration 24664 (best and current version on TWDNE)",
111
- "TWDNEv3 iteration 18528 (the most used version on Internet)",
112
- "TWDNEv3 iteration 17325"
113
- ]
114
  )
115
  ],
116
  gr.outputs.Image(type="pil"),
 
9
  from stylegan2 import utils
10
  from huggingface_hub import hf_hub_download
11
  import gradio as gr
12
+ import re
13
 
14
  from types import SimpleNamespace
15
 
 
79
 
80
 
81
  #----------------------------------------------------------------------------
82
+ interface_modelversion_labels = [
83
+ "TWDNEv3 iteration 24664 (best and current version on TWDNE)",
84
+ "TWDNEv3 iteration 18528 (the most used version on the Internet)",
85
+ "TWDNEv3 iteration 17325"
86
+ ]
87
 
88
+ def inference(seed, truncation_psi, modelversion_label):
89
+ model_iteration = re.search("TWDNEv3 iteration (\d{5})", modelversion_label).group(1)
90
+ G = stylegan2.models.load(
91
+ hf_hub_download("hr16/Gwern-TWDNEv3-pytorch_ckpt", f"iteration-{model_iteration}/Gs.pth", use_auth_token=os.environ['MODEL_READING_TOKEN'])
92
+ )
93
  G.eval()
94
  return generate_images(
95
  G,
 
112
  gr.Number(precision=0, label="PCG64 PRNG Seed (any-bit-size unsigned int, note that it may different from the original site)"),
113
  gr.Slider(0, 2, step=0.1, value=0.7, label='Truncation psi (aka creative level, between 0 and 2)'),
114
  gr.Radio(
115
+ interface_modelversion_labels,
116
+ value="TWDNEv3 iteration 24664 (best and current version on TWDNE)",
117
  type="value",
118
+ label="Model versions"
 
 
 
 
119
  )
120
  ],
121
  gr.outputs.Image(type="pil"),