Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,7 @@ import stylegan2
|
|
9 |
from stylegan2 import utils
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
import gradio as gr
|
|
|
12 |
|
13 |
from types import SimpleNamespace
|
14 |
|
@@ -78,9 +79,17 @@ def generate_images(G, args):
|
|
78 |
|
79 |
|
80 |
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
def inference(seed, truncation_psi,
|
83 |
-
|
|
|
|
|
|
|
84 |
G.eval()
|
85 |
return generate_images(
|
86 |
G,
|
@@ -103,14 +112,10 @@ gr.Interface(
|
|
103 |
gr.Number(precision=0, label="PCG64 PRNG Seed (any-bit-size unsigned int, note that it may different from the original site)"),
|
104 |
gr.Slider(0, 2, step=0.1, value=0.7, label='Truncation psi (aka creative level, between 0 and 2)'),
|
105 |
gr.Radio(
|
106 |
-
|
107 |
-
value="iteration
|
108 |
type="value",
|
109 |
-
label=
|
110 |
-
"TWDNEv3 iteration 24664 (best and current version on TWDNE)",
|
111 |
-
"TWDNEv3 iteration 18528 (the most used version on Internet)",
|
112 |
-
"TWDNEv3 iteration 17325"
|
113 |
-
]
|
114 |
)
|
115 |
],
|
116 |
gr.outputs.Image(type="pil"),
|
|
|
9 |
from stylegan2 import utils
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
import gradio as gr
|
12 |
+
import re
|
13 |
|
14 |
from types import SimpleNamespace
|
15 |
|
|
|
79 |
|
80 |
|
81 |
#----------------------------------------------------------------------------
|
82 |
+
interface_modelversion_labels = [
|
83 |
+
"TWDNEv3 iteration 24664 (best and current version on TWDNE)",
|
84 |
+
"TWDNEv3 iteration 18528 (the most used version on the Internet)",
|
85 |
+
"TWDNEv3 iteration 17325"
|
86 |
+
]
|
87 |
|
88 |
+
def inference(seed, truncation_psi, modelversion_label):
|
89 |
+
model_iteration = re.search("TWDNEv3 iteration (\d{5})", modelversion_label).group(1)
|
90 |
+
G = stylegan2.models.load(
|
91 |
+
hf_hub_download("hr16/Gwern-TWDNEv3-pytorch_ckpt", f"iteration-{model_iteration}/Gs.pth", use_auth_token=os.environ['MODEL_READING_TOKEN'])
|
92 |
+
)
|
93 |
G.eval()
|
94 |
return generate_images(
|
95 |
G,
|
|
|
112 |
gr.Number(precision=0, label="PCG64 PRNG Seed (any-bit-size unsigned int, note that it may different from the original site)"),
|
113 |
gr.Slider(0, 2, step=0.1, value=0.7, label='Truncation psi (aka creative level, between 0 and 2)'),
|
114 |
gr.Radio(
|
115 |
+
interface_modelversion_labels,
|
116 |
+
value="TWDNEv3 iteration 24664 (best and current version on TWDNE)",
|
117 |
type="value",
|
118 |
+
label="Model versions"
|
|
|
|
|
|
|
|
|
119 |
)
|
120 |
],
|
121 |
gr.outputs.Image(type="pil"),
|