wenhao-gao commited on
Commit
81224d9
·
1 Parent(s): 3beee34
Files changed (2) hide show
  1. app.py +105 -125
  2. requirements.txt +3 -5
app.py CHANGED
@@ -1,142 +1,122 @@
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
- #import spaces #[uncomment to use ZeroGPU]
5
- from diffusers import DiffusionPipeline
6
- import torch
7
 
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model_repo_id = "stabilityai/sdxl-turbo" #Replace to the model you would like to use
 
 
 
 
 
10
 
11
- if torch.cuda.is_available():
12
- torch_dtype = torch.float16
13
- else:
14
- torch_dtype = torch.float32
15
 
16
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
17
- pipe = pipe.to(device)
 
 
18
 
19
- MAX_SEED = np.iinfo(np.int32).max
20
- MAX_IMAGE_SIZE = 1024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- #@spaces.GPU #[uncomment to use ZeroGPU]
23
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
24
 
25
- if randomize_seed:
26
- seed = random.randint(0, MAX_SEED)
27
-
28
- generator = torch.Generator().manual_seed(seed)
29
-
30
- image = pipe(
31
- prompt = prompt,
32
- negative_prompt = negative_prompt,
33
- guidance_scale = guidance_scale,
34
- num_inference_steps = num_inference_steps,
35
- width = width,
36
- height = height,
37
- generator = generator
38
- ).images[0]
39
-
40
- return image, seed
41
 
42
  examples = [
43
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
44
- "An astronaut riding a green horse",
45
- "A delicious ceviche cheesecake slice",
 
 
46
  ]
47
 
48
- css="""
49
- #col-container {
50
- margin: 0 auto;
51
- max-width: 640px;
52
- }
53
- """
54
 
55
- with gr.Blocks(css=css) as demo:
56
-
57
- with gr.Column(elem_id="col-container"):
58
- gr.Markdown(f"""
59
- # Text-to-Image Gradio Template
 
 
 
 
 
 
60
  """)
61
-
62
- with gr.Row():
63
-
64
- prompt = gr.Text(
65
- label="Prompt",
66
- show_label=False,
67
- max_lines=1,
68
- placeholder="Enter your prompt",
69
- container=False,
70
- )
71
-
72
- run_button = gr.Button("Run", scale=0)
73
-
74
- result = gr.Image(label="Result", show_label=False)
75
 
76
- with gr.Accordion("Advanced Settings", open=False):
77
-
78
- negative_prompt = gr.Text(
79
- label="Negative prompt",
80
- max_lines=1,
81
- placeholder="Enter a negative prompt",
82
- visible=False,
83
- )
84
-
85
- seed = gr.Slider(
86
- label="Seed",
87
- minimum=0,
88
- maximum=MAX_SEED,
89
- step=1,
90
- value=0,
91
- )
92
-
93
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
94
-
95
- with gr.Row():
96
-
97
- width = gr.Slider(
98
- label="Width",
99
- minimum=256,
100
- maximum=MAX_IMAGE_SIZE,
101
- step=32,
102
- value=1024, #Replace with defaults that work for your model
103
- )
104
-
105
- height = gr.Slider(
106
- label="Height",
107
- minimum=256,
108
- maximum=MAX_IMAGE_SIZE,
109
- step=32,
110
- value=1024, #Replace with defaults that work for your model
111
- )
112
-
113
  with gr.Row():
114
-
115
- guidance_scale = gr.Slider(
116
- label="Guidance scale",
117
- minimum=0.0,
118
- maximum=10.0,
119
- step=0.1,
120
- value=0.0, #Replace with defaults that work for your model
121
- )
122
-
123
- num_inference_steps = gr.Slider(
124
- label="Number of inference steps",
125
- minimum=1,
126
- maximum=50,
127
- step=1,
128
- value=2, #Replace with defaults that work for your model
129
- )
130
-
131
- gr.Examples(
132
- examples = examples,
133
- inputs = [prompt]
134
- )
135
- gr.on(
136
- triggers=[run_button.click, prompt.submit],
137
- fn = infer,
138
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
139
- outputs = [result, seed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  )
141
 
142
- demo.queue().launch()
 
1
  import gradio as gr
2
+ from gradio_molecule2d import molecule2d
3
+ from synformer.chem.mol import Molecule
4
+ from synformer.sampler.analog.parallel import run_sampling_one_cpu
5
+ from huggingface_hub import hf_hub_download
 
6
 
7
+ REPO_ID = "whgao/synformer"
8
+ CKPT_FILENAME = "sf_ed_default.ckpt"
9
+ MAT_FILENAME = "matrix.pkl"
10
+ FPI_FILENAME = "fpindex.pkl"
11
+ ckpt_path = hf_hub_download(REPO_ID, CKPT_FILENAME)
12
+ mat_path = hf_hub_download(REPO_ID, MAT_FILENAME)
13
+ fpi_path = hf_hub_download(REPO_ID, FPI_FILENAME)
14
 
15
+ last_result = {}
 
 
 
16
 
17
+ # Function to clear all inputs
18
+ def clear_inputs():
19
+ # Return default or empty values to reset each input component
20
+ return None, 24, 64, 0
21
 
22
+ def sample(smi, search_width, exhaustiveness):
23
+ result_df = run_sampling_one_cpu(
24
+ input=Molecule(smi),
25
+ model_path=ckpt_path,
26
+ mat_path=mat_path,
27
+ fpi_path=fpi_path,
28
+ search_width=search_width,
29
+ exhaustiveness=exhaustiveness,
30
+ time_limit=180,
31
+ max_results=100,
32
+ max_evolve_steps=24,
33
+ sort_by_scores=True,
34
+ )
35
+ result_df = result_df[:30]
36
+ last_result["results_df"] = result_df
37
+ smiles = result_df.iloc[0]["smiles"]
38
+ similarity = result_df.iloc[0]["score"]
39
+ synthesis = result_df.iloc[0]["synthesis"]
40
+ return smiles, similarity, synthesis, gr.update(maximum=len(result_df)-1)
41
 
 
 
42
 
43
+ def select_from_output(index):
44
+ df_results = last_result["results_df"]
45
+ return df_results.iloc[index]["smiles"], df_results.iloc[index]["score"], df_results.iloc[index]["synthesis"]
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  examples = [
48
+ "Nc1cccc(S(=O)(=O)N2CCCN(S(=O)(=O)c3ccc4c(c3)OCCO4)CC2)c1",
49
+ "CN1C[C@H](Nc2cnn(C)c(=O)c2)C[C@H](c2ccccc2)C1",
50
+ "COc1ccc(-c2ccnc(Nc3ccccc3)n2)cc1",
51
+ "CC[C@@H]1OC[C@@]23Cc4cc(F)c(N)cc4-c4ccc5c(c42)C(=CC(F)(F)O5)[C@@H]1C3=O",
52
+ "O=C(OCC(=O)N1[C@H](C(=O)O)C[C@@H]2CCCC[C@@H]21)[C@H](Cc1cbccc1)NC(I)c1bcccc1",
53
  ]
54
 
 
 
 
 
 
 
55
 
56
+ with gr.Blocks() as demo:
57
+ gr.Markdown(f"""
58
+ # Demo of [SynFormer](https://github.com/wenhao-gao/synformer/tree/main)
59
+ This page demonstrates the SynFormer-ED model, which takes a molecule as input—regardless of its synthetic accessibility—and outputs
60
+ identical or approximate molecules along with their associated synthetic paths. The demo runs on CPUs and typically takes about
61
+ one minute per run but can be accelerated by reducing the search width and exhaustiveness. The model may take longer if the server
62
+ is busy. Since the sampling is stochastic, you may run the demo multiple times to explore different results, with a maximum of
63
+ 30 molecules displayed at once.
64
+ To learn more about SynFormer’s architecture and applications, check out [our paper](https://github.com/wenhao-gao/synformer/tree/main).
65
+
66
+ Authors: [Wenhao Gao](mailto:[email protected]), Shitong Luo, Connor W. Coley
67
  """)
68
+ with gr.Row():
69
+ with gr.Column(scale=0.5):
70
+ input_molecule = molecule2d(label="SMILES Input")
71
+ slider_1 = gr.Slider(minimum=1, maximum=100, step=1, label="Search Width", value=24)
72
+ slider_2 = gr.Slider(minimum=1, maximum=100, step=1, label="Exhaustiveness", value=64)
 
 
 
 
 
 
 
 
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  with gr.Row():
75
+ with gr.Column(scale=0.5):
76
+ run_btn = gr.Button("Run on sample")
77
+ with gr.Column(scale=0.5):
78
+ clear_btn = gr.Button("Clear")
79
+
80
+ with gr.Column(scale=0.5):
81
+ index_slider = gr.Slider(minimum=0, maximum=10, step=1, label="Select Output Index", value=0, interactive=True)
82
+ output_similarity = gr.Text(label="Tanimoto Similarity")
83
+ output_molecule = molecule2d(label="Output")
84
+ output_synpath = gr.Textbox(label="Synthetic Path")
85
+
86
+ with gr.Row():
87
+ with gr.Column(scale=1):
88
+ gr.Markdown("### Examples")
89
+ gr.Examples(
90
+ examples = examples,
91
+ inputs = [input_molecule]
92
+ )
93
+
94
+ run_btn.click(
95
+ fn=sample,
96
+ inputs=[
97
+ input_molecule,
98
+ slider_1,
99
+ slider_2
100
+ ],
101
+ outputs=[
102
+ output_molecule,
103
+ output_similarity,
104
+ output_synpath,
105
+ index_slider
106
+ ],
107
+ api_name="Run"
108
+ )
109
+
110
+ index_slider.change(
111
+ fn=select_from_output,
112
+ inputs=[index_slider],
113
+ outputs=[output_molecule, output_similarity, output_synpath],
114
+ )
115
+
116
+ clear_btn.click(
117
+ fn=clear_inputs,
118
+ inputs=[],
119
+ outputs=[input_molecule, slider_1, slider_2, index_slider]
120
  )
121
 
122
+ demo.launch()
requirements.txt CHANGED
@@ -1,6 +1,4 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
  torch
5
- transformers
6
- xformers
 
1
+ gradio_molecule2d
 
 
2
  torch
3
+ synformer
4
+ huggingface-hub