IMvision12 commited on
Commit
826c17e
·
1 Parent(s): f4a013c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -5,7 +5,8 @@ import tensorflow as tf
5
  import gradio as gr
6
  import numpy as np
7
 
8
- model = from_pretrained_keras("keras-io/WGAN-GP")
 
9
 
10
  title = "WGAN-GP"
11
  description = "Image Generation Using WGAN"
@@ -17,7 +18,7 @@ article = """
17
  </p>
18
  """
19
 
20
- def Predict(num_images):
21
  random_latent_vectors = tf.random.normal(shape=(int(num_images), 128))
22
  preds = model(random_latent_vectors)
23
  num = ceil(sqrt(num_images))
@@ -32,13 +33,21 @@ def Predict(num_images):
32
  return images
33
 
34
 
 
 
 
 
 
 
 
35
  examples = [[5],[8],[2],[3]]
36
-
 
37
 
38
  interface = gr.Interface(
39
- fn = Predict,
40
- inputs = ["number"],
41
- outputs = ["image"],
42
  examples = examples,
43
  description = description,
44
  title = title,
 
5
  import gradio as gr
6
  import numpy as np
7
 
8
+ model1 = tf.keras.models.load_model("mnist.h5", compile=False)
9
+ model2 = from_pretrained_keras("keras-io/WGAN-GP")
10
 
11
  title = "WGAN-GP"
12
  description = "Image Generation Using WGAN"
 
18
  </p>
19
  """
20
 
21
+ def Predict(model, num_images):
22
  random_latent_vectors = tf.random.normal(shape=(int(num_images), 128))
23
  preds = model(random_latent_vectors)
24
  num = ceil(sqrt(num_images))
 
33
  return images
34
 
35
 
36
+ def inference(num_images, select: str):
37
+ if select == 'fmnist':
38
+ result = create_digit_samples(model2, num_images)
39
+ else:
40
+ result = create_digit_samples(model1, num_images)
41
+ return result
42
+
43
  examples = [[5],[8],[2],[3]]
44
+ inputs = [gr.inputs.Number(label="number of images"), gr.inputs.Radio(['fmnist', 'mnist'])]
45
+ outputs = gr.outputs.Image(label="Output Image")
46
 
47
  interface = gr.Interface(
48
+ fn = inference,
49
+ inputs = inputs,
50
+ outputs = outputs,
51
  examples = examples,
52
  description = description,
53
  title = title,