thuanz123 commited on
Commit
ffa45e8
·
1 Parent(s): 31197f9

Upload 18 files

Browse files
Files changed (18) hide show
  1. .gitattributes +0 -1
  2. .gitignore +165 -0
  3. .pre-commit-config.yaml +37 -0
  4. .style.yapf +5 -0
  5. LICENSE +21 -0
  6. README.md +5 -4
  7. app.py +75 -0
  8. app_inference.py +162 -0
  9. app_training.py +152 -0
  10. app_upload.py +100 -0
  11. constants.py +6 -0
  12. inference.py +83 -0
  13. requirements.txt +9 -0
  14. style.css +3 -0
  15. train_realfill.py +952 -0
  16. trainer.py +168 -0
  17. uploader.py +42 -0
  18. utils.py +48 -0
.gitattributes CHANGED
@@ -25,7 +25,6 @@
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training_data/
2
+ experiments/
3
+ wandb/
4
+
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # pdm
110
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111
+ #pdm.lock
112
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113
+ # in version control.
114
+ # https://pdm.fming.dev/#use-with-ide
115
+ .pdm.toml
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: train_realfill.py
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.10.1
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
  title: RealFill Training UI
3
- emoji: 📈
4
- colorFrom: purple
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 3.50.2
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
  title: RealFill Training UI
3
+ emoji:
4
+ colorFrom: red
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.16.2
8
+ python_version: 3.10.9
9
  app_file: app.py
10
  pinned: false
11
  license: mit
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+ import torch
9
+
10
+ from app_inference import create_inference_demo
11
+ from app_training import create_training_demo
12
+ from app_upload import create_upload_demo
13
+ from inference import InferencePipeline
14
+ from trainer import Trainer
15
+
16
+ TITLE = '# RealFill Training UI'
17
+
18
+ ORIGINAL_SPACE_ID = 'realfill-library/RealFill-Training-UI'
19
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
20
+ SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
21
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
22
+ '''
23
+
24
+ if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
25
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
26
+ else:
27
+ SETTINGS = 'Settings'
28
+ CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
29
+ <center>
30
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
31
+ "T4 small" is sufficient to run this demo.
32
+ </center>
33
+ '''
34
+
35
+ HF_TOKEN_NOT_SPECIFIED_WARNING = f'''# Attention - The environment variable `HF_TOKEN` is not specified. Please specify your Hugging Face token with write permission as the value of it.
36
+ <center>
37
+ You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
38
+ You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
39
+ </center>
40
+ '''
41
+
42
+ HF_TOKEN = os.getenv('HF_TOKEN')
43
+
44
+
45
+ def show_warning(warning_text: str) -> gr.Blocks:
46
+ with gr.Blocks() as demo:
47
+ with gr.Box():
48
+ gr.Markdown(warning_text)
49
+ return demo
50
+
51
+
52
+ pipe = InferencePipeline(HF_TOKEN)
53
+ trainer = Trainer(HF_TOKEN)
54
+
55
+ with gr.Blocks(css='style.css') as demo:
56
+ if os.getenv('IS_SHARED_UI'):
57
+ show_warning(SHARED_UI_WARNING)
58
+ if not torch.cuda.is_available():
59
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
60
+ if not HF_TOKEN:
61
+ show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
62
+
63
+ gr.Markdown(TITLE)
64
+ with gr.Tabs():
65
+ with gr.TabItem('Train'):
66
+ create_training_demo(trainer, pipe)
67
+ with gr.TabItem('Test'):
68
+ create_inference_demo(pipe, HF_TOKEN)
69
+ with gr.TabItem('Upload'):
70
+ gr.Markdown('''
71
+ - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
72
+ ''')
73
+ create_upload_demo(HF_TOKEN)
74
+
75
+ demo.queue(max_size=1).launch(share=True)
app_inference.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import enum
6
+
7
+ import gradio as gr
8
+ from huggingface_hub import HfApi
9
+
10
+ from inference import InferencePipeline
11
+ from utils import find_exp_dirs
12
+
13
+ SAMPLE_MODEL_IDS = ['thuanz123/flowerwoman']
14
+
15
+
16
+ class ModelSource(enum.Enum):
17
+ SAMPLE = 'Sample'
18
+ HUB_LIB = 'Hub (realfill-library)'
19
+ LOCAL = 'Local'
20
+
21
+
22
+ class InferenceUtil:
23
+ def __init__(self, hf_token: str | None):
24
+ self.hf_token = hf_token
25
+
26
+ @staticmethod
27
+ def load_sample_model_list():
28
+ return gr.update(choices=SAMPLE_MODEL_IDS, value=SAMPLE_MODEL_IDS[0])
29
+
30
+ def load_hub_model_list(self) -> dict:
31
+ api = HfApi(token=self.hf_token)
32
+ choices = [
33
+ info.modelId for info in api.list_models(author='realfill-library')
34
+ ]
35
+ return gr.update(choices=choices,
36
+ value=choices[0] if choices else None)
37
+
38
+ @staticmethod
39
+ def load_local_model_list() -> dict:
40
+ choices = find_exp_dirs()
41
+ return gr.update(choices=choices,
42
+ value=choices[0] if choices else None)
43
+
44
+ def reload_model_list(self, model_source: str) -> dict:
45
+ if model_source == ModelSource.SAMPLE.value:
46
+ return self.load_sample_model_list()
47
+ elif model_source == ModelSource.HUB_LIB.value:
48
+ return self.load_hub_model_list()
49
+ elif model_source == ModelSource.LOCAL.value:
50
+ return self.load_local_model_list()
51
+ else:
52
+ raise ValueError
53
+
54
+ def load_model_info(self, model_id: str) -> tuple[str, str]:
55
+ try:
56
+ card = InferencePipeline.get_model_card(model_id, self.hf_token)
57
+ except Exception:
58
+ return '', ''
59
+ target_image = getattr(card.data, 'target_image', '')
60
+ target_mask = getattr(card.data, 'target_mask', '')
61
+ return target_image, target_mask
62
+
63
+ def reload_model_list_and_update_model_info(
64
+ self, model_source: str
65
+ ) -> tuple[dict, str, str]:
66
+ model_list_update = self.reload_model_list(model_source)
67
+ model_list = model_list_update['choices']
68
+ model_info = self.load_model_info(model_list[0] if model_list else '')
69
+ return model_list_update, *model_info
70
+
71
+
72
+ def create_inference_demo(pipe: InferencePipeline,
73
+ hf_token: str | None = None) -> gr.Blocks:
74
+ app = InferenceUtil(hf_token)
75
+
76
+ with gr.Blocks() as demo:
77
+ with gr.Row():
78
+ with gr.Column():
79
+ with gr.Box():
80
+ model_source = gr.Radio(
81
+ label='Model Source',
82
+ choices=[_.value for _ in ModelSource],
83
+ value=ModelSource.SAMPLE.value)
84
+ reload_button = gr.Button('Reload Model List')
85
+ model_id = gr.Dropdown(label='Model ID',
86
+ choices=SAMPLE_MODEL_IDS,
87
+ value=SAMPLE_MODEL_IDS[0])
88
+ with gr.Accordion(
89
+ label=
90
+ 'Model info (Target image and mask used for both training and inference)',
91
+ open=False):
92
+ with gr.Row():
93
+ target_image = gr.Image(
94
+ label='Target Image', interactive=False)
95
+ target_mask = gr.Image(
96
+ label='Target Mask', interactive=False)
97
+ seed = gr.Slider(label='Seed',
98
+ minimum=0,
99
+ maximum=100000,
100
+ step=1,
101
+ value=0)
102
+ with gr.Accordion('Other Parameters', open=False):
103
+ num_steps = gr.Slider(label='Number of Steps',
104
+ minimum=0,
105
+ maximum=100,
106
+ step=1,
107
+ value=25)
108
+ guidance_scale = gr.Slider(label='CFG Scale',
109
+ minimum=0,
110
+ maximum=50,
111
+ step=0.1,
112
+ value=5.0)
113
+
114
+ run_button = gr.Button('Generate')
115
+
116
+ gr.Markdown('''
117
+ - After training, you can press "Reload Model List" button to load your trained model names.
118
+ ''')
119
+ with gr.Column():
120
+ result = gr.Image(label='Result')
121
+
122
+ model_source.change(
123
+ fn=app.reload_model_list_and_update_model_info,
124
+ inputs=model_source,
125
+ outputs=[
126
+ model_id,
127
+ target_image,
128
+ target_mask
129
+ ])
130
+ reload_button.click(
131
+ fn=app.reload_model_list_and_update_model_info,
132
+ inputs=model_source,
133
+ outputs=[
134
+ model_id,
135
+ target_image,
136
+ target_mask
137
+ ])
138
+ model_id.change(fn=app.load_model_info,
139
+ inputs=model_id,
140
+ outputs=[
141
+ target_image,
142
+ target_mask
143
+ ])
144
+ inputs = [
145
+ model_id,
146
+ seed,
147
+ target_image,
148
+ target_mask,
149
+ num_steps,
150
+ guidance_scale,
151
+ ]
152
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
153
+ return demo
154
+
155
+
156
+ if __name__ == '__main__':
157
+ import os
158
+
159
+ hf_token = os.getenv('HF_TOKEN')
160
+ pipe = InferencePipeline(hf_token)
161
+ demo = create_inference_demo(pipe, hf_token)
162
+ demo.queue(max_size=10).launch(share=False)
app_training.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+
9
+ from constants import UploadTarget
10
+ from inference import InferencePipeline
11
+ from trainer import Trainer
12
+
13
+
14
+ def create_training_demo(trainer: Trainer,
15
+ pipe: InferencePipeline | None = None) -> gr.Blocks:
16
+ with gr.Blocks() as demo:
17
+ with gr.Row():
18
+ with gr.Column():
19
+ with gr.Box():
20
+ gr.Markdown('Training Data')
21
+ reference_images = gr.Files(label='Reference images')
22
+ target_image = gr.Files(label='Target image')
23
+ target_mask = gr.Files(label='Target mask')
24
+ gr.Markdown('''
25
+ - Upload reference images of the scene you are planning on training on.
26
+ - For the target image, the inpainting region should be white.
27
+ - For the target mask, white for inpainting and black for keeping as is.
28
+ ''')
29
+ with gr.Box():
30
+ gr.Markdown('Output Model')
31
+ output_model_name = gr.Text(label='Name of your model',
32
+ max_lines=1)
33
+ delete_existing_model = gr.Checkbox(
34
+ label='Delete existing model of the same name',
35
+ value=False)
36
+ with gr.Box():
37
+ gr.Markdown('Upload Settings')
38
+ with gr.Row():
39
+ upload_to_hub = gr.Checkbox(
40
+ label='Upload model to Hub', value=True)
41
+ use_private_repo = gr.Checkbox(label='Private',
42
+ value=True)
43
+ delete_existing_repo = gr.Checkbox(
44
+ label='Delete existing repo of the same name',
45
+ value=False)
46
+ upload_to = gr.Radio(
47
+ label='Upload to',
48
+ choices=[_.value for _ in UploadTarget],
49
+ value=UploadTarget.REALFILL_LIBRARY.value)
50
+ gr.Markdown('''
51
+ - By default, trained models will be uploaded to [ReaFill Library](https://huggingface.co/realfill-library).
52
+ - You can also choose "Personal Profile", in which case, the model will be uploaded to https://huggingface.co/{your_username}/{model_name}.
53
+ ''')
54
+
55
+ with gr.Box():
56
+ gr.Markdown('Training Parameters')
57
+ with gr.Row():
58
+ base_model = gr.Text(
59
+ label='Base Model',
60
+ value='stabilityai/stable-diffusion-2-inpainting',
61
+ max_lines=1)
62
+ resolution = gr.Dropdown(choices=['512', '768'],
63
+ value='512',
64
+ label='Resolution')
65
+ num_training_steps = gr.Number(
66
+ label='Number of Training Steps', value=2000, precision=0)
67
+ unet_learning_rate = gr.Number(label='Unet Learning Rate', value=0.0002)
68
+ text_encoder_learning_rate = gr.Number(label='Text Encoder Learning Rate', value=0.00004)
69
+ lora_rank = gr.Number(label='LoRA rank value', value=8, precision=0)
70
+ lora_dropout = gr.Number(label='LoRA dropout rate', value=0.1)
71
+ lora_alpha = gr.Number(label='LoRA alpha value', value=16, precision=0)
72
+ gradient_accumulation = gr.Number(
73
+ label='Number of Gradient Accumulation',
74
+ value=1,
75
+ precision=0)
76
+ seed = gr.Slider(label='Seed',
77
+ minimum=0,
78
+ maximum=100000,
79
+ step=1,
80
+ value=0)
81
+ fp16 = gr.Checkbox(label='FP16', value=True)
82
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
83
+ checkpointing_steps = gr.Number(label='Checkpointing Steps',
84
+ value=100,
85
+ precision=0)
86
+ use_wandb = gr.Checkbox(label='Use W&B',
87
+ value=False,
88
+ interactive=bool(
89
+ os.getenv('WANDB_API_KEY')))
90
+ validation_steps = gr.Number(label='Validation Steps',
91
+ value=100,
92
+ precision=0)
93
+ gr.Markdown('''
94
+ - The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
95
+ - It takes a few minutes to download the base model first.
96
+ - It will take about 16 minutes to train for 2000 steps with a T4 GPU.
97
+ - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
98
+ - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
99
+ - You need to set the environment variable `WANDB_API_KEY` if you'd like to use [W&B](https://wandb.ai/site). See [W&B documentation](https://docs.wandb.ai/guides/track/advanced/environment-variables).
100
+ - **Note:** Due to [this issue](https://github.com/huggingface/accelerate/issues/944), currently, training will not terminate properly if you use W&B.
101
+ ''')
102
+
103
+ remove_gpu_after_training = gr.Checkbox(
104
+ label='Remove GPU after training',
105
+ value=False,
106
+ interactive=bool(os.getenv('SPACE_ID')),
107
+ visible=False)
108
+ run_button = gr.Button('Start Training')
109
+
110
+ with gr.Box():
111
+ gr.Markdown('Output message')
112
+ output_message = gr.Markdown()
113
+
114
+ if pipe is not None:
115
+ run_button.click(fn=pipe.clear)
116
+ run_button.click(fn=trainer.run,
117
+ inputs=[
118
+ reference_images,
119
+ target_image,
120
+ target_mask,
121
+ output_model_name,
122
+ delete_existing_model,
123
+ base_model,
124
+ resolution,
125
+ num_training_steps,
126
+ unet_learning_rate,
127
+ text_encoder_learning_rate,
128
+ lora_rank,
129
+ lora_dropout,
130
+ lora_alpha,
131
+ gradient_accumulation,
132
+ seed,
133
+ fp16,
134
+ use_8bit_adam,
135
+ checkpointing_steps,
136
+ use_wandb,
137
+ validation_steps,
138
+ upload_to_hub,
139
+ use_private_repo,
140
+ delete_existing_repo,
141
+ upload_to,
142
+ remove_gpu_after_training,
143
+ ],
144
+ outputs=output_message)
145
+ return demo
146
+
147
+
148
+ if __name__ == '__main__':
149
+ hf_token = os.getenv('HF_TOKEN')
150
+ trainer = Trainer(hf_token)
151
+ demo = create_training_demo(trainer)
152
+ demo.queue(max_size=1).launch(share=False)
app_upload.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import pathlib
6
+
7
+ import gradio as gr
8
+ import slugify
9
+
10
+ from constants import UploadTarget
11
+ from uploader import Uploader
12
+ from utils import find_exp_dirs
13
+
14
+
15
+ class ModelUploader(Uploader):
16
+ def upload_model(
17
+ self,
18
+ folder_path: str,
19
+ repo_name: str,
20
+ upload_to: str,
21
+ private: bool,
22
+ delete_existing_repo: bool,
23
+ ) -> str:
24
+ if not folder_path:
25
+ raise ValueError
26
+ if not repo_name:
27
+ repo_name = pathlib.Path(folder_path).name
28
+ repo_name = slugify.slugify(repo_name)
29
+
30
+ if upload_to == UploadTarget.PERSONAL_PROFILE.value:
31
+ organization = ''
32
+ elif upload_to == UploadTarget.REALFILL_LIBRARY.value:
33
+ organization = 'realfill-library'
34
+ else:
35
+ raise ValueError
36
+
37
+ return self.upload(folder_path,
38
+ repo_name,
39
+ organization=organization,
40
+ private=private,
41
+ delete_existing_repo=delete_existing_repo)
42
+
43
+
44
+ def load_local_model_list() -> dict:
45
+ choices = find_exp_dirs(ignore_repo=True)
46
+ return gr.update(choices=choices, value=choices[0] if choices else None)
47
+
48
+
49
+ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
50
+ uploader = ModelUploader(hf_token)
51
+ model_dirs = find_exp_dirs(ignore_repo=True)
52
+
53
+ with gr.Blocks() as demo:
54
+ with gr.Box():
55
+ gr.Markdown('Local Models')
56
+ reload_button = gr.Button('Reload Model List')
57
+ model_dir = gr.Dropdown(
58
+ label='Model names',
59
+ choices=model_dirs,
60
+ value=model_dirs[0] if model_dirs else None)
61
+ with gr.Box():
62
+ gr.Markdown('Upload Settings')
63
+ with gr.Row():
64
+ use_private_repo = gr.Checkbox(label='Private', value=True)
65
+ delete_existing_repo = gr.Checkbox(
66
+ label='Delete existing repo of the same name', value=False)
67
+ upload_to = gr.Radio(label='Upload to',
68
+ choices=[_.value for _ in UploadTarget],
69
+ value=UploadTarget.REALFILL_LIBRARY.value)
70
+ model_name = gr.Textbox(label='Model Name')
71
+ upload_button = gr.Button('Upload')
72
+ gr.Markdown('''
73
+ - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{your_username}/{model_name}) or to the public [ReaFill Library](https://huggingface.co/realfill-library) (i.e. https://huggingface.co/realfill-library/{model_name}).
74
+ ''')
75
+ with gr.Box():
76
+ gr.Markdown('Output message')
77
+ output_message = gr.Markdown()
78
+
79
+ reload_button.click(fn=load_local_model_list,
80
+ inputs=None,
81
+ outputs=model_dir)
82
+ upload_button.click(fn=uploader.upload_model,
83
+ inputs=[
84
+ model_dir,
85
+ model_name,
86
+ upload_to,
87
+ use_private_repo,
88
+ delete_existing_repo,
89
+ ],
90
+ outputs=output_message)
91
+
92
+ return demo
93
+
94
+
95
+ if __name__ == '__main__':
96
+ import os
97
+
98
+ hf_token = os.getenv('HF_TOKEN')
99
+ demo = create_upload_demo(hf_token)
100
+ demo.queue(max_size=1).launch(share=False)
constants.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+
4
+ class UploadTarget(enum.Enum):
5
+ PERSONAL_PROFILE = 'Personal Profile'
6
+ REALFILL_LIBRARY = 'RealFill Library'
inference.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import pathlib
5
+
6
+ import gradio as gr
7
+ import PIL.Image
8
+ import torch
9
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
10
+ from huggingface_hub import ModelCard
11
+
12
+
13
+ class InferencePipeline:
14
+ def __init__(self, hf_token: str | None = None):
15
+ self.hf_token = hf_token
16
+ self.pipe = None
17
+ self.device = torch.device(
18
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
19
+ self.model_id = None
20
+
21
+ def clear(self) -> None:
22
+ self.model_id = None
23
+ del self.pipe
24
+ self.pipe = None
25
+ torch.cuda.empty_cache()
26
+ gc.collect()
27
+
28
+ @staticmethod
29
+ def check_if_model_is_local(model_id: str) -> bool:
30
+ return pathlib.Path(model_id).exists()
31
+
32
+ @staticmethod
33
+ def get_model_card(model_id: str,
34
+ hf_token: str | None = None) -> ModelCard:
35
+ if InferencePipeline.check_if_model_is_local(model_id):
36
+ card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
37
+ else:
38
+ card_path = model_id
39
+ return ModelCard.load(card_path, token=hf_token)
40
+
41
+ def load_pipe(self, model_id: str) -> None:
42
+ if model_id == self.model_id:
43
+ return
44
+
45
+ if self.device.type == 'cpu':
46
+ pipe = DiffusionPipeline.from_pretrained(
47
+ model_id, use_auth_token=self.hf_token)
48
+ else:
49
+ pipe = DiffusionPipeline.from_pretrained(
50
+ model_id, torch_dtype=torch.float16,
51
+ use_auth_token=self.hf_token)
52
+ pipe = pipe.to(self.device)
53
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(
54
+ pipe.scheduler.config)
55
+ self.pipe = pipe
56
+
57
+ pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
58
+ self.model_id = model_id # type: ignore
59
+
60
+ def run(
61
+ self,
62
+ model_id: str,
63
+ target_image: str,
64
+ target_mask: str,
65
+ seed: int,
66
+ n_steps: int,
67
+ guidance_scale: float,
68
+ ) -> PIL.Image.Image:
69
+ if not torch.cuda.is_available():
70
+ raise gr.Error('CUDA is not available.')
71
+
72
+ self.load_pipe(model_id)
73
+
74
+ generator = torch.Generator(device=self.device).manual_seed(seed)
75
+ out = self.pipe(
76
+ "a photo of sks",
77
+ image=target_image,
78
+ mask_image=target_mask,
79
+ num_inference_steps=n_steps,
80
+ guidance_scale=guidance_scale,
81
+ generator=generator,
82
+ ) # type: ignore
83
+ return out.images[0]
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.20.1
2
+ accelerate==0.23.0
3
+ transformers==4.34.0
4
+ peft==0.5.0
5
+ torch==2.0.1
6
+ torchvision==0.15.2
7
+ ftfy==6.1.1
8
+ tensorboard==2.14.0
9
+ Jinja2==3.1.2
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
train_realfill.py ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import argparse
3
+ import copy
4
+ import itertools
5
+ import logging
6
+ import math
7
+ import os
8
+ import shutil
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint
15
+ import transformers
16
+ from accelerate import Accelerator
17
+ from accelerate.logging import get_logger
18
+ from accelerate.utils import set_seed
19
+ from huggingface_hub import create_repo, upload_folder
20
+ from packaging import version
21
+ from PIL import Image
22
+ from PIL.ImageOps import exif_transpose
23
+ from torch.utils.data import Dataset
24
+ import torchvision.transforms.v2 as transforms_v2
25
+ from tqdm.auto import tqdm
26
+ from transformers import AutoTokenizer, CLIPTextModel
27
+
28
+ import diffusers
29
+ from diffusers import (
30
+ AutoencoderKL,
31
+ DDPMScheduler,
32
+ StableDiffusionInpaintPipeline,
33
+ DPMSolverMultistepScheduler,
34
+ UNet2DConditionModel,
35
+ )
36
+ from diffusers.optimization import get_scheduler
37
+ from diffusers.utils import check_min_version, is_wandb_available
38
+ from diffusers.utils.import_utils import is_xformers_available
39
+
40
+ from peft import PeftModel, LoraConfig, get_peft_model
41
+
42
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
43
+ check_min_version("0.20.1")
44
+
45
+ logger = get_logger(__name__)
46
+
47
+ def make_mask(images, resolution, times=30):
48
+ mask, times = torch.ones_like(images[0:1, :, :]), np.random.randint(1, times)
49
+ min_size, max_size, margin = np.array([0.03, 0.25, 0.01]) * resolution
50
+ max_size = min(max_size, resolution - margin * 2)
51
+
52
+ for _ in range(times):
53
+ width = np.random.randint(int(min_size), int(max_size))
54
+ height = np.random.randint(int(min_size), int(max_size))
55
+
56
+ x_start = np.random.randint(int(margin), resolution - int(margin) - width + 1)
57
+ y_start = np.random.randint(int(margin), resolution - int(margin) - height + 1)
58
+ mask[:, y_start:y_start + height, x_start:x_start + width] = 0
59
+
60
+ mask = 1 - mask if random.random() < 0.5 else mask
61
+ return mask
62
+
63
+ def save_model_card(
64
+ repo_id: str,
65
+ base_model: str,
66
+ target_image: str,
67
+ target_mask: str,
68
+ repo_folder=None,
69
+ ):
70
+ yaml = f"""
71
+ ---
72
+ license: creativeml-openrail-m
73
+ base_model: {base_model}
74
+ target_image: {target_image}
75
+ target_mask: {target_mask}
76
+ tags:
77
+ - stable-diffusion-inpainting
78
+ - stable-diffusion-inpainting-diffusers
79
+ - text-to-image
80
+ - diffusers
81
+ - realfill
82
+ inference: true
83
+ ---
84
+ """
85
+ model_card = f"""
86
+ # RealFill - {repo_id}
87
+
88
+ This is a realfill model derived from {base_model}. The weights were trained using [RealFill](https://realfill.github.io/).
89
+ """
90
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
91
+ f.write(yaml + model_card)
92
+
93
+ def log_validation(
94
+ text_encoder,
95
+ tokenizer,
96
+ unet,
97
+ args,
98
+ accelerator,
99
+ weight_dtype,
100
+ epoch,
101
+ ):
102
+ logger.info(
103
+ f"Running validation... \nGenerating {args.num_validation_images} images"
104
+ )
105
+
106
+ # create pipeline (note: unet and vae are loaded again in float32)
107
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
108
+ args.pretrained_model_name_or_path,
109
+ tokenizer=tokenizer,
110
+ revision=args.revision,
111
+ torch_dtype=weight_dtype,
112
+ )
113
+
114
+ # set `keep_fp32_wrapper` to True because we do not want to remove
115
+ # mixed precision hooks while we are still training
116
+ pipeline.unet = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
117
+ pipeline.text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
118
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
119
+
120
+ pipeline = pipeline.to(accelerator.device)
121
+ pipeline.set_progress_bar_config(disable=True)
122
+
123
+ # run inference
124
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
125
+
126
+ target_dir = Path(args.train_data_dir) / "target"
127
+ target_image, target_mask = target_dir / "target.png", target_dir / "mask.png"
128
+ image, mask_image = Image.open(target_image), Image.open(target_mask)
129
+
130
+ if image.mode != "RGB":
131
+ image = image.convert("RGB")
132
+
133
+ images = []
134
+ for _ in range(args.num_validation_images):
135
+ image = pipeline(
136
+ prompt="a photo of sks", image=image, mask_image=mask_image,
137
+ num_inference_steps=25, guidance_scale=5, generator=generator
138
+ ).images[0]
139
+ images.append(image)
140
+
141
+ for tracker in accelerator.trackers:
142
+ if tracker.name == "tensorboard":
143
+ np_images = np.stack([np.asarray(img) for img in images])
144
+ tracker.writer.add_images(f"validation", np_images, epoch, dataformats="NHWC")
145
+ if tracker.name == "wandb":
146
+ tracker.log(
147
+ {
148
+ f"validation": [
149
+ wandb.Image(image, caption=str(i)) for i, image in enumerate(images)
150
+ ]
151
+ }
152
+ )
153
+
154
+ del pipeline
155
+ torch.cuda.empty_cache()
156
+
157
+ return images
158
+
159
+ def parse_args(input_args=None):
160
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
161
+ parser.add_argument(
162
+ "--pretrained_model_name_or_path",
163
+ type=str,
164
+ default=None,
165
+ required=True,
166
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
167
+ )
168
+ parser.add_argument(
169
+ "--revision",
170
+ type=str,
171
+ default=None,
172
+ required=False,
173
+ help="Revision of pretrained model identifier from huggingface.co/models.",
174
+ )
175
+ parser.add_argument(
176
+ "--tokenizer_name",
177
+ type=str,
178
+ default=None,
179
+ help="Pretrained tokenizer name or path if not the same as model_name",
180
+ )
181
+ parser.add_argument(
182
+ "--train_data_dir",
183
+ type=str,
184
+ default=None,
185
+ required=True,
186
+ help="A folder containing the training data of images.",
187
+ )
188
+ parser.add_argument(
189
+ "--num_validation_images",
190
+ type=int,
191
+ default=4,
192
+ help="Number of images that should be generated during validation with `validation_conditioning`.",
193
+ )
194
+ parser.add_argument(
195
+ "--validation_steps",
196
+ type=int,
197
+ default=100,
198
+ help=(
199
+ "Run realfill validation every X steps. RealFill validation consists of running the conditioning"
200
+ " `args.validation_conditioning` multiple times: `args.num_validation_images`."
201
+ ),
202
+ )
203
+ parser.add_argument(
204
+ "--output_dir",
205
+ type=str,
206
+ default="realfill-model",
207
+ help="The output directory where the model predictions and checkpoints will be written.",
208
+ )
209
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
210
+ parser.add_argument(
211
+ "--resolution",
212
+ type=int,
213
+ default=512,
214
+ help=(
215
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
216
+ " resolution"
217
+ ),
218
+ )
219
+ parser.add_argument(
220
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
221
+ )
222
+ parser.add_argument("--num_train_epochs", type=int, default=1)
223
+ parser.add_argument(
224
+ "--max_train_steps",
225
+ type=int,
226
+ default=None,
227
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
228
+ )
229
+ parser.add_argument(
230
+ "--checkpointing_steps",
231
+ type=int,
232
+ default=500,
233
+ help=(
234
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
235
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
236
+ " training using `--resume_from_checkpoint`."
237
+ ),
238
+ )
239
+ parser.add_argument(
240
+ "--checkpoints_total_limit",
241
+ type=int,
242
+ default=None,
243
+ help=("Max number of checkpoints to store."),
244
+ )
245
+ parser.add_argument(
246
+ "--resume_from_checkpoint",
247
+ type=str,
248
+ default=None,
249
+ help=(
250
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
251
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
252
+ ),
253
+ )
254
+ parser.add_argument(
255
+ "--gradient_accumulation_steps",
256
+ type=int,
257
+ default=1,
258
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
259
+ )
260
+ parser.add_argument(
261
+ "--gradient_checkpointing",
262
+ action="store_true",
263
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
264
+ )
265
+ parser.add_argument(
266
+ "--unet_learning_rate",
267
+ type=float,
268
+ default=2e-4,
269
+ help="Learning rate to use for unet.",
270
+ )
271
+ parser.add_argument(
272
+ "--text_encoder_learning_rate",
273
+ type=float,
274
+ default=4e-5,
275
+ help="Learning rate to use for text encoder.",
276
+ )
277
+ parser.add_argument(
278
+ "--scale_lr",
279
+ action="store_true",
280
+ default=False,
281
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
282
+ )
283
+ parser.add_argument(
284
+ "--lr_scheduler",
285
+ type=str,
286
+ default="constant",
287
+ help=(
288
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
289
+ ' "constant", "constant_with_warmup"]'
290
+ ),
291
+ )
292
+ parser.add_argument(
293
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
294
+ )
295
+ parser.add_argument(
296
+ "--lr_num_cycles",
297
+ type=int,
298
+ default=1,
299
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
300
+ )
301
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
302
+ parser.add_argument(
303
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
304
+ )
305
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
306
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
307
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
308
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
309
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
310
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
311
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
312
+ parser.add_argument(
313
+ "--hub_model_id",
314
+ type=str,
315
+ default=None,
316
+ help="The name of the repository to keep in sync with the local `output_dir`.",
317
+ )
318
+ parser.add_argument(
319
+ "--logging_dir",
320
+ type=str,
321
+ default="logs",
322
+ help=(
323
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
324
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
325
+ ),
326
+ )
327
+ parser.add_argument(
328
+ "--allow_tf32",
329
+ action="store_true",
330
+ help=(
331
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
332
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
333
+ ),
334
+ )
335
+ parser.add_argument(
336
+ "--report_to",
337
+ type=str,
338
+ default="tensorboard",
339
+ help=(
340
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
341
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
342
+ ),
343
+ )
344
+ parser.add_argument(
345
+ "--wandb_key",
346
+ type=str,
347
+ default=None,
348
+ help=("If report to option is set to wandb, api-key for wandb used for login to wandb "),
349
+ )
350
+ parser.add_argument(
351
+ "--wandb_project_name",
352
+ type=str,
353
+ default=None,
354
+ help=("If report to option is set to wandb, project name in wandb for log tracking "),
355
+ )
356
+ parser.add_argument(
357
+ "--mixed_precision",
358
+ type=str,
359
+ default=None,
360
+ choices=["no", "fp16", "bf16"],
361
+ help=(
362
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
363
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
364
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
365
+ ),
366
+ )
367
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
368
+ parser.add_argument(
369
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
370
+ )
371
+ parser.add_argument(
372
+ "--set_grads_to_none",
373
+ action="store_true",
374
+ help=(
375
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
376
+ " behaviors, so disable this argument if it causes any problems. More info:"
377
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
378
+ ),
379
+ )
380
+ parser.add_argument(
381
+ "--lora_rank",
382
+ type=int,
383
+ default=16,
384
+ help=("The dimension of the LoRA update matrices."),
385
+ )
386
+ parser.add_argument(
387
+ "--lora_alpha",
388
+ type=int,
389
+ default=27,
390
+ help=("The alpha constant of the LoRA update matrices."),
391
+ )
392
+ parser.add_argument(
393
+ "--lora_dropout",
394
+ type=float,
395
+ default=0.0,
396
+ help="The dropout rate of the LoRA update matrices.",
397
+ )
398
+ parser.add_argument(
399
+ "--lora_bias",
400
+ type=str,
401
+ default="none",
402
+ help="The bias type of the Lora update matrices. Must be 'none', 'all' or 'lora_only'.",
403
+ )
404
+
405
+ if input_args is not None:
406
+ args = parser.parse_args(input_args)
407
+ else:
408
+ args = parser.parse_args()
409
+
410
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
411
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
412
+ args.local_rank = env_local_rank
413
+
414
+ return args
415
+
416
+ class RealFillDataset(Dataset):
417
+ """
418
+ A dataset to prepare the training and conditioning images and
419
+ the masks with the dummy prompt for fine-tuning the model.
420
+ It pre-processes the images, masks and tokenizes the prompts.
421
+ """
422
+
423
+ def __init__(
424
+ self,
425
+ train_data_root,
426
+ tokenizer,
427
+ size=512,
428
+ ):
429
+ self.size = size
430
+ self.tokenizer = tokenizer
431
+
432
+ self.ref_data_root = Path(train_data_root) / "ref"
433
+ self.target_image = Path(train_data_root) / "target" / "target.jpg"
434
+ self.target_mask = Path(train_data_root) / "target" / "mask.jpg"
435
+ if not (self.ref_data_root.exists() and self.target_image.exists() and self.target_mask.exists()):
436
+ raise ValueError("Train images root doesn't exists.")
437
+
438
+ self.train_images_path = list(self.ref_data_root.iterdir()) + [self.target_image]
439
+ self.num_train_images = len(self.train_images_path)
440
+ self.train_prompt = "a photo of sks"
441
+
442
+ self.transform = transforms_v2.Compose(
443
+ [
444
+ transforms_v2.RandomResize(size, int(1.125 * size)),
445
+ transforms_v2.RandomCrop(size),
446
+ transforms_v2.ToImageTensor(),
447
+ transforms_v2.ConvertImageDtype(),
448
+ transforms_v2.Normalize([0.5], [0.5]),
449
+ ]
450
+ )
451
+
452
+ def __len__(self):
453
+ return self.num_train_images
454
+
455
+ def __getitem__(self, index):
456
+ example = {}
457
+
458
+ image = Image.open(self.train_images_path[index])
459
+ image = exif_transpose(image)
460
+
461
+ if not image.mode == "RGB":
462
+ image = image.convert("RGB")
463
+
464
+ if index < len(self) - 1:
465
+ weighting = Image.new("L", image.size)
466
+ else:
467
+ weighting = Image.open(self.target_mask)
468
+ weighting = exif_transpose(weighting)
469
+
470
+ image, weighting = self.transform(image, weighting)
471
+ example["images"], example["weightings"] = image, weighting < 0
472
+
473
+ if random.random() < 0.1:
474
+ example["masks"] = torch.ones_like(example["images"][0:1, :, :])
475
+ else:
476
+ example["masks"] = make_mask(example["images"], self.size)
477
+
478
+ example["conditioning_images"] = example["images"] * (example["masks"] < 0.5)
479
+
480
+ train_prompt = "" if random.random() < 0.1 else self.train_prompt
481
+ example["prompt_ids"] = self.tokenizer(
482
+ train_prompt,
483
+ truncation=True,
484
+ padding="max_length",
485
+ max_length=self.tokenizer.model_max_length,
486
+ return_tensors="pt",
487
+ ).input_ids
488
+
489
+ return example
490
+
491
+ def collate_fn(examples):
492
+ input_ids = [example["prompt_ids"] for example in examples]
493
+ images = [example["images"] for example in examples]
494
+
495
+ masks = [example["masks"] for example in examples]
496
+ weightings = [example["weightings"] for example in examples]
497
+ conditioning_images = [example["conditioning_images"] for example in examples]
498
+
499
+ images = torch.stack(images)
500
+ images = images.to(memory_format=torch.contiguous_format).float()
501
+
502
+ masks = torch.stack(masks)
503
+ masks = masks.to(memory_format=torch.contiguous_format).float()
504
+
505
+ weightings = torch.stack(weightings)
506
+ weightings = weightings.to(memory_format=torch.contiguous_format).float()
507
+
508
+ conditioning_images = torch.stack(conditioning_images)
509
+ conditioning_images = conditioning_images.to(memory_format=torch.contiguous_format).float()
510
+
511
+ input_ids = torch.cat(input_ids, dim=0)
512
+
513
+ batch = {
514
+ "input_ids": input_ids,
515
+ "images": images,
516
+ "masks": masks,
517
+ "weightings": weightings,
518
+ "conditioning_images": conditioning_images,
519
+ }
520
+ return batch
521
+
522
+ def main(args):
523
+ logging_dir = Path(args.output_dir, args.logging_dir)
524
+
525
+ accelerator = Accelerator(
526
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
527
+ mixed_precision=args.mixed_precision,
528
+ log_with=args.report_to,
529
+ project_dir=logging_dir,
530
+ )
531
+
532
+ if args.report_to == "wandb":
533
+ if not is_wandb_available():
534
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
535
+ import wandb
536
+
537
+ wandb.login(key=args.wandb_key)
538
+ wandb.init(project=args.wandb_project_name)
539
+
540
+ # Make one log on every process with the configuration for debugging.
541
+ logging.basicConfig(
542
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
543
+ datefmt="%m/%d/%Y %H:%M:%S",
544
+ level=logging.INFO,
545
+ )
546
+ logger.info(accelerator.state, main_process_only=False)
547
+ if accelerator.is_local_main_process:
548
+ transformers.utils.logging.set_verbosity_warning()
549
+ diffusers.utils.logging.set_verbosity_info()
550
+ else:
551
+ transformers.utils.logging.set_verbosity_error()
552
+ diffusers.utils.logging.set_verbosity_error()
553
+
554
+ # If passed along, set the training seed now.
555
+ if args.seed is not None:
556
+ set_seed(args.seed)
557
+
558
+ # Handle the repository creation
559
+ if accelerator.is_main_process:
560
+ if args.output_dir is not None:
561
+ os.makedirs(args.output_dir, exist_ok=True)
562
+
563
+ if args.push_to_hub:
564
+ repo_id = create_repo(
565
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
566
+ ).repo_id
567
+
568
+ # Load the tokenizer
569
+ if args.tokenizer_name:
570
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
571
+ elif args.pretrained_model_name_or_path:
572
+ tokenizer = AutoTokenizer.from_pretrained(
573
+ args.pretrained_model_name_or_path,
574
+ subfolder="tokenizer",
575
+ revision=args.revision,
576
+ use_fast=False,
577
+ )
578
+
579
+ # Load scheduler and models
580
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
581
+ text_encoder = CLIPTextModel.from_pretrained(
582
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
583
+ )
584
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
585
+ unet = UNet2DConditionModel.from_pretrained(
586
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
587
+ )
588
+
589
+ config = LoraConfig(
590
+ r=args.lora_rank,
591
+ lora_alpha=args.lora_alpha,
592
+ target_modules=["to_k", "to_q", "to_v", "key", "query", "value"],
593
+ lora_dropout=args.lora_dropout,
594
+ bias=args.lora_bias,
595
+ )
596
+ unet = get_peft_model(unet, config)
597
+
598
+ config = LoraConfig(
599
+ r=args.lora_rank,
600
+ lora_alpha=args.lora_alpha,
601
+ target_modules=["k_proj", "q_proj", "v_proj"],
602
+ lora_dropout=args.lora_dropout,
603
+ bias=args.lora_bias,
604
+ )
605
+ text_encoder = get_peft_model(text_encoder, config)
606
+
607
+ vae.requires_grad_(False)
608
+
609
+ if args.enable_xformers_memory_efficient_attention:
610
+ if is_xformers_available():
611
+ import xformers
612
+
613
+ xformers_version = version.parse(xformers.__version__)
614
+ if xformers_version == version.parse("0.0.16"):
615
+ logger.warn(
616
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
617
+ )
618
+ unet.enable_xformers_memory_efficient_attention()
619
+ else:
620
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
621
+
622
+ if args.gradient_checkpointing:
623
+ unet.enable_gradient_checkpointing()
624
+ text_encoder.gradient_checkpointing_enable()
625
+
626
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
627
+ def save_model_hook(models, weights, output_dir):
628
+ if accelerator.is_main_process:
629
+ for model in models:
630
+ sub_dir = "unet" if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model))) else "text_encoder"
631
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
632
+
633
+ # make sure to pop weight so that corresponding model is not saved again
634
+ weights.pop()
635
+
636
+ def load_model_hook(models, input_dir):
637
+ while len(models) > 0:
638
+ # pop models so that they are not loaded again
639
+ model = models.pop()
640
+
641
+ sub_dir = "unet" if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model))) else "text_encoder"
642
+ model_cls = UNet2DConditionModel if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model))) else CLIPTextModel
643
+
644
+ load_model = model_cls.from_pretrained(args.pretrained_model_name_or_path, subfolder=sub_dir)
645
+ load_model = PeftModel.from_pretrained(load_model, input_dir, subfolder=sub_dir)
646
+
647
+ model.load_state_dict(load_model.state_dict())
648
+ del load_model
649
+
650
+ accelerator.register_save_state_pre_hook(save_model_hook)
651
+ accelerator.register_load_state_pre_hook(load_model_hook)
652
+
653
+ # Enable TF32 for faster training on Ampere GPUs,
654
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
655
+ if args.allow_tf32:
656
+ torch.backends.cuda.matmul.allow_tf32 = True
657
+
658
+ if args.scale_lr:
659
+ args.unet_learning_rate = (
660
+ args.unet_learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
661
+ )
662
+
663
+ args.text_encoder_learning_rate = (
664
+ args.text_encoder_learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
665
+ )
666
+
667
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
668
+ if args.use_8bit_adam:
669
+ try:
670
+ import bitsandbytes as bnb
671
+ except ImportError:
672
+ raise ImportError(
673
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
674
+ )
675
+
676
+ optimizer_class = bnb.optim.AdamW8bit
677
+ else:
678
+ optimizer_class = torch.optim.AdamW
679
+
680
+ # Optimizer creation
681
+ optimizer = optimizer_class(
682
+ [
683
+ {"params": unet.parameters(), "lr": args.unet_learning_rate},
684
+ {"params": text_encoder.parameters(), "lr": args.text_encoder_learning_rate}
685
+ ],
686
+ betas=(args.adam_beta1, args.adam_beta2),
687
+ weight_decay=args.adam_weight_decay,
688
+ eps=args.adam_epsilon,
689
+ )
690
+
691
+ # Dataset and DataLoaders creation:
692
+ train_dataset = RealFillDataset(
693
+ train_data_root=args.train_data_dir,
694
+ tokenizer=tokenizer,
695
+ size=args.resolution,
696
+ )
697
+
698
+ train_dataloader = torch.utils.data.DataLoader(
699
+ train_dataset,
700
+ batch_size=args.train_batch_size,
701
+ shuffle=True,
702
+ collate_fn=collate_fn,
703
+ num_workers=1,
704
+ )
705
+
706
+ # Scheduler and math around the number of training steps.
707
+ overrode_max_train_steps = False
708
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
709
+ if args.max_train_steps is None:
710
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
711
+ overrode_max_train_steps = True
712
+
713
+ lr_scheduler = get_scheduler(
714
+ args.lr_scheduler,
715
+ optimizer=optimizer,
716
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
717
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
718
+ num_cycles=args.lr_num_cycles,
719
+ power=args.lr_power,
720
+ )
721
+
722
+ # Prepare everything with our `accelerator`.
723
+ unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
724
+ unet, text_encoder, optimizer, train_dataloader
725
+ )
726
+
727
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
728
+ # as these weights are only used for inference, keeping weights in full precision is not required.
729
+ weight_dtype = torch.float32
730
+ if accelerator.mixed_precision == "fp16":
731
+ weight_dtype = torch.float16
732
+ elif accelerator.mixed_precision == "bf16":
733
+ weight_dtype = torch.bfloat16
734
+
735
+ # Move vae to device and cast to weight_dtype
736
+ vae.to(accelerator.device, dtype=weight_dtype)
737
+
738
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
739
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
740
+ if overrode_max_train_steps:
741
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
742
+ # Afterwards we recalculate our number of training epochs
743
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
744
+
745
+ # We need to initialize the trackers we use, and also store our configuration.
746
+ # The trackers initializes automatically on the main process.
747
+ if accelerator.is_main_process:
748
+ tracker_config = vars(copy.deepcopy(args))
749
+ accelerator.init_trackers("realfill", config=tracker_config)
750
+
751
+ # Train!
752
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
753
+
754
+ logger.info("***** Running training *****")
755
+ logger.info(f" Num examples = {len(train_dataset)}")
756
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
757
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
758
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
759
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
760
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
761
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
762
+ global_step = 0
763
+ first_epoch = 0
764
+
765
+ # Potentially load in the weights and states from a previous save
766
+ if args.resume_from_checkpoint:
767
+ if args.resume_from_checkpoint != "latest":
768
+ path = os.path.basename(args.resume_from_checkpoint)
769
+ else:
770
+ # Get the mos recent checkpoint
771
+ dirs = os.listdir(args.output_dir)
772
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
773
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
774
+ path = dirs[-1] if len(dirs) > 0 else None
775
+
776
+ if path is None:
777
+ accelerator.print(
778
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
779
+ )
780
+ args.resume_from_checkpoint = None
781
+ initial_global_step = 0
782
+ else:
783
+ accelerator.print(f"Resuming from checkpoint {path}")
784
+ accelerator.load_state(os.path.join(args.output_dir, path))
785
+ global_step = int(path.split("-")[1])
786
+
787
+ initial_global_step = global_step
788
+ first_epoch = global_step // num_update_steps_per_epoch
789
+ else:
790
+ initial_global_step = 0
791
+
792
+ progress_bar = tqdm(
793
+ range(0, args.max_train_steps),
794
+ initial=initial_global_step,
795
+ desc="Steps",
796
+ # Only show the progress bar once on each machine.
797
+ disable=not accelerator.is_local_main_process,
798
+ )
799
+
800
+ for epoch in range(first_epoch, args.num_train_epochs):
801
+ unet.train()
802
+ text_encoder.train()
803
+
804
+ for step, batch in enumerate(train_dataloader):
805
+ with accelerator.accumulate(unet, text_encoder):
806
+ # Convert images to latent space
807
+ latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
808
+ latents = latents * 0.18215
809
+
810
+ # Convert masked images to latent space
811
+ conditionings = vae.encode(batch["conditioning_images"].to(dtype=weight_dtype)).latent_dist.sample()
812
+ conditionings = conditionings * 0.18215
813
+
814
+ # Downsample mask and weighting so that they match with the latents
815
+ masks, size = batch["masks"].to(dtype=weight_dtype), latents.shape[2:]
816
+ masks = F.interpolate(masks, size=size)
817
+
818
+ weightings = batch["weightings"].to(dtype=weight_dtype)
819
+ weightings = F.interpolate(weightings, size=size)
820
+
821
+ # Sample noise that we'll add to the latents
822
+ noise = torch.randn_like(latents)
823
+ bsz = latents.shape[0]
824
+
825
+ # Sample a random timestep for each image
826
+ timesteps = torch.randint(
827
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
828
+ )
829
+ timesteps = timesteps.long()
830
+
831
+ # Add noise to the latents according to the noise magnitude at each timestep
832
+ # (this is the forward diffusion process)
833
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
834
+
835
+ # Concatenate noisy latents, masks and conditionings to get inputs to unet
836
+ inputs = torch.cat([noisy_latents, masks, conditionings], dim=1)
837
+
838
+ # Get the text embedding for conditioning
839
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
840
+
841
+ # Predict the noise residual
842
+ model_pred = unet(inputs, timesteps, encoder_hidden_states).sample
843
+
844
+ # Compute the diffusion loss
845
+ assert noise_scheduler.config.prediction_type == "epsilon"
846
+ loss = (weightings * F.mse_loss(model_pred.float(), noise.float(), reduction="none")).mean()
847
+
848
+ # Backpropagate
849
+ accelerator.backward(loss)
850
+ if accelerator.sync_gradients:
851
+ params_to_clip = itertools.chain(
852
+ unet.parameters(), text_encoder.parameters()
853
+ )
854
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
855
+
856
+ optimizer.step()
857
+ lr_scheduler.step()
858
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
859
+
860
+ # Checks if the accelerator has performed an optimization step behind the scenes
861
+ if accelerator.sync_gradients:
862
+ progress_bar.update(1)
863
+ if args.report_to == "wandb":
864
+ accelerator.print(progress_bar)
865
+ global_step += 1
866
+
867
+ if accelerator.is_main_process:
868
+ if global_step % args.checkpointing_steps == 0:
869
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
870
+ if args.checkpoints_total_limit is not None:
871
+ checkpoints = os.listdir(args.output_dir)
872
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
873
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
874
+
875
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
876
+ if len(checkpoints) >= args.checkpoints_total_limit:
877
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
878
+ removing_checkpoints = checkpoints[0:num_to_remove]
879
+
880
+ logger.info(
881
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
882
+ )
883
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
884
+
885
+ for removing_checkpoint in removing_checkpoints:
886
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
887
+ shutil.rmtree(removing_checkpoint)
888
+
889
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
890
+ accelerator.save_state(save_path)
891
+ logger.info(f"Saved state to {save_path}")
892
+
893
+ if global_step % args.validation_steps == 0:
894
+ log_validation(
895
+ text_encoder,
896
+ tokenizer,
897
+ unet,
898
+ args,
899
+ accelerator,
900
+ weight_dtype,
901
+ global_step,
902
+ )
903
+
904
+ logs = {"loss": loss.detach().item()}
905
+ progress_bar.set_postfix(**logs)
906
+ accelerator.log(logs, step=global_step)
907
+
908
+ if global_step >= args.max_train_steps:
909
+ break
910
+
911
+ # Save the lora layers
912
+ accelerator.wait_for_everyone()
913
+ if accelerator.is_main_process:
914
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
915
+ args.pretrained_model_name_or_path,
916
+ unet=accelerator.unwrap_model(unet.merge_and_unload(), keep_fp32_wrapper=True),
917
+ text_encoder=accelerator.unwrap_model(text_encoder.merge_and_unload(), keep_fp32_wrapper=True),
918
+ revision=args.revision,
919
+ )
920
+
921
+ pipeline.save_pretrained(args.output_dir)
922
+
923
+ # Final inference
924
+ images = log_validation(
925
+ text_encoder,
926
+ tokenizer,
927
+ unet,
928
+ args,
929
+ accelerator,
930
+ weight_dtype,
931
+ global_step,
932
+ )
933
+
934
+ if args.push_to_hub:
935
+ save_model_card(
936
+ repo_id,
937
+ images=images,
938
+ base_model=args.pretrained_model_name_or_path,
939
+ repo_folder=args.output_dir,
940
+ )
941
+ upload_folder(
942
+ repo_id=repo_id,
943
+ folder_path=args.output_dir,
944
+ commit_message="End of training",
945
+ ignore_patterns=["step_*", "epoch_*"],
946
+ )
947
+
948
+ accelerator.end_training()
949
+
950
+ if __name__ == "__main__":
951
+ args = parse_args()
952
+ main(args)
trainer.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import os
5
+ import pathlib
6
+ import shlex
7
+ import shutil
8
+ import subprocess
9
+
10
+ import gradio as gr
11
+ import PIL.Image
12
+ import slugify
13
+ import torch
14
+ from huggingface_hub import HfApi
15
+
16
+ from app_upload import ModelUploader
17
+ from utils import save_model_card
18
+
19
+ URL_TO_JOIN_LIBRARY_ORG = 'https://huggingface.co/organizations/realfill-library/share/WctmaLvDHWxnuWoJxagTrzVXbGwxoqoJoG'
20
+
21
+ class Trainer:
22
+ def __init__(self, hf_token: str | None = None):
23
+ self.hf_token = hf_token
24
+ self.api = HfApi(token=hf_token)
25
+ self.model_uploader = ModelUploader(hf_token)
26
+
27
+ def prepare_dataset(self, reference_images: list, resolution: int,
28
+ target_image: PIL.Image, target_mask: PIL.Image,
29
+ train_data_dir: pathlib.Path) -> None:
30
+ shutil.rmtree(train_data_dir, ignore_errors=True)
31
+ train_data_dir.mkdir(parents=True)
32
+
33
+ (train_data_dir / 'ref').mkdir(parents=True)
34
+ (train_data_dir / 'target').mkdir(parents=True)
35
+
36
+ for i, temp_path in enumerate(reference_images):
37
+ image = PIL.Image.open(temp_path.name)
38
+ image = image.convert('RGB')
39
+ out_path = train_data_dir / 'ref' / f'{i:03d}.jpg'
40
+ image.save(out_path, format='JPEG', quality=100)
41
+
42
+ target_image = PIL.Image.open(target_image[0].name)
43
+ target_image = target_image.convert('RGB')
44
+ out_path = train_data_dir / 'target' / f'target.jpg'
45
+ target_image.save(out_path, format='JPEG', quality=100)
46
+
47
+ target_mask = PIL.Image.open(target_mask[0].name)
48
+ target_mask = target_mask.convert('L')
49
+ out_path = train_data_dir / 'target' / f'mask.jpg'
50
+ target_mask.save(out_path, format='JPEG', quality=100)
51
+
52
+ def join_library_org(self) -> None:
53
+ subprocess.run(
54
+ shlex.split(
55
+ f'curl -X POST -H "Authorization: Bearer {self.hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_LIBRARY_ORG}'
56
+ ))
57
+
58
+ def run(
59
+ self,
60
+ reference_images: list | None,
61
+ target_image: PIL.ImageFile | None,
62
+ target_mask: PIL.ImageFile | None,
63
+ output_model_name: str,
64
+ overwrite_existing_model: bool,
65
+ base_model: str,
66
+ resolution_s: str,
67
+ n_steps: int,
68
+ unet_learning_rate: float,
69
+ text_encoder_learning_rate: float,
70
+ lora_rank: int,
71
+ lora_dropout: float,
72
+ lora_alpha: int,
73
+ gradient_accumulation: int,
74
+ seed: int,
75
+ fp16: bool,
76
+ use_8bit_adam: bool,
77
+ checkpointing_steps: int,
78
+ use_wandb: bool,
79
+ validation_steps: int,
80
+ upload_to_hub: bool,
81
+ use_private_repo: bool,
82
+ delete_existing_repo: bool,
83
+ upload_to: str,
84
+ remove_gpu_after_training: bool,
85
+ ) -> str:
86
+ if not torch.cuda.is_available():
87
+ raise gr.Error('CUDA is not available.')
88
+ if reference_images is None:
89
+ raise gr.Error('You need to upload reference images.')
90
+ if target_image is None:
91
+ raise gr.Error('The instance prompt is missing.')
92
+
93
+ resolution = int(resolution_s)
94
+
95
+ if not output_model_name:
96
+ timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
97
+ output_model_name = f'realfill-{timestamp}'
98
+ output_model_name = slugify.slugify(output_model_name)
99
+
100
+ repo_dir = pathlib.Path(__file__).parent
101
+ output_dir = repo_dir / 'experiments' / output_model_name
102
+ if overwrite_existing_model or upload_to_hub:
103
+ shutil.rmtree(output_dir, ignore_errors=True)
104
+ output_dir.mkdir(parents=True)
105
+
106
+ train_data_dir = repo_dir / 'training_data' / output_model_name
107
+ self.prepare_dataset(reference_images, resolution, target_image, target_mask, train_data_dir)
108
+
109
+ if upload_to_hub:
110
+ self.join_library_org()
111
+
112
+ command = f'''
113
+ python train_realfill.py \
114
+ --pretrained_model_name_or_path={base_model} \
115
+ --train_data_dir={train_data_dir} \
116
+ --output_dir={output_dir} \
117
+ --resolution={resolution} \
118
+ --train_batch_size=16 \
119
+ --gradient_accumulation_steps={gradient_accumulation} --gradient_checkpointing \
120
+ --unet_learning_rate={unet_learning_rate} \
121
+ --text_encoder_learning_rate={text_encoder_learning_rate} \
122
+ --lr_scheduler=constant \
123
+ --lr_warmup_steps=100 \
124
+ --set_grads_to_none \
125
+ --max_train_steps={n_steps} \
126
+ --checkpointing_steps={checkpointing_steps} \
127
+ --validation_steps={validation_steps} \
128
+ --lora_rank={lora_rank} \
129
+ --lora_dropout={lora_dropout} \
130
+ --lora_alpha={lora_alpha} \
131
+ --seed={seed}
132
+ '''
133
+ if fp16:
134
+ command += ' --mixed_precision fp16'
135
+ if use_8bit_adam:
136
+ command += ' --use_8bit_adam'
137
+ if use_wandb:
138
+ command += ' --report_to wandb'
139
+
140
+ with open(output_dir / 'train.sh', 'w') as f:
141
+ command_s = ' '.join(command.split())
142
+ f.write(command_s)
143
+ subprocess.run(shlex.split(command))
144
+ save_model_card(save_dir=output_dir,
145
+ base_model=base_model,
146
+ target_image=train_data_dir / 'target' / 'target.jpg',
147
+ target_mask=train_data_dir / 'target' / 'mask.jpg')
148
+
149
+ message = 'Training completed!'
150
+ print(message)
151
+
152
+ if upload_to_hub:
153
+ upload_message = self.model_uploader.upload_model(
154
+ folder_path=output_dir.as_posix(),
155
+ repo_name=output_model_name,
156
+ upload_to=upload_to,
157
+ private=use_private_repo,
158
+ delete_existing_repo=delete_existing_repo)
159
+ print(upload_message)
160
+ message = message + '\n' + upload_message
161
+
162
+ if remove_gpu_after_training:
163
+ space_id = os.getenv('SPACE_ID')
164
+ if space_id:
165
+ self.api.request_space_hardware(repo_id=space_id,
166
+ hardware='cpu-basic')
167
+
168
+ return message
uploader.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from huggingface_hub import HfApi
4
+
5
+
6
+ class Uploader:
7
+ def __init__(self, hf_token: str | None):
8
+ self.api = HfApi(token=hf_token)
9
+
10
+ def get_username(self) -> str:
11
+ return self.api.whoami()['name']
12
+
13
+ def upload(self,
14
+ folder_path: str,
15
+ repo_name: str,
16
+ organization: str = '',
17
+ repo_type: str = 'model',
18
+ private: bool = True,
19
+ delete_existing_repo: bool = False) -> str:
20
+ if not folder_path:
21
+ raise ValueError
22
+ if not repo_name:
23
+ raise ValueError
24
+ if not organization:
25
+ organization = self.get_username()
26
+ repo_id = f'{organization}/{repo_name}'
27
+ if delete_existing_repo:
28
+ try:
29
+ self.api.delete_repo(repo_id, repo_type=repo_type)
30
+ except Exception:
31
+ pass
32
+ try:
33
+ self.api.create_repo(repo_id, repo_type=repo_type, private=private)
34
+ self.api.upload_folder(repo_id=repo_id,
35
+ folder_path=folder_path,
36
+ path_in_repo='.',
37
+ repo_type=repo_type)
38
+ url = f'https://huggingface.co/{repo_id}'
39
+ message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
40
+ except Exception as e:
41
+ message = str(e)
42
+ return message
utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import pathlib
4
+
5
+
6
+ def find_exp_dirs(ignore_repo: bool = False) -> list[str]:
7
+ repo_dir = pathlib.Path(__file__).parent
8
+ exp_root_dir = repo_dir / 'experiments'
9
+ if not exp_root_dir.exists():
10
+ return []
11
+ exp_dirs = sorted(exp_root_dir.glob('*'))
12
+ exp_dirs = [
13
+ exp_dir for exp_dir in exp_dirs
14
+ if (exp_dir / 'model_index.json').exists()
15
+ ]
16
+ if ignore_repo:
17
+ exp_dirs = [
18
+ exp_dir for exp_dir in exp_dirs if not (exp_dir / '.git').exists()
19
+ ]
20
+ return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
21
+
22
+
23
+ def save_model_card(
24
+ save_dir: pathlib.Path,
25
+ base_model: str,
26
+ target_image: str,
27
+ target_mask: str,
28
+ ) -> None:
29
+ model_card = f'''---
30
+ license: creativeml-openrail-m
31
+ base_model: {base_model}
32
+ target_image: {target_image}
33
+ target_mask: {target_mask}
34
+ tags:
35
+ - stable-diffusion-inpainting
36
+ - stable-diffusion-inpainting-diffusers
37
+ - text-to-image
38
+ - diffusers
39
+ - realfill
40
+ inference: true
41
+ ---
42
+ # RealFill - {save_dir.name}
43
+
44
+ These are RealFill weights for [{base_model}](https://huggingface.co/{base_model}). The weights were trained using [RealFill](https://realfill.github.io/).
45
+ '''
46
+
47
+ with open(save_dir / 'README.md', 'w') as f:
48
+ f.write(model_card)