diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..77b1579bda12971e4a9e737cb7d618b01c0289a8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,142 @@ +*.pth +*.pt +*.pyc +src/ +outputs/ +models/ +models +.DS_Store +ia_config.ini +.eslintrc +.eslintrc.json +pyproject.toml + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..29f81d812f3e768fa89638d1f72920dbfd1413a8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 4d3cb097be3fd3a84108fe70bb08f3f3b3141536..978542474e83a7dbaf589b6766ca17fac8071479 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,149 @@ ---- -title: ' ' -emoji: 🚀 -colorFrom: pink -colorTo: blue -sdk: gradio -sdk_version: 4.40.0 -app_file: app.py -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +--- +title: _ +app_file: iasam_app.py +sdk: gradio +sdk_version: 3.50.2 +--- +# Inpaint Anything (Inpainting with Segment Anything) + +Inpaint Anything performs stable diffusion inpainting on a browser UI using any mask selected from the output of [Segment Anything](https://github.com/facebookresearch/segment-anything). + + +Using Segment Anything enables users to specify masks by simply pointing to the desired areas, instead of manually filling them in. This can increase the efficiency and accuracy of the mask creation process, leading to potentially higher-quality inpainting results while saving time and effort. + +[Extension version for AUTOMATIC1111's Web UI](https://github.com/Uminosachi/sd-webui-inpaint-anything) + +![Explanation image](images/inpaint_anything_explanation_image_1.png) + +## Installation + +Please follow these steps to install the software: + +* Create a new conda environment: + +```bash +conda create -n inpaint python=3.10 +conda activate inpaint +``` + +* Clone the software repository: + +```bash +git clone https://github.com/Uminosachi/inpaint-anything.git +cd inpaint-anything +``` + +* For the CUDA environment, install the following packages: + +```bash +pip install -r requirements.txt +``` + +* If you are using macOS, please install the package from the following file instead: + +```bash +pip install -r requirements_mac.txt +``` + +## Running the application + +```bash +python iasam_app.py +``` + +* Open http://127.0.0.1:7860/ in your browser. +* Note: If you have a privacy protection extension enabled in your web browser, such as DuckDuckGo, you may not be able to retrieve the mask from your sketch. + +### Options + +* `--save-seg`: Save the segmentation image generated by SAM. +* `--offline`: Execute inpainting using an offline network. +* `--sam-cpu`: Perform the Segment Anything operation on CPU. + +## Downloading the Model + +* Launch this application. +* Click on the `Download model` button, located next to the [Segment Anything Model ID](https://github.com/facebookresearch/segment-anything#model-checkpoints). This includes the [SAM 2](https://github.com/facebookresearch/segment-anything-2), [Segment Anything in High Quality Model ID](https://github.com/SysCV/sam-hq), [Fast Segment Anything](https://github.com/CASIA-IVA-Lab/FastSAM), and [Faster Segment Anything (MobileSAM)](https://github.com/ChaoningZhang/MobileSAM). + * Please note that the SAM is available in three sizes: Base, Large, and Huge. Remember, larger sizes consume more VRAM. +* Wait for the download to complete. +* The downloaded model file will be stored in the `models` directory of this application's repository. + +## Usage + +* Drag and drop your image onto the input image area. + * Outpainting can be achieved by the `Padding options`, configuring the scale and balance, and then clicking on the `Run Padding` button. + * The `Anime Style` checkbox enhances segmentation mask detection, particularly in anime style images, at the expense of a slight reduction in mask quality. +* Click on the `Run Segment Anything` button. +* Use sketching to point the area you want to inpaint. You can undo and adjust the pen size. + * Hover over either the SAM image or the mask image and press the `S` key for Fullscreen mode, or the `R` key to Reset zoom. +* Click on the `Create mask` button. The mask will appear in the selected mask image area. + +### Mask Adjustment + +* `Expand mask region` button: Use this to slightly expand the area of the mask for broader coverage. +* `Trim mask by sketch` button: Clicking this will exclude the sketched area from the mask. +* `Add mask by sketch` button: Clicking this will add the sketched area to the mask. + +### Inpainting Tab + +* Enter your desired Prompt and Negative Prompt, then choose the Inpainting Model ID. +* Click on the `Run Inpainting` button (**Please note that it may take some time to download the model for the first time**). + * In the Advanced options, you can adjust the Sampler, Sampling Steps, Guidance Scale, and Seed. + * If you enable the `Mask area Only` option, modifications will be confined to the designated mask area only. +* Adjust the iteration slider to perform inpainting multiple times with different seeds. +* The inpainting process is powered by [diffusers](https://github.com/huggingface/diffusers). + +#### Tips + +* You can directly drag and drop the inpainted image into the input image field on the Web UI. (useful with Chrome and Edge browsers) + +#### Model Cache +* The inpainting model, which is saved in HuggingFace's cache and includes `inpaint` (case-insensitive) in its repo_id, will also be added to the Inpainting Model ID dropdown list. + * If there's a specific model you'd like to use, you can cache it in advance using the following Python commands: + ```bash + python + ``` + ```python + from diffusers import StableDiffusionInpaintPipeline + pipe = StableDiffusionInpaintPipeline.from_pretrained("Uminosachi/dreamshaper_5-inpainting") + exit() + ``` +* The model diffusers downloaded is typically stored in your home directory. You can find it at `/home/username/.cache/huggingface/hub` for Linux and MacOS users, or at `C:\Users\username\.cache\huggingface\hub` for Windows users. + * When executing inpainting, if the following error is output to the console, try deleting the corresponding model from the cache folder mentioned above: + ``` + An error occurred while trying to fetch model name... + ``` + +### Cleaner Tab + +* Choose the Cleaner Model ID. +* Click on the `Run Cleaner` button (**Please note that it may take some time to download the model for the first time**). +* Cleaner process is performed using [Lama Cleaner](https://github.com/Sanster/lama-cleaner). + +### Mask only Tab + +* Gives ability to just save mask without any other processing, so it's then possible to use the mask in other graphic applications. +* `Get mask as alpha of image` button: Save the mask as RGBA image, with the mask put into the alpha channel of the input image. +* `Get mask` button: Save the mask as RGB image. + +![UI image](images/inpaint_anything_ui_image_1.png) + +## Auto-saving images + +* The inpainted image will be automatically saved in the folder that matches the current date within the `outputs` directory. + +## Development + +With the [Inpaint Anything library](README_DEV.md), you can perform segmentation and create masks using sketches from other applications. + +## License + +The source code is licensed under the [Apache 2.0 license](LICENSE). + +## References + +* Ravi, N., Gabeur, V., Hu, Y.-T., Hu, R., Ryali, C., Ma, T., Khedr, H., Rädel, R., Rolland, C., Gustafson, L., Mintun, E., Pan, J., Alwala, K. V., Carion, N., Wu, C.-Y., Girshick, R., Dollár, P., & Feichtenhofer, C. (2024). [SAM 2: Segment Anything in Images and Videos](https://ai.meta.com/research/publications/sam-2-segment-anything-in-images-and-videos/). arXiv preprint. +* Kirillov, A., Mintun, E., Ravi, N., Mao, H., Rolland, C., Gustafson, L., Xiao, T., Whitehead, S., Berg, A. C., Lo, W-Y., Dollár, P., & Girshick, R. (2023). [Segment Anything](https://arxiv.org/abs/2304.02643). arXiv:2304.02643. +* Ke, L., Ye, M., Danelljan, M., Liu, Y., Tai, Y-W., Tang, C-K., & Yu, F. (2023). [Segment Anything in High Quality](https://arxiv.org/abs/2306.01567). arXiv:2306.01567. +* Zhao, X., Ding, W., An, Y., Du, Y., Yu, T., Li, M., Tang, M., & Wang, J. (2023). [Fast Segment Anything](https://arxiv.org/abs/2306.12156). arXiv:2306.12156 [cs.CV]. +* Zhang, C., Han, D., Qiao, Y., Kim, J. U., Bae, S-H., Lee, S., & Hong, C. S. (2023). [Faster Segment Anything: Towards Lightweight SAM for Mobile Applications](https://arxiv.org/abs/2306.14289). arXiv:2306.14289. diff --git a/README_DEV.md b/README_DEV.md new file mode 100644 index 0000000000000000000000000000000000000000..4fad2a24043b65d0a1c63192dbc00840e1ce8fd8 --- /dev/null +++ b/README_DEV.md @@ -0,0 +1,61 @@ +# Usage of Inpaint Anything Library + +## Introduction + +The `inpalib` from the `inpaint-anything` package lets you segment images and create masks using sketches from other applications. + +## Code Breakdown + +### Imports and Module Initialization + +```python +import importlib + +import numpy as np +from PIL import Image, ImageDraw + +inpalib = importlib.import_module("inpaint-anything.inpalib") +``` + +### Fetch Model IDs + +```python +available_sam_ids = inpalib.get_available_sam_ids() + +use_sam_id = "sam_hq_vit_l.pth" +# assert use_sam_id in available_sam_ids, f"Invalid SAM ID: {use_sam_id}" +``` + +Note: Only the models downloaded via the Inpaint Anything are available. + +### Generate Segments Image + +```python +input_image = np.array(Image.open("/path/to/image.png")) + +sam_masks = inpalib.generate_sam_masks(input_image, use_sam_id, anime_style_chk=False) +sam_masks = inpalib.sort_masks_by_area(sam_masks) + +seg_color_image = inpalib.create_seg_color_image(input_image, sam_masks) + +Image.fromarray(seg_color_image).save("/path/to/seg_color_image.png") +``` + +drawing drawing + +### Create Mask from Sketch + +```python +sketch_image = Image.fromarray(np.zeros_like(input_image)) + +draw = ImageDraw.Draw(sketch_image) +draw.point((input_image.shape[1] // 2, input_image.shape[0] // 2), fill=(255, 255, 255)) + +mask_image = inpalib.create_mask_image(np.array(sketch_image), sam_masks, ignore_black_chk=True) + +Image.fromarray(mask_image).save("/path/to/mask_image.png") +``` + +drawing + +Note: Ensure you adjust the file paths before executing the code. diff --git a/fast_sam/__init__.py b/fast_sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d58ef79249f6bc66708e8b8cf4b7560c9b8149e --- /dev/null +++ b/fast_sam/__init__.py @@ -0,0 +1,9 @@ +from .fast_sam_wrapper import FastSAM +from .fast_sam_wrapper import FastSamAutomaticMaskGenerator + +fast_sam_model_registry = { + "FastSAM-x": FastSAM, + "FastSAM-s": FastSAM, +} + +__all__ = ["FastSAM", "FastSamAutomaticMaskGenerator", "fast_sam_model_registry"] diff --git a/fast_sam/fast_sam_wrapper.py b/fast_sam/fast_sam_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..71454076010baffcdad522a33e91ecc6292c276e --- /dev/null +++ b/fast_sam/fast_sam_wrapper.py @@ -0,0 +1,90 @@ +import inspect +import math +from typing import Any, Dict, List + +import cv2 +import numpy as np +import torch +import ultralytics + +if hasattr(ultralytics, "FastSAM"): + from ultralytics import FastSAM as YOLO +else: + from ultralytics import YOLO + + +class FastSAM: + def __init__( + self, + checkpoint: str, + ) -> None: + self.model_path = checkpoint + self.model = YOLO(self.model_path) + + if not hasattr(torch.nn.Upsample, "recompute_scale_factor"): + torch.nn.Upsample.recompute_scale_factor = None + + def to(self, device) -> None: + self.model.to(device) + + @property + def device(self) -> Any: + return self.model.device + + def __call__(self, source=None, stream=False, **kwargs) -> Any: + return self.model(source=source, stream=stream, **kwargs) + + +class FastSamAutomaticMaskGenerator: + def __init__( + self, + model: FastSAM, + points_per_batch: int = None, + pred_iou_thresh: float = None, + stability_score_thresh: float = None, + ) -> None: + self.model = model + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.conf = 0.25 if stability_score_thresh >= 0.95 else 0.15 + + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + height, width = image.shape[:2] + new_height = math.ceil(height / 32) * 32 + new_width = math.ceil(width / 32) * 32 + resize_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_CUBIC) + + backup_nn_dict = {} + for key, _ in torch.nn.__dict__.copy().items(): + if not inspect.isclass(torch.nn.__dict__.get(key)) and "Norm" in key: + backup_nn_dict[key] = torch.nn.__dict__.pop(key) + + results = self.model( + source=resize_image, + stream=False, + imgsz=max(new_height, new_width), + device=self.model.device, + retina_masks=True, + iou=0.7, + conf=self.conf, + max_det=256) + + for key, value in backup_nn_dict.items(): + setattr(torch.nn, key, value) + # assert backup_nn_dict[key] == torch.nn.__dict__[key] + + annotations = results[0].masks.data + + if isinstance(annotations[0], torch.Tensor): + annotations = np.array(annotations.cpu()) + + annotations_list = [] + for mask in annotations: + mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) + mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((7, 7), np.uint8)) + mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_AREA) + + annotations_list.append(dict(segmentation=mask.astype(bool))) + + return annotations_list diff --git a/ia_check_versions.py b/ia_check_versions.py new file mode 100644 index 0000000000000000000000000000000000000000..c51c461214e91e4427a0fed915c2d834c63b998f --- /dev/null +++ b/ia_check_versions.py @@ -0,0 +1,74 @@ +from functools import cached_property +from importlib.metadata import version +from importlib.util import find_spec + +import torch +from packaging.version import parse + + +def get_module_version(module_name): + try: + module_version = version(module_name) + except Exception: + module_version = None + return module_version + + +def compare_version(version1, version2): + if not isinstance(version1, str) or not isinstance(version2, str): + return None + + if parse(version1) > parse(version2): + return 1 + elif parse(version1) < parse(version2): + return -1 + else: + return 0 + + +def compare_module_version(module_name, version_string): + module_version = get_module_version(module_name) + + result = compare_version(module_version, version_string) + return result if result is not None else -2 + + +class IACheckVersions: + @cached_property + def diffusers_enable_cpu_offload(self): + if (find_spec("diffusers") is not None and compare_module_version("diffusers", "0.15.0") >= 0 and + find_spec("accelerate") is not None and compare_module_version("accelerate", "0.17.0") >= 0 and + torch.cuda.is_available()): + return True + else: + return False + + @cached_property + def torch_mps_is_available(self): + if compare_module_version("torch", "2.0.1") < 0: + if not getattr(torch, "has_mps", False): + return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False + else: + return torch.backends.mps.is_available() and torch.backends.mps.is_built() + + @cached_property + def torch_on_amd_rocm(self): + if find_spec("torch") is not None and "rocm" in version("torch"): + return True + else: + return False + + @cached_property + def gradio_version_is_old(self): + if find_spec("gradio") is not None and compare_module_version("gradio", "3.34.0") <= 0: + return True + else: + return False + + +ia_check_versions = IACheckVersions() diff --git a/ia_config.py b/ia_config.py new file mode 100644 index 0000000000000000000000000000000000000000..c8c74147436cd55b6632f213bf21cd138799833e --- /dev/null +++ b/ia_config.py @@ -0,0 +1,115 @@ +import configparser +# import json +import os +from types import SimpleNamespace + +from ia_ui_items import get_inp_model_ids, get_sam_model_ids + + +class IAConfig: + SECTIONS = SimpleNamespace( + DEFAULT=configparser.DEFAULTSECT, + USER="USER", + ) + + KEYS = SimpleNamespace( + SAM_MODEL_ID="sam_model_id", + INP_MODEL_ID="inp_model_id", + ) + + PATHS = SimpleNamespace( + INI=os.path.join(os.path.dirname(os.path.realpath(__file__)), "ia_config.ini"), + ) + + global_args = {} + + def __init__(self): + self.ids_dict = {} + self.ids_dict[IAConfig.KEYS.SAM_MODEL_ID] = { + "list": get_sam_model_ids(), + "index": 1, + } + self.ids_dict[IAConfig.KEYS.INP_MODEL_ID] = { + "list": get_inp_model_ids(), + "index": 0, + } + + +ia_config = IAConfig() + + +def setup_ia_config_ini(): + ia_config_ini = configparser.ConfigParser(defaults={}) + if os.path.isfile(IAConfig.PATHS.INI): + ia_config_ini.read(IAConfig.PATHS.INI, encoding="utf-8") + + changed = False + for key, ids_info in ia_config.ids_dict.items(): + if not ia_config_ini.has_option(IAConfig.SECTIONS.DEFAULT, key): + if len(ids_info["list"]) > ids_info["index"]: + ia_config_ini[IAConfig.SECTIONS.DEFAULT][key] = ids_info["list"][ids_info["index"]] + changed = True + else: + if len(ids_info["list"]) > ids_info["index"] and ia_config_ini[IAConfig.SECTIONS.DEFAULT][key] != ids_info["list"][ids_info["index"]]: + ia_config_ini[IAConfig.SECTIONS.DEFAULT][key] = ids_info["list"][ids_info["index"]] + changed = True + + if changed: + with open(IAConfig.PATHS.INI, "w", encoding="utf-8") as f: + ia_config_ini.write(f) + + +def get_ia_config(key, section=IAConfig.SECTIONS.DEFAULT): + setup_ia_config_ini() + + ia_config_ini = configparser.ConfigParser(defaults={}) + ia_config_ini.read(IAConfig.PATHS.INI, encoding="utf-8") + + if ia_config_ini.has_option(section, key): + return ia_config_ini[section][key] + + section = IAConfig.SECTIONS.DEFAULT + if ia_config_ini.has_option(section, key): + return ia_config_ini[section][key] + + return None + + +def get_ia_config_index(key, section=IAConfig.SECTIONS.DEFAULT): + value = get_ia_config(key, section) + + ids_dict = ia_config.ids_dict + if value is None: + if key in ids_dict.keys(): + ids_info = ids_dict[key] + return ids_info["index"] + else: + return 0 + else: + if key in ids_dict.keys(): + ids_info = ids_dict[key] + return ids_info["list"].index(value) if value in ids_info["list"] else ids_info["index"] + else: + return 0 + + +def set_ia_config(key, value, section=IAConfig.SECTIONS.DEFAULT): + setup_ia_config_ini() + + ia_config_ini = configparser.ConfigParser(defaults={}) + ia_config_ini.read(IAConfig.PATHS.INI, encoding="utf-8") + + if ia_config_ini.has_option(section, key) and ia_config_ini[section][key] == value: + return + + if section != IAConfig.SECTIONS.DEFAULT and not ia_config_ini.has_section(section): + ia_config_ini[section] = {} + + try: + ia_config_ini[section][key] = value + except Exception: + ia_config_ini[section] = {} + ia_config_ini[section][key] = value + + with open(IAConfig.PATHS.INI, "w", encoding="utf-8") as f: + ia_config_ini.write(f) diff --git a/ia_devices.py b/ia_devices.py new file mode 100644 index 0000000000000000000000000000000000000000..bbacbd186d55a22936e073d68408220e22ff1233 --- /dev/null +++ b/ia_devices.py @@ -0,0 +1,10 @@ +import torch + + +class TorchDevices: + def __init__(self): + self.cpu = torch.device("cpu") + self.device = torch.device("cuda") if torch.cuda.is_available() else self.cpu + + +devices = TorchDevices() diff --git a/ia_file_manager.py b/ia_file_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..59e0fbac6df0d8fce0762e3b85b76847e8ec39fe --- /dev/null +++ b/ia_file_manager.py @@ -0,0 +1,71 @@ +import os +from datetime import datetime +from huggingface_hub import snapshot_download +from ia_logging import ia_logging + + +class IAFileManager: + DOWNLOAD_COMPLETE = "Download complete" + + def __init__(self) -> None: + self._ia_outputs_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), + "outputs", + datetime.now().strftime("%Y-%m-%d")) + + self._ia_models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") + + @property + def outputs_dir(self) -> str: + """Get inpaint-anything outputs directory. + + Returns: + str: inpaint-anything outputs directory + """ + if not os.path.isdir(self._ia_outputs_dir): + os.makedirs(self._ia_outputs_dir, exist_ok=True) + return self._ia_outputs_dir + + @property + def models_dir(self) -> str: + """Get inpaint-anything models directory. + + Returns: + str: inpaint-anything models directory + """ + if not os.path.isdir(self._ia_models_dir): + os.makedirs(self._ia_models_dir, exist_ok=True) + return self._ia_models_dir + + @property + def savename_prefix(self) -> str: + """Get inpaint-anything savename prefix. + + Returns: + str: inpaint-anything savename prefix + """ + return datetime.now().strftime("%Y%m%d-%H%M%S") + + +ia_file_manager = IAFileManager() + + +def download_model_from_hf(hf_model_id, local_files_only=False): + """Download model from HuggingFace Hub. + + Args: + sam_model_id (str): HuggingFace model id + local_files_only (bool, optional): If True, use only local files. Defaults to False. + + Returns: + str: download status + """ + if not local_files_only: + ia_logging.info(f"Downloading {hf_model_id}") + try: + snapshot_download(repo_id=hf_model_id, local_files_only=local_files_only) + except FileNotFoundError: + return f"{hf_model_id} not found, please download" + except Exception as e: + return str(e) + + return IAFileManager.DOWNLOAD_COMPLETE diff --git a/ia_get_dataset_colormap.py b/ia_get_dataset_colormap.py new file mode 100644 index 0000000000000000000000000000000000000000..d3e35128e012ad431ecaba2cc3bf2116881b5dee --- /dev/null +++ b/ia_get_dataset_colormap.py @@ -0,0 +1,416 @@ +# Lint as: python2, python3 +# Copyright 2018 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Visualizes the segmentation results via specified color map. + +Visualizes the semantic segmentation results by the color map +defined by the different datasets. Supported colormaps are: + +* ADE20K (http://groups.csail.mit.edu/vision/datasets/ADE20K/). + +* Cityscapes dataset (https://www.cityscapes-dataset.com). + +* Mapillary Vistas (https://research.mapillary.com). + +* PASCAL VOC 2012 (http://host.robots.ox.ac.uk/pascal/VOC/). +""" + +from __future__ import absolute_import, division, print_function + +import numpy as np + +# from six.moves import range + +# Dataset names. +_ADE20K = 'ade20k' +_CITYSCAPES = 'cityscapes' +_MAPILLARY_VISTAS = 'mapillary_vistas' +_PASCAL = 'pascal' + +# Max number of entries in the colormap for each dataset. +_DATASET_MAX_ENTRIES = { + _ADE20K: 151, + _CITYSCAPES: 256, + _MAPILLARY_VISTAS: 66, + _PASCAL: 512, +} + + +def create_ade20k_label_colormap(): + """Creates a label colormap used in ADE20K segmentation benchmark. + + Returns: + A colormap for visualizing segmentation results. + """ + return np.asarray([ + [0, 0, 0], + [120, 120, 120], + [180, 120, 120], + [6, 230, 230], + [80, 50, 50], + [4, 200, 3], + [120, 120, 80], + [140, 140, 140], + [204, 5, 255], + [230, 230, 230], + [4, 250, 7], + [224, 5, 255], + [235, 255, 7], + [150, 5, 61], + [120, 120, 70], + [8, 255, 51], + [255, 6, 82], + [143, 255, 140], + [204, 255, 4], + [255, 51, 7], + [204, 70, 3], + [0, 102, 200], + [61, 230, 250], + [255, 6, 51], + [11, 102, 255], + [255, 7, 71], + [255, 9, 224], + [9, 7, 230], + [220, 220, 220], + [255, 9, 92], + [112, 9, 255], + [8, 255, 214], + [7, 255, 224], + [255, 184, 6], + [10, 255, 71], + [255, 41, 10], + [7, 255, 255], + [224, 255, 8], + [102, 8, 255], + [255, 61, 6], + [255, 194, 7], + [255, 122, 8], + [0, 255, 20], + [255, 8, 41], + [255, 5, 153], + [6, 51, 255], + [235, 12, 255], + [160, 150, 20], + [0, 163, 255], + [140, 140, 140], + [250, 10, 15], + [20, 255, 0], + [31, 255, 0], + [255, 31, 0], + [255, 224, 0], + [153, 255, 0], + [0, 0, 255], + [255, 71, 0], + [0, 235, 255], + [0, 173, 255], + [31, 0, 255], + [11, 200, 200], + [255, 82, 0], + [0, 255, 245], + [0, 61, 255], + [0, 255, 112], + [0, 255, 133], + [255, 0, 0], + [255, 163, 0], + [255, 102, 0], + [194, 255, 0], + [0, 143, 255], + [51, 255, 0], + [0, 82, 255], + [0, 255, 41], + [0, 255, 173], + [10, 0, 255], + [173, 255, 0], + [0, 255, 153], + [255, 92, 0], + [255, 0, 255], + [255, 0, 245], + [255, 0, 102], + [255, 173, 0], + [255, 0, 20], + [255, 184, 184], + [0, 31, 255], + [0, 255, 61], + [0, 71, 255], + [255, 0, 204], + [0, 255, 194], + [0, 255, 82], + [0, 10, 255], + [0, 112, 255], + [51, 0, 255], + [0, 194, 255], + [0, 122, 255], + [0, 255, 163], + [255, 153, 0], + [0, 255, 10], + [255, 112, 0], + [143, 255, 0], + [82, 0, 255], + [163, 255, 0], + [255, 235, 0], + [8, 184, 170], + [133, 0, 255], + [0, 255, 92], + [184, 0, 255], + [255, 0, 31], + [0, 184, 255], + [0, 214, 255], + [255, 0, 112], + [92, 255, 0], + [0, 224, 255], + [112, 224, 255], + [70, 184, 160], + [163, 0, 255], + [153, 0, 255], + [71, 255, 0], + [255, 0, 163], + [255, 204, 0], + [255, 0, 143], + [0, 255, 235], + [133, 255, 0], + [255, 0, 235], + [245, 0, 255], + [255, 0, 122], + [255, 245, 0], + [10, 190, 212], + [214, 255, 0], + [0, 204, 255], + [20, 0, 255], + [255, 255, 0], + [0, 153, 255], + [0, 41, 255], + [0, 255, 204], + [41, 0, 255], + [41, 255, 0], + [173, 0, 255], + [0, 245, 255], + [71, 0, 255], + [122, 0, 255], + [0, 255, 184], + [0, 92, 255], + [184, 255, 0], + [0, 133, 255], + [255, 214, 0], + [25, 194, 194], + [102, 255, 0], + [92, 0, 255], + ]) + + +def create_cityscapes_label_colormap(): + """Creates a label colormap used in CITYSCAPES segmentation benchmark. + + Returns: + A colormap for visualizing segmentation results. + """ + colormap = np.zeros((256, 3), dtype=np.uint8) + colormap[0] = [128, 64, 128] + colormap[1] = [244, 35, 232] + colormap[2] = [70, 70, 70] + colormap[3] = [102, 102, 156] + colormap[4] = [190, 153, 153] + colormap[5] = [153, 153, 153] + colormap[6] = [250, 170, 30] + colormap[7] = [220, 220, 0] + colormap[8] = [107, 142, 35] + colormap[9] = [152, 251, 152] + colormap[10] = [70, 130, 180] + colormap[11] = [220, 20, 60] + colormap[12] = [255, 0, 0] + colormap[13] = [0, 0, 142] + colormap[14] = [0, 0, 70] + colormap[15] = [0, 60, 100] + colormap[16] = [0, 80, 100] + colormap[17] = [0, 0, 230] + colormap[18] = [119, 11, 32] + return colormap + + +def create_mapillary_vistas_label_colormap(): + """Creates a label colormap used in Mapillary Vistas segmentation benchmark. + + Returns: + A colormap for visualizing segmentation results. + """ + return np.asarray([ + [165, 42, 42], + [0, 192, 0], + [196, 196, 196], + [190, 153, 153], + [180, 165, 180], + [102, 102, 156], + [102, 102, 156], + [128, 64, 255], + [140, 140, 200], + [170, 170, 170], + [250, 170, 160], + [96, 96, 96], + [230, 150, 140], + [128, 64, 128], + [110, 110, 110], + [244, 35, 232], + [150, 100, 100], + [70, 70, 70], + [150, 120, 90], + [220, 20, 60], + [255, 0, 0], + [255, 0, 0], + [255, 0, 0], + [200, 128, 128], + [255, 255, 255], + [64, 170, 64], + [128, 64, 64], + [70, 130, 180], + [255, 255, 255], + [152, 251, 152], + [107, 142, 35], + [0, 170, 30], + [255, 255, 128], + [250, 0, 30], + [0, 0, 0], + [220, 220, 220], + [170, 170, 170], + [222, 40, 40], + [100, 170, 30], + [40, 40, 40], + [33, 33, 33], + [170, 170, 170], + [0, 0, 142], + [170, 170, 170], + [210, 170, 100], + [153, 153, 153], + [128, 128, 128], + [0, 0, 142], + [250, 170, 30], + [192, 192, 192], + [220, 220, 0], + [180, 165, 180], + [119, 11, 32], + [0, 0, 142], + [0, 60, 100], + [0, 0, 142], + [0, 0, 90], + [0, 0, 230], + [0, 80, 100], + [128, 64, 64], + [0, 0, 110], + [0, 0, 70], + [0, 0, 192], + [32, 32, 32], + [0, 0, 0], + [0, 0, 0], + ]) + + +def create_pascal_label_colormap(): + """Creates a label colormap used in PASCAL VOC segmentation benchmark. + + Returns: + A colormap for visualizing segmentation results. + """ + colormap = np.zeros((_DATASET_MAX_ENTRIES[_PASCAL], 3), dtype=int) + ind = np.arange(_DATASET_MAX_ENTRIES[_PASCAL], dtype=int) + + for shift in reversed(list(range(8))): + for channel in range(3): + colormap[:, channel] |= bit_get(ind, channel) << shift + ind >>= 3 + + return colormap + + +def get_ade20k_name(): + return _ADE20K + + +def get_cityscapes_name(): + return _CITYSCAPES + + +def get_mapillary_vistas_name(): + return _MAPILLARY_VISTAS + + +def get_pascal_name(): + return _PASCAL + + +def bit_get(val, idx): + """Gets the bit value. + + Args: + val: Input value, int or numpy int array. + idx: Which bit of the input val. + + Returns: + The "idx"-th bit of input val. + """ + return (val >> idx) & 1 + + +def create_label_colormap(dataset=_PASCAL): + """Creates a label colormap for the specified dataset. + + Args: + dataset: The colormap used in the dataset. + + Returns: + A numpy array of the dataset colormap. + + Raises: + ValueError: If the dataset is not supported. + """ + if dataset == _ADE20K: + return create_ade20k_label_colormap() + elif dataset == _CITYSCAPES: + return create_cityscapes_label_colormap() + elif dataset == _MAPILLARY_VISTAS: + return create_mapillary_vistas_label_colormap() + elif dataset == _PASCAL: + return create_pascal_label_colormap() + else: + raise ValueError('Unsupported dataset.') + + +def label_to_color_image(label, dataset=_PASCAL): + """Adds color defined by the dataset colormap to the label. + + Args: + label: A 2D array with integer type, storing the segmentation label. + dataset: The colormap used in the dataset. + + Returns: + result: A 2D array with floating type. The element of the array + is the color indexed by the corresponding element in the input label + to the dataset color map. + + Raises: + ValueError: If label is not of rank 2 or its value is larger than color + map maximum entry. + """ + if label.ndim != 2: + raise ValueError('Expect 2-D input label. Got {}'.format(label.shape)) + + if np.max(label) >= _DATASET_MAX_ENTRIES[dataset]: + raise ValueError( + 'label value too large: {} >= {}.'.format( + np.max(label), _DATASET_MAX_ENTRIES[dataset])) + + colormap = create_label_colormap(dataset) + return colormap[label] + + +def get_dataset_colormap_max_entries(dataset): + return _DATASET_MAX_ENTRIES[dataset] diff --git a/ia_logging.py b/ia_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..c0cce539befb121f9397a3240cdd73f30089f84f --- /dev/null +++ b/ia_logging.py @@ -0,0 +1,14 @@ +import logging +import warnings + +warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers") +warnings.filterwarnings(action="ignore", category=FutureWarning, module="huggingface_hub") + +ia_logging = logging.getLogger("Inpaint Anything") +ia_logging.setLevel(logging.INFO) +ia_logging.propagate = False + +ia_logging_sh = logging.StreamHandler() +ia_logging_sh.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) +ia_logging_sh.setLevel(logging.INFO) +ia_logging.addHandler(ia_logging_sh) diff --git a/ia_sam_manager.py b/ia_sam_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..eefa387efa58efc1686fc91bd6a1be6319b731bb --- /dev/null +++ b/ia_sam_manager.py @@ -0,0 +1,182 @@ +import os +import platform +from functools import partial + +import torch + +from fast_sam import FastSamAutomaticMaskGenerator, fast_sam_model_registry +from ia_check_versions import ia_check_versions +from ia_config import IAConfig +from ia_devices import devices +from ia_logging import ia_logging +from mobile_sam import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorMobile +from mobile_sam import SamPredictor as SamPredictorMobile +from mobile_sam import sam_model_registry as sam_model_registry_mobile +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator +from sam2.build_sam import build_sam2 +from segment_anything_fb import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry +from segment_anything_hq import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorHQ +from segment_anything_hq import SamPredictor as SamPredictorHQ +from segment_anything_hq import sam_model_registry as sam_model_registry_hq + + +def check_bfloat16_support() -> bool: + if torch.cuda.is_available(): + compute_capability = torch.cuda.get_device_capability(torch.cuda.current_device()) + if compute_capability[0] >= 8: + ia_logging.debug("The CUDA device supports bfloat16") + return True + else: + ia_logging.debug("The CUDA device does not support bfloat16") + return False + else: + ia_logging.debug("CUDA is not available") + return False + + +def partial_from_end(func, /, *fixed_args, **fixed_kwargs): + def wrapper(*args, **kwargs): + updated_kwargs = {**fixed_kwargs, **kwargs} + return func(*args, *fixed_args, **updated_kwargs) + return wrapper + + +def rename_args(func, arg_map): + def wrapper(*args, **kwargs): + new_kwargs = {arg_map.get(k, k): v for k, v in kwargs.items()} + return func(*args, **new_kwargs) + return wrapper + + +arg_map = {"checkpoint": "ckpt_path"} +rename_build_sam2 = rename_args(build_sam2, arg_map) +end_kwargs = dict(device="cpu", mode="eval", hydra_overrides_extra=[], apply_postprocessing=False) +sam2_model_registry = { + "sam2_hiera_large": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_l.yaml"), + "sam2_hiera_base_plus": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_b+.yaml"), + "sam2_hiera_small": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_s.yaml"), + "sam2_hiera_tiny": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_t.yaml"), +} + + +def get_sam_mask_generator(sam_checkpoint, anime_style_chk=False): + """Get SAM mask generator. + + Args: + sam_checkpoint (str): SAM checkpoint path + + Returns: + SamAutomaticMaskGenerator or None: SAM mask generator + """ + points_per_batch = 64 + if "_hq_" in os.path.basename(sam_checkpoint): + model_type = os.path.basename(sam_checkpoint)[7:12] + sam_model_registry_local = sam_model_registry_hq + SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorHQ + points_per_batch = 32 + elif "FastSAM" in os.path.basename(sam_checkpoint): + model_type = os.path.splitext(os.path.basename(sam_checkpoint))[0] + sam_model_registry_local = fast_sam_model_registry + SamAutomaticMaskGeneratorLocal = FastSamAutomaticMaskGenerator + points_per_batch = None + elif "mobile_sam" in os.path.basename(sam_checkpoint): + model_type = "vit_t" + sam_model_registry_local = sam_model_registry_mobile + SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorMobile + points_per_batch = 64 + elif "sam2_" in os.path.basename(sam_checkpoint): + model_type = os.path.splitext(os.path.basename(sam_checkpoint))[0] + sam_model_registry_local = sam2_model_registry + SamAutomaticMaskGeneratorLocal = SAM2AutomaticMaskGenerator + points_per_batch = 128 + else: + model_type = os.path.basename(sam_checkpoint)[4:9] + sam_model_registry_local = sam_model_registry + SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGenerator + points_per_batch = 64 + + pred_iou_thresh = 0.88 if not anime_style_chk else 0.83 + stability_score_thresh = 0.95 if not anime_style_chk else 0.9 + + if "sam2_" in model_type: + pred_iou_thresh = round(pred_iou_thresh - 0.18, 2) + stability_score_thresh = round(stability_score_thresh - 0.03, 2) + sam2_gen_kwargs = dict( + points_per_side=64, + points_per_batch=points_per_batch, + pred_iou_thresh=pred_iou_thresh, + stability_score_thresh=stability_score_thresh, + stability_score_offset=0.7, + crop_n_layers=1, + box_nms_thresh=0.7, + crop_n_points_downscale_factor=2) + if platform.system() == "Darwin": + sam2_gen_kwargs.update(dict(points_per_side=32, points_per_batch=64, crop_n_points_downscale_factor=1)) + + if os.path.isfile(sam_checkpoint): + sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint) + if platform.system() == "Darwin": + if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available: + sam.to(device=torch.device("cpu")) + else: + sam.to(device=torch.device("mps")) + else: + if IAConfig.global_args.get("sam_cpu", False): + ia_logging.info("SAM is running on CPU... (the option has been selected)") + sam.to(device=devices.cpu) + else: + sam.to(device=devices.device) + sam_gen_kwargs = dict( + model=sam, points_per_batch=points_per_batch, pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh) + if "sam2_" in model_type: + sam_gen_kwargs.update(sam2_gen_kwargs) + sam_mask_generator = SamAutomaticMaskGeneratorLocal(**sam_gen_kwargs) + else: + sam_mask_generator = None + + return sam_mask_generator + + +def get_sam_predictor(sam_checkpoint): + """Get SAM predictor. + + Args: + sam_checkpoint (str): SAM checkpoint path + + Returns: + SamPredictor or None: SAM predictor + """ + # model_type = "vit_h" + if "_hq_" in os.path.basename(sam_checkpoint): + model_type = os.path.basename(sam_checkpoint)[7:12] + sam_model_registry_local = sam_model_registry_hq + SamPredictorLocal = SamPredictorHQ + elif "FastSAM" in os.path.basename(sam_checkpoint): + raise NotImplementedError("FastSAM predictor is not implemented yet.") + elif "mobile_sam" in os.path.basename(sam_checkpoint): + model_type = "vit_t" + sam_model_registry_local = sam_model_registry_mobile + SamPredictorLocal = SamPredictorMobile + else: + model_type = os.path.basename(sam_checkpoint)[4:9] + sam_model_registry_local = sam_model_registry + SamPredictorLocal = SamPredictor + + if os.path.isfile(sam_checkpoint): + sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint) + if platform.system() == "Darwin": + if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available: + sam.to(device=torch.device("cpu")) + else: + sam.to(device=torch.device("mps")) + else: + if IAConfig.global_args.get("sam_cpu", False): + ia_logging.info("SAM is running on CPU... (the option has been selected)") + sam.to(device=devices.cpu) + else: + sam.to(device=devices.device) + sam_predictor = SamPredictorLocal(sam) + else: + sam_predictor = None + + return sam_predictor diff --git a/ia_threading.py b/ia_threading.py new file mode 100644 index 0000000000000000000000000000000000000000..f5b0f15f45a06566b0e09ebbbc62e4ca1bd00b12 --- /dev/null +++ b/ia_threading.py @@ -0,0 +1,55 @@ +import gc +import inspect +import threading +from functools import wraps + +import torch + +from ia_check_versions import ia_check_versions + +model_access_sem = threading.Semaphore(1) + + +def torch_gc(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + if ia_check_versions.torch_mps_is_available: + if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): + torch.mps.empty_cache() + + +def clear_cache(): + gc.collect() + torch_gc() + + +def post_clear_cache(sem): + with sem: + gc.collect() + torch_gc() + + +def async_post_clear_cache(): + thread = threading.Thread(target=post_clear_cache, args=(model_access_sem,)) + thread.start() + + +def clear_cache_decorator(func): + @wraps(func) + def yield_wrapper(*args, **kwargs): + clear_cache() + yield from func(*args, **kwargs) + clear_cache() + + @wraps(func) + def wrapper(*args, **kwargs): + clear_cache() + res = func(*args, **kwargs) + clear_cache() + return res + + if inspect.isgeneratorfunction(func): + return yield_wrapper + else: + return wrapper diff --git a/ia_ui_gradio.py b/ia_ui_gradio.py new file mode 100644 index 0000000000000000000000000000000000000000..d23b568f0af90c2d4da698b19c7f237178fa5018 --- /dev/null +++ b/ia_ui_gradio.py @@ -0,0 +1,30 @@ +import os + +import gradio as gr + +GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse + + +def webpath(fn): + web_path = os.path.realpath(fn) + + return f'file={web_path}?{os.path.getmtime(fn)}' + + +def javascript_html(): + script_path = os.path.join(os.path.dirname(__file__), "javascript", "inpaint-anything.js") + head = f'\n' + + return head + + +def reload_javascript(): + js = javascript_html() + + def template_response(*args, **kwargs): + res = GradioTemplateResponseOriginal(*args, **kwargs) + res.body = res.body.replace(b'', f'{js}'.encode("utf8")) + res.init_headers() + return res + + gr.routes.templates.TemplateResponse = template_response diff --git a/ia_ui_items.py b/ia_ui_items.py new file mode 100644 index 0000000000000000000000000000000000000000..0244d7cfccc071f2b668d319b1303df81f4fad4a --- /dev/null +++ b/ia_ui_items.py @@ -0,0 +1,110 @@ +from huggingface_hub import scan_cache_dir + + +def get_sampler_names(): + """Get sampler name list. + + Returns: + list: sampler name list + """ + sampler_names = [ + "DDIM", + "Euler", + "Euler a", + "DPM2 Karras", + "DPM2 a Karras", + ] + return sampler_names + + +def get_sam_model_ids(): + """Get SAM model ids list. + + Returns: + list: SAM model ids list + """ + sam_model_ids = [ + "sam2_hiera_large.pt", + "sam2_hiera_base_plus.pt", + "sam2_hiera_small.pt", + "sam2_hiera_tiny.pt", + "sam_vit_h_4b8939.pth", + "sam_vit_l_0b3195.pth", + "sam_vit_b_01ec64.pth", + "sam_hq_vit_h.pth", + "sam_hq_vit_l.pth", + "sam_hq_vit_b.pth", + "FastSAM-x.pt", + "FastSAM-s.pt", + "mobile_sam.pt", + ] + return sam_model_ids + + +inp_list_from_cache = None + + +def get_inp_model_ids(): + """Get inpainting model ids list. + + Returns: + list: model ids list + """ + global inp_list_from_cache + model_ids = [ + "stabilityai/stable-diffusion-2-inpainting", + "Uminosachi/dreamshaper_8Inpainting", + "Uminosachi/deliberate_v3-inpainting", + "Uminosachi/realisticVisionV51_v51VAE-inpainting", + "Uminosachi/revAnimated_v121Inp-inpainting", + "runwayml/stable-diffusion-inpainting", + ] + if inp_list_from_cache is not None and isinstance(inp_list_from_cache, list): + model_ids.extend(inp_list_from_cache) + return model_ids + try: + hf_cache_info = scan_cache_dir() + inpaint_repos = [] + for repo in hf_cache_info.repos: + if repo.repo_type == "model" and "inpaint" in repo.repo_id.lower() and repo.repo_id not in model_ids: + inpaint_repos.append(repo.repo_id) + inp_list_from_cache = sorted(inpaint_repos, reverse=True, key=lambda x: x.split("/")[-1]) + model_ids.extend(inp_list_from_cache) + return model_ids + except Exception: + return model_ids + + +def get_cleaner_model_ids(): + """Get cleaner model ids list. + + Returns: + list: model ids list + """ + model_ids = [ + "lama", + "ldm", + "zits", + "mat", + "fcf", + "manga", + ] + return model_ids + + +def get_padding_mode_names(): + """Get padding mode name list. + + Returns: + list: padding mode name list + """ + padding_mode_names = [ + "constant", + "edge", + "reflect", + "mean", + "median", + "maximum", + "minimum", + ] + return padding_mode_names diff --git a/iasam_app.py b/iasam_app.py new file mode 100644 index 0000000000000000000000000000000000000000..bc5ebe4c50e0c661a20d9e3205002dbdc3f67673 --- /dev/null +++ b/iasam_app.py @@ -0,0 +1,809 @@ +import argparse +# import math +import gc +import os +import platform + +if platform.system() == "Darwin": + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +if platform.system() == "Windows": + os.environ["XFORMERS_FORCE_DISABLE_TRITON"] = "1" + +import random +import traceback +from importlib.util import find_spec + +import cv2 +import gradio as gr +import numpy as np +import torch +from diffusers import (DDIMScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, + KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, + StableDiffusionInpaintPipeline) +from PIL import Image, ImageFilter +from PIL.PngImagePlugin import PngInfo +from torch.hub import download_url_to_file +from torchvision import transforms + +import inpalib +from ia_check_versions import ia_check_versions +from ia_config import IAConfig, get_ia_config_index, set_ia_config, setup_ia_config_ini +from ia_devices import devices +from ia_file_manager import IAFileManager, download_model_from_hf, ia_file_manager +from ia_logging import ia_logging +from ia_threading import clear_cache_decorator +from ia_ui_gradio import reload_javascript +from ia_ui_items import (get_cleaner_model_ids, get_inp_model_ids, get_padding_mode_names, + get_sam_model_ids, get_sampler_names) +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler + +print("platform:", platform.system()) + +reload_javascript() + +if find_spec("xformers") is not None: + xformers_available = True +else: + xformers_available = False + +parser = argparse.ArgumentParser(description="Inpaint Anything") +parser.add_argument("--save-seg", action="store_true", help="Save the segmentation image generated by SAM.") +parser.add_argument("--offline", action="store_true", help="Execute inpainting using an offline network.") +parser.add_argument("--sam-cpu", action="store_true", help="Perform the Segment Anything operation on CPU.") +args = parser.parse_args() +IAConfig.global_args.update(args.__dict__) + + +@clear_cache_decorator +def download_model(sam_model_id): + """Download SAM model. + + Args: + sam_model_id (str): SAM model id + + Returns: + str: download status + """ + if "_hq_" in sam_model_id: + url_sam = "https://huggingface.co/Uminosachi/sam-hq/resolve/main/" + sam_model_id + elif "FastSAM" in sam_model_id: + url_sam = "https://huggingface.co/Uminosachi/FastSAM/resolve/main/" + sam_model_id + elif "mobile_sam" in sam_model_id: + url_sam = "https://huggingface.co/Uminosachi/MobileSAM/resolve/main/" + sam_model_id + elif "sam2_" in sam_model_id: + url_sam = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/" + sam_model_id + else: + url_sam = "https://dl.fbaipublicfiles.com/segment_anything/" + sam_model_id + + sam_checkpoint = os.path.join(ia_file_manager.models_dir, sam_model_id) + if not os.path.isfile(sam_checkpoint): + try: + download_url_to_file(url_sam, sam_checkpoint) + except Exception as e: + ia_logging.error(str(e)) + return str(e) + + return IAFileManager.DOWNLOAD_COMPLETE + else: + return "Model already exists" + + +sam_dict = dict(sam_masks=None, mask_image=None, cnet=None, orig_image=None, pad_mask=None) + + +def save_mask_image(mask_image, save_mask_chk=False): + """Save mask image. + + Args: + mask_image (np.ndarray): mask image + save_mask_chk (bool, optional): If True, save mask image. Defaults to False. + + Returns: + None + """ + if save_mask_chk: + save_name = "_".join([ia_file_manager.savename_prefix, "created_mask"]) + ".png" + save_name = os.path.join(ia_file_manager.outputs_dir, save_name) + Image.fromarray(mask_image).save(save_name) + + +@clear_cache_decorator +def input_image_upload(input_image, sam_image, sel_mask): + global sam_dict + sam_dict["orig_image"] = input_image + sam_dict["pad_mask"] = None + + if (sam_dict["mask_image"] is None or not isinstance(sam_dict["mask_image"], np.ndarray) or + sam_dict["mask_image"].shape != input_image.shape): + sam_dict["mask_image"] = np.zeros_like(input_image, dtype=np.uint8) + + ret_sel_image = cv2.addWeighted(input_image, 0.5, sam_dict["mask_image"], 0.5, 0) + + if sam_image is None or not isinstance(sam_image, dict) or "image" not in sam_image: + sam_dict["sam_masks"] = None + ret_sam_image = np.zeros_like(input_image, dtype=np.uint8) + elif sam_image["image"].shape == input_image.shape: + ret_sam_image = gr.update() + else: + sam_dict["sam_masks"] = None + ret_sam_image = gr.update(value=np.zeros_like(input_image, dtype=np.uint8)) + + if sel_mask is None or not isinstance(sel_mask, dict) or "image" not in sel_mask: + ret_sel_mask = ret_sel_image + elif sel_mask["image"].shape == ret_sel_image.shape and np.all(sel_mask["image"] == ret_sel_image): + ret_sel_mask = gr.update() + else: + ret_sel_mask = gr.update(value=ret_sel_image) + + return ret_sam_image, ret_sel_mask, gr.update(interactive=True) + + +@clear_cache_decorator +def run_padding(input_image, pad_scale_width, pad_scale_height, pad_lr_barance, pad_tb_barance, padding_mode="edge"): + global sam_dict + if input_image is None or sam_dict["orig_image"] is None: + sam_dict["orig_image"] = None + sam_dict["pad_mask"] = None + return None, "Input image not found" + + orig_image = sam_dict["orig_image"] + + height, width = orig_image.shape[:2] + pad_width, pad_height = (int(width * pad_scale_width), int(height * pad_scale_height)) + ia_logging.info(f"resize by padding: ({height}, {width}) -> ({pad_height}, {pad_width})") + + pad_size_w, pad_size_h = (pad_width - width, pad_height - height) + pad_size_l = int(pad_size_w * pad_lr_barance) + pad_size_r = pad_size_w - pad_size_l + pad_size_t = int(pad_size_h * pad_tb_barance) + pad_size_b = pad_size_h - pad_size_t + + pad_width = [(pad_size_t, pad_size_b), (pad_size_l, pad_size_r), (0, 0)] + if padding_mode == "constant": + fill_value = 127 + pad_image = np.pad(orig_image, pad_width=pad_width, mode=padding_mode, constant_values=fill_value) + else: + pad_image = np.pad(orig_image, pad_width=pad_width, mode=padding_mode) + + mask_pad_width = [(pad_size_t, pad_size_b), (pad_size_l, pad_size_r)] + pad_mask = np.zeros((height, width), dtype=np.uint8) + pad_mask = np.pad(pad_mask, pad_width=mask_pad_width, mode="constant", constant_values=255) + sam_dict["pad_mask"] = dict(segmentation=pad_mask.astype(bool)) + + return pad_image, "Padding done" + + +@clear_cache_decorator +def run_sam(input_image, sam_model_id, sam_image, anime_style_chk=False): + global sam_dict + if not inpalib.sam_file_exists(sam_model_id): + ret_sam_image = None if sam_image is None else gr.update() + return ret_sam_image, f"{sam_model_id} not found, please download" + + if input_image is None: + ret_sam_image = None if sam_image is None else gr.update() + return ret_sam_image, "Input image not found" + + set_ia_config(IAConfig.KEYS.SAM_MODEL_ID, sam_model_id, IAConfig.SECTIONS.USER) + + if sam_dict["sam_masks"] is not None: + sam_dict["sam_masks"] = None + gc.collect() + + ia_logging.info(f"input_image: {input_image.shape} {input_image.dtype}") + + try: + sam_masks = inpalib.generate_sam_masks(input_image, sam_model_id, anime_style_chk) + sam_masks = inpalib.sort_masks_by_area(sam_masks) + sam_masks = inpalib.insert_mask_to_sam_masks(sam_masks, sam_dict["pad_mask"]) + + seg_image = inpalib.create_seg_color_image(input_image, sam_masks) + + sam_dict["sam_masks"] = sam_masks + + except Exception as e: + print(traceback.format_exc()) + ia_logging.error(str(e)) + ret_sam_image = None if sam_image is None else gr.update() + return ret_sam_image, "Segment Anything failed" + + if IAConfig.global_args.get("save_seg", False): + save_name = "_".join([ia_file_manager.savename_prefix, os.path.splitext(sam_model_id)[0]]) + ".png" + save_name = os.path.join(ia_file_manager.outputs_dir, save_name) + Image.fromarray(seg_image).save(save_name) + + if sam_image is None: + return seg_image, "Segment Anything complete" + else: + if sam_image["image"].shape == seg_image.shape and np.all(sam_image["image"] == seg_image): + return gr.update(), "Segment Anything complete" + else: + return gr.update(value=seg_image), "Segment Anything complete" + + +@clear_cache_decorator +def select_mask(input_image, sam_image, invert_chk, ignore_black_chk, sel_mask): + global sam_dict + if sam_dict["sam_masks"] is None or sam_image is None: + ret_sel_mask = None if sel_mask is None else gr.update() + return ret_sel_mask + sam_masks = sam_dict["sam_masks"] + + # image = sam_image["image"] + mask = sam_image["mask"][:, :, 0:1] + + try: + seg_image = inpalib.create_mask_image(mask, sam_masks, ignore_black_chk) + if invert_chk: + seg_image = inpalib.invert_mask(seg_image) + + sam_dict["mask_image"] = seg_image + + except Exception as e: + print(traceback.format_exc()) + ia_logging.error(str(e)) + ret_sel_mask = None if sel_mask is None else gr.update() + return ret_sel_mask + + if input_image is not None and input_image.shape == seg_image.shape: + ret_image = cv2.addWeighted(input_image, 0.5, seg_image, 0.5, 0) + else: + ret_image = seg_image + + if sel_mask is None: + return ret_image + else: + if sel_mask["image"].shape == ret_image.shape and np.all(sel_mask["image"] == ret_image): + return gr.update() + else: + return gr.update(value=ret_image) + + +@clear_cache_decorator +def expand_mask(input_image, sel_mask, expand_iteration=1): + global sam_dict + if sam_dict["mask_image"] is None or sel_mask is None: + return None + + new_sel_mask = sam_dict["mask_image"] + + expand_iteration = int(np.clip(expand_iteration, 1, 100)) + + new_sel_mask = cv2.dilate(new_sel_mask, np.ones((3, 3), dtype=np.uint8), iterations=expand_iteration) + + sam_dict["mask_image"] = new_sel_mask + + if input_image is not None and input_image.shape == new_sel_mask.shape: + ret_image = cv2.addWeighted(input_image, 0.5, new_sel_mask, 0.5, 0) + else: + ret_image = new_sel_mask + + if sel_mask["image"].shape == ret_image.shape and np.all(sel_mask["image"] == ret_image): + return gr.update() + else: + return gr.update(value=ret_image) + + +@clear_cache_decorator +def apply_mask(input_image, sel_mask): + global sam_dict + if sam_dict["mask_image"] is None or sel_mask is None: + return None + + sel_mask_image = sam_dict["mask_image"] + sel_mask_mask = np.logical_not(sel_mask["mask"][:, :, 0:3].astype(bool)).astype(np.uint8) + new_sel_mask = sel_mask_image * sel_mask_mask + + sam_dict["mask_image"] = new_sel_mask + + if input_image is not None and input_image.shape == new_sel_mask.shape: + ret_image = cv2.addWeighted(input_image, 0.5, new_sel_mask, 0.5, 0) + else: + ret_image = new_sel_mask + + if sel_mask["image"].shape == ret_image.shape and np.all(sel_mask["image"] == ret_image): + return gr.update() + else: + return gr.update(value=ret_image) + + +@clear_cache_decorator +def add_mask(input_image, sel_mask): + global sam_dict + if sam_dict["mask_image"] is None or sel_mask is None: + return None + + sel_mask_image = sam_dict["mask_image"] + sel_mask_mask = sel_mask["mask"][:, :, 0:3].astype(bool).astype(np.uint8) + new_sel_mask = sel_mask_image + (sel_mask_mask * np.invert(sel_mask_image, dtype=np.uint8)) + + sam_dict["mask_image"] = new_sel_mask + + if input_image is not None and input_image.shape == new_sel_mask.shape: + ret_image = cv2.addWeighted(input_image, 0.5, new_sel_mask, 0.5, 0) + else: + ret_image = new_sel_mask + + if sel_mask["image"].shape == ret_image.shape and np.all(sel_mask["image"] == ret_image): + return gr.update() + else: + return gr.update(value=ret_image) + + +def auto_resize_to_pil(input_image, mask_image): + init_image = Image.fromarray(input_image).convert("RGB") + mask_image = Image.fromarray(mask_image).convert("RGB") + assert init_image.size == mask_image.size, "The sizes of the image and mask do not match" + width, height = init_image.size + + new_height = (height // 8) * 8 + new_width = (width // 8) * 8 + if new_width < width or new_height < height: + if (new_width / width) < (new_height / height): + scale = new_height / height + else: + scale = new_width / width + resize_height = int(height*scale+0.5) + resize_width = int(width*scale+0.5) + if height != resize_height or width != resize_width: + ia_logging.info(f"resize: ({height}, {width}) -> ({resize_height}, {resize_width})") + init_image = transforms.functional.resize(init_image, (resize_height, resize_width), transforms.InterpolationMode.LANCZOS) + mask_image = transforms.functional.resize(mask_image, (resize_height, resize_width), transforms.InterpolationMode.LANCZOS) + if resize_height != new_height or resize_width != new_width: + ia_logging.info(f"center_crop: ({resize_height}, {resize_width}) -> ({new_height}, {new_width})") + init_image = transforms.functional.center_crop(init_image, (new_height, new_width)) + mask_image = transforms.functional.center_crop(mask_image, (new_height, new_width)) + + return init_image, mask_image + + +@clear_cache_decorator +def run_inpaint(input_image, sel_mask, prompt, n_prompt, ddim_steps, cfg_scale, seed, inp_model_id, save_mask_chk, composite_chk, + sampler_name="DDIM", iteration_count=1): + global sam_dict + if input_image is None or sam_dict["mask_image"] is None or sel_mask is None: + ia_logging.error("The image or mask does not exist") + return + + mask_image = sam_dict["mask_image"] + if input_image.shape != mask_image.shape: + ia_logging.error("The sizes of the image and mask do not match") + return + + set_ia_config(IAConfig.KEYS.INP_MODEL_ID, inp_model_id, IAConfig.SECTIONS.USER) + + save_mask_image(mask_image, save_mask_chk) + + ia_logging.info(f"Loading model {inp_model_id}") + config_offline_inpainting = IAConfig.global_args.get("offline", False) + if config_offline_inpainting: + ia_logging.info("Run Inpainting on offline network: {}".format(str(config_offline_inpainting))) + local_files_only = False + local_file_status = download_model_from_hf(inp_model_id, local_files_only=True) + if local_file_status != IAFileManager.DOWNLOAD_COMPLETE: + if config_offline_inpainting: + ia_logging.warning(local_file_status) + return + else: + local_files_only = True + ia_logging.info("local_files_only: {}".format(str(local_files_only))) + + if platform.system() == "Darwin" or devices.device == devices.cpu or ia_check_versions.torch_on_amd_rocm: + torch_dtype = torch.float32 + else: + torch_dtype = torch.float16 + + try: + pipe = StableDiffusionInpaintPipeline.from_pretrained( + inp_model_id, torch_dtype=torch_dtype, local_files_only=local_files_only, use_safetensors=True) + except Exception as e: + ia_logging.error(str(e)) + if not config_offline_inpainting: + try: + pipe = StableDiffusionInpaintPipeline.from_pretrained( + inp_model_id, torch_dtype=torch_dtype, use_safetensors=True) + except Exception as e: + ia_logging.error(str(e)) + try: + pipe = StableDiffusionInpaintPipeline.from_pretrained( + inp_model_id, torch_dtype=torch_dtype, force_download=True, use_safetensors=True) + except Exception as e: + ia_logging.error(str(e)) + return + else: + return + pipe.safety_checker = None + + ia_logging.info(f"Using sampler {sampler_name}") + if sampler_name == "DDIM": + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + elif sampler_name == "Euler": + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + elif sampler_name == "Euler a": + pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) + elif sampler_name == "DPM2 Karras": + pipe.scheduler = KDPM2DiscreteScheduler.from_config(pipe.scheduler.config) + elif sampler_name == "DPM2 a Karras": + pipe.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipe.scheduler.config) + else: + ia_logging.info("Sampler fallback to DDIM") + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + + if platform.system() == "Darwin": + pipe = pipe.to("mps" if ia_check_versions.torch_mps_is_available else "cpu") + pipe.enable_attention_slicing() + torch_generator = torch.Generator(devices.cpu) + else: + if ia_check_versions.diffusers_enable_cpu_offload and devices.device != devices.cpu: + ia_logging.info("Enable model cpu offload") + pipe.enable_model_cpu_offload() + else: + pipe = pipe.to(devices.device) + if xformers_available: + ia_logging.info("Enable xformers memory efficient attention") + pipe.enable_xformers_memory_efficient_attention() + else: + ia_logging.info("Enable attention slicing") + pipe.enable_attention_slicing() + if "privateuseone" in str(getattr(devices.device, "type", "")): + torch_generator = torch.Generator(devices.cpu) + else: + torch_generator = torch.Generator(devices.device) + + init_image, mask_image = auto_resize_to_pil(input_image, mask_image) + width, height = init_image.size + + output_list = [] + iteration_count = iteration_count if iteration_count is not None else 1 + for count in range(int(iteration_count)): + gc.collect() + if seed < 0 or count > 0: + seed = random.randint(0, 2147483647) + + generator = torch_generator.manual_seed(seed) + + pipe_args_dict = { + "prompt": prompt, + "image": init_image, + "width": width, + "height": height, + "mask_image": mask_image, + "num_inference_steps": ddim_steps, + "guidance_scale": cfg_scale, + "negative_prompt": n_prompt, + "generator": generator, + } + + output_image = pipe(**pipe_args_dict).images[0] + + if composite_chk: + dilate_mask_image = Image.fromarray(cv2.dilate(np.array(mask_image), np.ones((3, 3), dtype=np.uint8), iterations=4)) + output_image = Image.composite(output_image, init_image, dilate_mask_image.convert("L").filter(ImageFilter.GaussianBlur(3))) + + generation_params = { + "Steps": ddim_steps, + "Sampler": sampler_name, + "CFG scale": cfg_scale, + "Seed": seed, + "Size": f"{width}x{height}", + "Model": inp_model_id, + } + + generation_params_text = ", ".join([k if k == v else f"{k}: {v}" for k, v in generation_params.items() if v is not None]) + prompt_text = prompt if prompt else "" + negative_prompt_text = "\nNegative prompt: " + n_prompt if n_prompt else "" + infotext = f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip() + + metadata = PngInfo() + metadata.add_text("parameters", infotext) + + save_name = "_".join([ia_file_manager.savename_prefix, os.path.basename(inp_model_id), str(seed)]) + ".png" + save_name = os.path.join(ia_file_manager.outputs_dir, save_name) + output_image.save(save_name, pnginfo=metadata) + + output_list.append(output_image) + + yield output_list, max([1, iteration_count - (count + 1)]) + + +@clear_cache_decorator +def run_cleaner(input_image, sel_mask, cleaner_model_id, cleaner_save_mask_chk): + global sam_dict + if input_image is None or sam_dict["mask_image"] is None or sel_mask is None: + ia_logging.error("The image or mask does not exist") + return None + + mask_image = sam_dict["mask_image"] + if input_image.shape != mask_image.shape: + ia_logging.error("The sizes of the image and mask do not match") + return None + + save_mask_image(mask_image, cleaner_save_mask_chk) + + ia_logging.info(f"Loading model {cleaner_model_id}") + if platform.system() == "Darwin": + model = ModelManager(name=cleaner_model_id, device=devices.cpu) + else: + model = ModelManager(name=cleaner_model_id, device=devices.device) + + init_image, mask_image = auto_resize_to_pil(input_image, mask_image) + width, height = init_image.size + + init_image = np.array(init_image) + mask_image = np.array(mask_image.convert("L")) + + config = Config( + ldm_steps=20, + ldm_sampler=LDMSampler.ddim, + hd_strategy=HDStrategy.ORIGINAL, + hd_strategy_crop_margin=32, + hd_strategy_crop_trigger_size=512, + hd_strategy_resize_limit=512, + prompt="", + sd_steps=20, + sd_sampler=SDSampler.ddim + ) + + output_image = model(image=init_image, mask=mask_image, config=config) + output_image = cv2.cvtColor(output_image.astype(np.uint8), cv2.COLOR_BGR2RGB) + output_image = Image.fromarray(output_image) + + save_name = "_".join([ia_file_manager.savename_prefix, os.path.basename(cleaner_model_id)]) + ".png" + save_name = os.path.join(ia_file_manager.outputs_dir, save_name) + output_image.save(save_name) + + del model + return [output_image] + + +@clear_cache_decorator +def run_get_alpha_image(input_image, sel_mask): + global sam_dict + if input_image is None or sam_dict["mask_image"] is None or sel_mask is None: + ia_logging.error("The image or mask does not exist") + return None, "" + + mask_image = sam_dict["mask_image"] + if input_image.shape != mask_image.shape: + ia_logging.error("The sizes of the image and mask do not match") + return None, "" + + alpha_image = Image.fromarray(input_image).convert("RGBA") + mask_image = Image.fromarray(mask_image).convert("L") + + alpha_image.putalpha(mask_image) + + save_name = "_".join([ia_file_manager.savename_prefix, "rgba_image"]) + ".png" + save_name = os.path.join(ia_file_manager.outputs_dir, save_name) + alpha_image.save(save_name) + + return alpha_image, f"saved: {save_name}" + + +@clear_cache_decorator +def run_get_mask(sel_mask): + global sam_dict + if sam_dict["mask_image"] is None or sel_mask is None: + return None + + mask_image = sam_dict["mask_image"] + + save_name = "_".join([ia_file_manager.savename_prefix, "created_mask"]) + ".png" + save_name = os.path.join(ia_file_manager.outputs_dir, save_name) + Image.fromarray(mask_image).save(save_name) + + return mask_image + + +def on_ui_tabs(): + setup_ia_config_ini() + sampler_names = get_sampler_names() + sam_model_ids = get_sam_model_ids() + sam_model_index = get_ia_config_index(IAConfig.KEYS.SAM_MODEL_ID, IAConfig.SECTIONS.USER) + inp_model_ids = get_inp_model_ids() + inp_model_index = get_ia_config_index(IAConfig.KEYS.INP_MODEL_ID, IAConfig.SECTIONS.USER) + cleaner_model_ids = get_cleaner_model_ids() + padding_mode_names = get_padding_mode_names() + + out_gallery_kwargs = dict(columns=2, height=520, object_fit="contain", preview=True) + + block = gr.Blocks(analytics_enabled=False).queue() + block.title = "Inpaint Anything" + with block as inpaint_anything_interface: + with gr.Row(): + gr.Markdown("## Inpainting with Segment Anything") + with gr.Row(): + with gr.Column(): + with gr.Row(): + with gr.Column(): + sam_model_id = gr.Dropdown(label="Segment Anything Model ID", elem_id="sam_model_id", choices=sam_model_ids, + value=sam_model_ids[sam_model_index], show_label=True) + with gr.Column(): + with gr.Row(): + load_model_btn = gr.Button("Download model", elem_id="load_model_btn") + with gr.Row(): + status_text = gr.Textbox(label="", elem_id="status_text", max_lines=1, show_label=False, interactive=False) + with gr.Row(): + input_image = gr.Image(label="Input image", elem_id="ia_input_image", source="upload", type="numpy", interactive=True) + + with gr.Row(): + with gr.Accordion("Padding options", elem_id="padding_options", open=False): + with gr.Row(): + with gr.Column(): + pad_scale_width = gr.Slider(label="Scale Width", elem_id="pad_scale_width", minimum=1.0, maximum=1.5, value=1.0, step=0.01) + with gr.Column(): + pad_lr_barance = gr.Slider(label="Left/Right Balance", elem_id="pad_lr_barance", minimum=0.0, maximum=1.0, value=0.5, step=0.01) + with gr.Row(): + with gr.Column(): + pad_scale_height = gr.Slider(label="Scale Height", elem_id="pad_scale_height", minimum=1.0, maximum=1.5, value=1.0, step=0.01) + with gr.Column(): + pad_tb_barance = gr.Slider(label="Top/Bottom Balance", elem_id="pad_tb_barance", minimum=0.0, maximum=1.0, value=0.5, step=0.01) + with gr.Row(): + with gr.Column(): + padding_mode = gr.Dropdown(label="Padding Mode", elem_id="padding_mode", choices=padding_mode_names, value="edge") + with gr.Column(): + padding_btn = gr.Button("Run Padding", elem_id="padding_btn") + + with gr.Row(): + with gr.Column(): + anime_style_chk = gr.Checkbox(label="Anime Style (Up Detection, Down mask Quality)", elem_id="anime_style_chk", + show_label=True, interactive=True) + with gr.Column(): + sam_btn = gr.Button("Run Segment Anything", elem_id="sam_btn", variant="primary", interactive=False) + + with gr.Tab("Inpainting", elem_id="inpainting_tab"): + prompt = gr.Textbox(label="Inpainting Prompt", elem_id="sd_prompt") + n_prompt = gr.Textbox(label="Negative Prompt", elem_id="sd_n_prompt") + with gr.Accordion("Advanced options", elem_id="inp_advanced_options", open=False): + composite_chk = gr.Checkbox(label="Mask area Only", elem_id="composite_chk", value=True, show_label=True, interactive=True) + with gr.Row(): + with gr.Column(): + sampler_name = gr.Dropdown(label="Sampler", elem_id="sampler_name", choices=sampler_names, + value=sampler_names[0], show_label=True) + with gr.Column(): + ddim_steps = gr.Slider(label="Sampling Steps", elem_id="ddim_steps", minimum=1, maximum=100, value=20, step=1) + cfg_scale = gr.Slider(label="Guidance Scale", elem_id="cfg_scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1) + seed = gr.Slider( + label="Seed", + elem_id="sd_seed", + minimum=-1, + maximum=2147483647, + step=1, + value=-1, + ) + with gr.Row(): + with gr.Column(): + inp_model_id = gr.Dropdown(label="Inpainting Model ID", elem_id="inp_model_id", + choices=inp_model_ids, value=inp_model_ids[inp_model_index], show_label=True) + with gr.Column(): + with gr.Row(): + inpaint_btn = gr.Button("Run Inpainting", elem_id="inpaint_btn", variant="primary") + with gr.Row(): + save_mask_chk = gr.Checkbox(label="Save mask", elem_id="save_mask_chk", + value=False, show_label=False, interactive=False, visible=False) + iteration_count = gr.Slider(label="Iterations", elem_id="iteration_count", minimum=1, maximum=10, value=1, step=1) + + with gr.Row(): + if ia_check_versions.gradio_version_is_old: + out_image = gr.Gallery(label="Inpainted image", elem_id="ia_out_image", show_label=False + ).style(**out_gallery_kwargs) + else: + out_image = gr.Gallery(label="Inpainted image", elem_id="ia_out_image", show_label=False, + **out_gallery_kwargs) + + with gr.Tab("Cleaner", elem_id="cleaner_tab"): + with gr.Row(): + with gr.Column(): + cleaner_model_id = gr.Dropdown(label="Cleaner Model ID", elem_id="cleaner_model_id", + choices=cleaner_model_ids, value=cleaner_model_ids[0], show_label=True) + with gr.Column(): + with gr.Row(): + cleaner_btn = gr.Button("Run Cleaner", elem_id="cleaner_btn", variant="primary") + with gr.Row(): + cleaner_save_mask_chk = gr.Checkbox(label="Save mask", elem_id="cleaner_save_mask_chk", + value=False, show_label=False, interactive=False, visible=False) + + with gr.Row(): + if ia_check_versions.gradio_version_is_old: + cleaner_out_image = gr.Gallery(label="Cleaned image", elem_id="ia_cleaner_out_image", show_label=False + ).style(**out_gallery_kwargs) + else: + cleaner_out_image = gr.Gallery(label="Cleaned image", elem_id="ia_cleaner_out_image", show_label=False, + **out_gallery_kwargs) + + with gr.Tab("Mask only", elem_id="mask_only_tab"): + with gr.Row(): + with gr.Column(): + get_alpha_image_btn = gr.Button("Get mask as alpha of image", elem_id="get_alpha_image_btn") + with gr.Column(): + get_mask_btn = gr.Button("Get mask", elem_id="get_mask_btn") + + with gr.Row(): + with gr.Column(): + alpha_out_image = gr.Image(label="Alpha channel image", elem_id="alpha_out_image", type="pil", image_mode="RGBA", interactive=False) + with gr.Column(): + mask_out_image = gr.Image(label="Mask image", elem_id="mask_out_image", type="numpy", interactive=False) + + with gr.Row(): + with gr.Column(): + get_alpha_status_text = gr.Textbox(label="", elem_id="get_alpha_status_text", max_lines=1, show_label=False, interactive=False) + with gr.Column(): + gr.Markdown("") + + with gr.Column(): + with gr.Row(): + gr.Markdown("Mouse over image: Press `S` key for Fullscreen mode, `R` key to Reset zoom") + with gr.Row(): + if ia_check_versions.gradio_version_is_old: + sam_image = gr.Image(label="Segment Anything image", elem_id="ia_sam_image", type="numpy", tool="sketch", brush_radius=8, + show_label=False, interactive=True).style(height=480) + else: + sam_image = gr.Image(label="Segment Anything image", elem_id="ia_sam_image", type="numpy", tool="sketch", brush_radius=8, + show_label=False, interactive=True, height=480) + + with gr.Row(): + with gr.Column(): + select_btn = gr.Button("Create Mask", elem_id="select_btn", variant="primary") + with gr.Column(): + with gr.Row(): + invert_chk = gr.Checkbox(label="Invert mask", elem_id="invert_chk", show_label=True, interactive=True) + ignore_black_chk = gr.Checkbox(label="Ignore black area", elem_id="ignore_black_chk", value=True, show_label=True, interactive=True) + + with gr.Row(): + if ia_check_versions.gradio_version_is_old: + sel_mask = gr.Image(label="Selected mask image", elem_id="ia_sel_mask", type="numpy", tool="sketch", brush_radius=12, + show_label=False, interactive=True).style(height=480) + else: + sel_mask = gr.Image(label="Selected mask image", elem_id="ia_sel_mask", type="numpy", tool="sketch", brush_radius=12, + show_label=False, interactive=True, height=480) + + with gr.Row(): + with gr.Column(): + expand_mask_btn = gr.Button("Expand mask region", elem_id="expand_mask_btn") + expand_mask_iteration_count = gr.Slider(label="Expand Mask Iterations", + elem_id="expand_mask_iteration_count", minimum=1, maximum=100, value=1, step=1) + with gr.Column(): + apply_mask_btn = gr.Button("Trim mask by sketch", elem_id="apply_mask_btn") + add_mask_btn = gr.Button("Add mask by sketch", elem_id="add_mask_btn") + + load_model_btn.click(download_model, inputs=[sam_model_id], outputs=[status_text]) + input_image.upload(input_image_upload, inputs=[input_image, sam_image, sel_mask], outputs=[sam_image, sel_mask, sam_btn]).then( + fn=None, inputs=None, outputs=None, _js="inpaintAnything_initSamSelMask") + padding_btn.click(run_padding, inputs=[input_image, pad_scale_width, pad_scale_height, pad_lr_barance, pad_tb_barance, padding_mode], + outputs=[input_image, status_text]) + sam_btn.click(run_sam, inputs=[input_image, sam_model_id, sam_image, anime_style_chk], outputs=[sam_image, status_text]).then( + fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSamMask") + select_btn.click(select_mask, inputs=[input_image, sam_image, invert_chk, ignore_black_chk, sel_mask], outputs=[sel_mask]).then( + fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSelMask") + expand_mask_btn.click(expand_mask, inputs=[input_image, sel_mask, expand_mask_iteration_count], outputs=[sel_mask]).then( + fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSelMask") + apply_mask_btn.click(apply_mask, inputs=[input_image, sel_mask], outputs=[sel_mask]).then( + fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSelMask") + add_mask_btn.click(add_mask, inputs=[input_image, sel_mask], outputs=[sel_mask]).then( + fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSelMask") + + inpaint_btn.click( + run_inpaint, + inputs=[input_image, sel_mask, prompt, n_prompt, ddim_steps, cfg_scale, seed, inp_model_id, save_mask_chk, composite_chk, + sampler_name, iteration_count], + outputs=[out_image, iteration_count]) + cleaner_btn.click( + run_cleaner, + inputs=[input_image, sel_mask, cleaner_model_id, cleaner_save_mask_chk], + outputs=[cleaner_out_image]) + get_alpha_image_btn.click( + run_get_alpha_image, + inputs=[input_image, sel_mask], + outputs=[alpha_out_image, get_alpha_status_text]) + get_mask_btn.click( + run_get_mask, + inputs=[sel_mask], + outputs=[mask_out_image]) + + return [(inpaint_anything_interface, "Inpaint Anything", "inpaint_anything")] + + +block, _, _ = on_ui_tabs()[0] +block.launch(share=True) diff --git a/images/inpaint_anything_explanation_image_1.png b/images/inpaint_anything_explanation_image_1.png new file mode 100644 index 0000000000000000000000000000000000000000..6262fa62aaf4c390041fafdd19f4dd9e9e4bc059 Binary files /dev/null and b/images/inpaint_anything_explanation_image_1.png differ diff --git a/images/inpaint_anything_ui_image_1.png b/images/inpaint_anything_ui_image_1.png new file mode 100644 index 0000000000000000000000000000000000000000..0f2ea7cea97ba33c26f9f2b31bad1174d0042c79 Binary files /dev/null and b/images/inpaint_anything_ui_image_1.png differ diff --git a/images/sample_input_image.png b/images/sample_input_image.png new file mode 100644 index 0000000000000000000000000000000000000000..e84dfc8554d344a69afb1fe8c7b8b2997d4e5e11 Binary files /dev/null and b/images/sample_input_image.png differ diff --git a/images/sample_mask_image.png b/images/sample_mask_image.png new file mode 100644 index 0000000000000000000000000000000000000000..3d26622c7ac31b7d18bdf0679ec89380bc79584f Binary files /dev/null and b/images/sample_mask_image.png differ diff --git a/images/sample_seg_color_image.png b/images/sample_seg_color_image.png new file mode 100644 index 0000000000000000000000000000000000000000..fdc3ddf73642b6e0fd43f2c540fd8ccf05eef40d Binary files /dev/null and b/images/sample_seg_color_image.png differ diff --git a/inpalib/__init__.py b/inpalib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10d945cf5faea2f7d219f6a7b34d2301b1f4b0e2 --- /dev/null +++ b/inpalib/__init__.py @@ -0,0 +1,18 @@ +from .masklib import create_mask_image, invert_mask +from .samlib import (create_seg_color_image, generate_sam_masks, get_all_sam_ids, + get_available_sam_ids, get_seg_colormap, insert_mask_to_sam_masks, + sam_file_exists, sam_file_path, sort_masks_by_area) + +__all__ = [ + "create_mask_image", + "invert_mask", + "create_seg_color_image", + "generate_sam_masks", + "get_all_sam_ids", + "get_available_sam_ids", + "get_seg_colormap", + "insert_mask_to_sam_masks", + "sam_file_exists", + "sam_file_path", + "sort_masks_by_area", +] diff --git a/inpalib/masklib.py b/inpalib/masklib.py new file mode 100644 index 0000000000000000000000000000000000000000..8717a9714370644bd214cc1981a3a5aa3cc2fa89 --- /dev/null +++ b/inpalib/masklib.py @@ -0,0 +1,106 @@ +from typing import Any, Dict, List, Union + +import numpy as np +from PIL import Image + + +def invert_mask(mask: np.ndarray) -> np.ndarray: + """Invert mask. + + Args: + mask (np.ndarray): mask + + Returns: + np.ndarray: inverted mask + """ + if mask is None or not isinstance(mask, np.ndarray): + raise ValueError("Invalid mask") + + # return np.logical_not(mask.astype(bool)).astype(np.uint8) * 255 + return np.invert(mask.astype(np.uint8)) + + +def check_inputs_create_mask_image( + mask: Union[np.ndarray, Image.Image], + sam_masks: List[Dict[str, Any]], + ignore_black_chk: bool = True, +) -> None: + """Check create mask image inputs. + + Args: + mask (Union[np.ndarray, Image.Image]): mask + sam_masks (List[Dict[str, Any]]): SAM masks + ignore_black_chk (bool): ignore black check + + Returns: + None + """ + if mask is None or not isinstance(mask, (np.ndarray, Image.Image)): + raise ValueError("Invalid mask") + + if sam_masks is None or not isinstance(sam_masks, list): + raise ValueError("Invalid SAM masks") + + if ignore_black_chk is None or not isinstance(ignore_black_chk, bool): + raise ValueError("Invalid ignore black check") + + +def convert_mask(mask: Union[np.ndarray, Image.Image]) -> np.ndarray: + """Convert mask. + + Args: + mask (Union[np.ndarray, Image.Image]): mask + + Returns: + np.ndarray: converted mask + """ + if isinstance(mask, Image.Image): + mask = np.array(mask) + + if mask.ndim == 2: + mask = mask[:, :, np.newaxis] + + if mask.shape[2] != 1: + mask = mask[:, :, 0:1] + + return mask + + +def create_mask_image( + mask: Union[np.ndarray, Image.Image], + sam_masks: List[Dict[str, Any]], + ignore_black_chk: bool = True, +) -> np.ndarray: + """Create mask image. + + Args: + mask (Union[np.ndarray, Image.Image]): mask + sam_masks (List[Dict[str, Any]]): SAM masks + ignore_black_chk (bool): ignore black check + + Returns: + np.ndarray: mask image + """ + check_inputs_create_mask_image(mask, sam_masks, ignore_black_chk) + mask = convert_mask(mask) + + canvas_image = np.zeros(mask.shape, dtype=np.uint8) + mask_region = np.zeros(mask.shape, dtype=np.uint8) + for seg_dict in sam_masks: + seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1) + canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8) + if (seg_mask * canvas_mask * mask).astype(bool).any(): + mask_region = mask_region + (seg_mask * canvas_mask) + seg_color = seg_mask * canvas_mask + canvas_image = canvas_image + seg_color + + if not ignore_black_chk: + canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8) + if (canvas_mask * mask).astype(bool).any(): + mask_region = mask_region + (canvas_mask) + + mask_region = np.tile(mask_region * 255, (1, 1, 3)) + + seg_image = mask_region.astype(np.uint8) + + return seg_image diff --git a/inpalib/samlib.py b/inpalib/samlib.py new file mode 100644 index 0000000000000000000000000000000000000000..61417243359752cb283e9880c7f73154c158ecf1 --- /dev/null +++ b/inpalib/samlib.py @@ -0,0 +1,256 @@ +import copy +import os +import sys +from typing import Any, Dict, List, Union + +import cv2 +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + +inpa_basedir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..")) +if inpa_basedir not in sys.path: + sys.path.append(inpa_basedir) + +from ia_file_manager import ia_file_manager # noqa: E402 +from ia_get_dataset_colormap import create_pascal_label_colormap # noqa: E402 +from ia_logging import ia_logging # noqa: E402 +from ia_sam_manager import check_bfloat16_support, get_sam_mask_generator # noqa: E402 +from ia_ui_items import get_sam_model_ids # noqa: E402 + + +def get_all_sam_ids() -> List[str]: + """Get all SAM IDs. + + Returns: + List[str]: SAM IDs + """ + return get_sam_model_ids() + + +def sam_file_path(sam_id: str) -> str: + """Get SAM file path. + + Args: + sam_id (str): SAM ID + + Returns: + str: SAM file path + """ + return os.path.join(ia_file_manager.models_dir, sam_id) + + +def sam_file_exists(sam_id: str) -> bool: + """Check if SAM file exists. + + Args: + sam_id (str): SAM ID + + Returns: + bool: True if SAM file exists else False + """ + sam_checkpoint = sam_file_path(sam_id) + + return os.path.isfile(sam_checkpoint) + + +def get_available_sam_ids() -> List[str]: + """Get available SAM IDs. + + Returns: + List[str]: available SAM IDs + """ + all_sam_ids = get_all_sam_ids() + for sam_id in all_sam_ids.copy(): + if not sam_file_exists(sam_id): + all_sam_ids.remove(sam_id) + + return all_sam_ids + + +def check_inputs_generate_sam_masks( + input_image: Union[np.ndarray, Image.Image], + sam_id: str, + anime_style_chk: bool = False, +) -> None: + """Check generate SAM masks inputs. + + Args: + input_image (Union[np.ndarray, Image.Image]): input image + sam_id (str): SAM ID + anime_style_chk (bool): anime style check + + Returns: + None + """ + if input_image is None or not isinstance(input_image, (np.ndarray, Image.Image)): + raise ValueError("Invalid input image") + + if sam_id is None or not isinstance(sam_id, str): + raise ValueError("Invalid SAM ID") + + if anime_style_chk is None or not isinstance(anime_style_chk, bool): + raise ValueError("Invalid anime style check") + + +def convert_input_image(input_image: Union[np.ndarray, Image.Image]) -> np.ndarray: + """Convert input image. + + Args: + input_image (Union[np.ndarray, Image.Image]): input image + + Returns: + np.ndarray: converted input image + """ + if isinstance(input_image, Image.Image): + input_image = np.array(input_image) + + if input_image.ndim == 2: + input_image = input_image[:, :, np.newaxis] + + if input_image.shape[2] == 1: + input_image = np.concatenate([input_image] * 3, axis=-1) + + return input_image + + +def generate_sam_masks( + input_image: Union[np.ndarray, Image.Image], + sam_id: str, + anime_style_chk: bool = False, +) -> List[Dict[str, Any]]: + """Generate SAM masks. + + Args: + input_image (Union[np.ndarray, Image.Image]): input image + sam_id (str): SAM ID + anime_style_chk (bool): anime style check + + Returns: + List[Dict[str, Any]]: SAM masks + """ + check_inputs_generate_sam_masks(input_image, sam_id, anime_style_chk) + input_image = convert_input_image(input_image) + + sam_checkpoint = sam_file_path(sam_id) + sam_mask_generator = get_sam_mask_generator(sam_checkpoint, anime_style_chk) + ia_logging.info(f"{sam_mask_generator.__class__.__name__} {sam_id}") + + if "sam2_" in sam_id: + device = "cuda" if torch.cuda.is_available() else "cpu" + torch_dtype = torch.bfloat16 if check_bfloat16_support() else torch.float16 + with torch.inference_mode(), torch.autocast(device, dtype=torch_dtype): + sam_masks = sam_mask_generator.generate(input_image) + else: + sam_masks = sam_mask_generator.generate(input_image) + + if anime_style_chk: + for sam_mask in sam_masks: + sam_mask_seg = sam_mask["segmentation"] + sam_mask_seg = cv2.morphologyEx(sam_mask_seg.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8)) + sam_mask_seg = cv2.morphologyEx(sam_mask_seg.astype(np.uint8), cv2.MORPH_OPEN, np.ones((5, 5), np.uint8)) + sam_mask["segmentation"] = sam_mask_seg.astype(bool) + + ia_logging.info("sam_masks: {}".format(len(sam_masks))) + + sam_masks = copy.deepcopy(sam_masks) + return sam_masks + + +def sort_masks_by_area( + sam_masks: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Sort mask by area. + + Args: + sam_masks (List[Dict[str, Any]]): SAM masks + + Returns: + List[Dict[str, Any]]: sorted SAM masks + """ + return sorted(sam_masks, key=lambda x: np.sum(x.get("segmentation").astype(np.uint32))) + + +def get_seg_colormap() -> np.ndarray: + """Get segmentation colormap. + + Returns: + np.ndarray: segmentation colormap + """ + cm_pascal = create_pascal_label_colormap() + seg_colormap = cm_pascal + seg_colormap = np.array([c for c in seg_colormap if max(c) >= 64], dtype=np.uint8) + + return seg_colormap + + +def insert_mask_to_sam_masks( + sam_masks: List[Dict[str, Any]], + insert_mask: Dict[str, Any], +) -> List[Dict[str, Any]]: + """Insert mask to SAM masks. + + Args: + sam_masks (List[Dict[str, Any]]): SAM masks + insert_mask (Dict[str, Any]): insert mask + + Returns: + List[Dict[str, Any]]: SAM masks + """ + if insert_mask is not None and isinstance(insert_mask, dict) and "segmentation" in insert_mask: + if (len(sam_masks) > 0 and + sam_masks[0]["segmentation"].shape == insert_mask["segmentation"].shape and + np.any(insert_mask["segmentation"])): + sam_masks.insert(0, insert_mask) + ia_logging.info("insert mask to sam_masks") + + return sam_masks + + +def create_seg_color_image( + input_image: Union[np.ndarray, Image.Image], + sam_masks: List[Dict[str, Any]], +) -> np.ndarray: + """Create segmentation color image. + + Args: + input_image (Union[np.ndarray, Image.Image]): input image + sam_masks (List[Dict[str, Any]]): SAM masks + + Returns: + np.ndarray: segmentation color image + """ + input_image = convert_input_image(input_image) + + seg_colormap = get_seg_colormap() + sam_masks = sam_masks[:len(seg_colormap)] + + with tqdm(total=len(sam_masks), desc="Processing segments") as progress_bar: + canvas_image = np.zeros((*input_image.shape[:2], 1), dtype=np.uint8) + for idx, seg_dict in enumerate(sam_masks[0:min(255, len(sam_masks))]): + seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1) + canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8) + seg_color = np.array([idx+1], dtype=np.uint8) * seg_mask * canvas_mask + canvas_image = canvas_image + seg_color + progress_bar.update(1) + seg_colormap = np.insert(seg_colormap, 0, [0, 0, 0], axis=0) + temp_canvas_image = np.apply_along_axis(lambda x: seg_colormap[x[0]], axis=-1, arr=canvas_image) + if len(sam_masks) > 255: + canvas_image = canvas_image.astype(bool).astype(np.uint8) + for idx, seg_dict in enumerate(sam_masks[255:min(509, len(sam_masks))]): + seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1) + canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8) + seg_color = np.array([idx+2], dtype=np.uint8) * seg_mask * canvas_mask + canvas_image = canvas_image + seg_color + progress_bar.update(1) + seg_colormap = seg_colormap[256:] + seg_colormap = np.insert(seg_colormap, 0, [0, 0, 0], axis=0) + seg_colormap = np.insert(seg_colormap, 0, [0, 0, 0], axis=0) + canvas_image = np.apply_along_axis(lambda x: seg_colormap[x[0]], axis=-1, arr=canvas_image) + canvas_image = temp_canvas_image + canvas_image + else: + canvas_image = temp_canvas_image + ret_seg_image = canvas_image.astype(np.uint8) + + return ret_seg_image diff --git a/javascript/inpaint-anything.js b/javascript/inpaint-anything.js new file mode 100644 index 0000000000000000000000000000000000000000..776a83688c9cc14f8ed0a0096812790b885b7f89 --- /dev/null +++ b/javascript/inpaint-anything.js @@ -0,0 +1,458 @@ +const inpaintAnything_waitForElement = async (parent, selector, exist) => { + return new Promise((resolve) => { + const observer = new MutationObserver(() => { + if (!!parent.querySelector(selector) != exist) { + return; + } + observer.disconnect(); + resolve(undefined); + }); + + observer.observe(parent, { + childList: true, + subtree: true, + }); + + if (!!parent.querySelector(selector) == exist) { + resolve(undefined); + } + }); +}; + +const inpaintAnything_waitForStyle = async (parent, selector, style) => { + return new Promise((resolve) => { + const observer = new MutationObserver(() => { + if (!parent.querySelector(selector) || !parent.querySelector(selector).style[style]) { + return; + } + observer.disconnect(); + resolve(undefined); + }); + + observer.observe(parent, { + childList: true, + subtree: true, + attributes: true, + attributeFilter: ["style"], + }); + + if (!!parent.querySelector(selector) && !!parent.querySelector(selector).style[style]) { + resolve(undefined); + } + }); +}; + +const inpaintAnything_timeout = (ms) => { + return new Promise(function (resolve, reject) { + setTimeout(() => reject("Timeout"), ms); + }); +}; + +async function inpaintAnything_clearSamMask() { + const waitForElementToBeInDocument = (parent, selector) => + Promise.race([inpaintAnything_waitForElement(parent, selector, true), inpaintAnything_timeout(1000)]); + + const elemId = "#ia_sam_image"; + + const targetElement = document.querySelector(elemId); + if (!targetElement) { + return; + } + await waitForElementToBeInDocument(targetElement, "button[aria-label='Clear']"); + + targetElement.style.transform = null; + targetElement.style.zIndex = null; + targetElement.style.overflow = "auto"; + + const samMaskClear = targetElement.querySelector("button[aria-label='Clear']"); + if (!samMaskClear) { + return; + } + const removeImageButton = targetElement.querySelector("button[aria-label='Remove Image']"); + if (!removeImageButton) { + return; + } + samMaskClear?.click(); + + if (typeof inpaintAnything_clearSamMask.clickRemoveImage === "undefined") { + inpaintAnything_clearSamMask.clickRemoveImage = () => { + targetElement.style.transform = null; + targetElement.style.zIndex = null; + }; + } else { + removeImageButton.removeEventListener("click", inpaintAnything_clearSamMask.clickRemoveImage); + } + removeImageButton.addEventListener("click", inpaintAnything_clearSamMask.clickRemoveImage); +} + +async function inpaintAnything_clearSelMask() { + const waitForElementToBeInDocument = (parent, selector) => + Promise.race([inpaintAnything_waitForElement(parent, selector, true), inpaintAnything_timeout(1000)]); + + const elemId = "#ia_sel_mask"; + + const targetElement = document.querySelector(elemId); + if (!targetElement) { + return; + } + await waitForElementToBeInDocument(targetElement, "button[aria-label='Clear']"); + + targetElement.style.transform = null; + targetElement.style.zIndex = null; + targetElement.style.overflow = "auto"; + + const selMaskClear = targetElement.querySelector("button[aria-label='Clear']"); + if (!selMaskClear) { + return; + } + const removeImageButton = targetElement.querySelector("button[aria-label='Remove Image']"); + if (!removeImageButton) { + return; + } + selMaskClear?.click(); + + if (typeof inpaintAnything_clearSelMask.clickRemoveImage === "undefined") { + inpaintAnything_clearSelMask.clickRemoveImage = () => { + targetElement.style.transform = null; + targetElement.style.zIndex = null; + }; + } else { + removeImageButton.removeEventListener("click", inpaintAnything_clearSelMask.clickRemoveImage); + } + removeImageButton.addEventListener("click", inpaintAnything_clearSelMask.clickRemoveImage); +} + +async function inpaintAnything_initSamSelMask() { + inpaintAnything_clearSamMask(); + inpaintAnything_clearSelMask(); +} + +var uiLoadedCallbacks = []; + +function gradioApp() { + const elems = document.getElementsByTagName("gradio-app"); + const elem = elems.length == 0 ? document : elems[0]; + + if (elem !== document) { + elem.getElementById = function (id) { + return document.getElementById(id); + }; + } + return elem.shadowRoot ? elem.shadowRoot : elem; +} + +function onUiLoaded(callback) { + uiLoadedCallbacks.push(callback); +} + +function executeCallbacks(queue) { + for (const callback of queue) { + try { + callback(); + } catch (e) { + console.error("error running callback", callback, ":", e); + } + } +} + +onUiLoaded(async () => { + const elementIDs = { + ia_sam_image: "#ia_sam_image", + ia_sel_mask: "#ia_sel_mask", + ia_out_image: "#ia_out_image", + ia_cleaner_out_image: "#ia_cleaner_out_image", + }; + + function setStyleHeight(elemId, height) { + const elem = gradioApp().querySelector(elemId); + if (elem) { + if (!elem.style.height) { + elem.style.height = height; + const observer = new MutationObserver(() => { + const divPreview = elem.querySelector(".preview"); + if (divPreview) { + divPreview.classList.remove("fixed-height"); + } + }); + observer.observe(elem, { + childList: true, + attributes: true, + attributeFilter: ["class"], + }); + } + } + } + + setStyleHeight(elementIDs.ia_out_image, "520px"); + setStyleHeight(elementIDs.ia_cleaner_out_image, "520px"); + + // Default config + const defaultHotkeysConfig = { + canvas_hotkey_reset: "KeyR", + canvas_hotkey_fullscreen: "KeyS", + }; + + const elemData = {}; + let activeElement; + + function applyZoomAndPan(elemId) { + const targetElement = gradioApp().querySelector(elemId); + + if (!targetElement) { + console.log("Element not found"); + return; + } + + targetElement.style.transformOrigin = "0 0"; + + elemData[elemId] = { + zoomLevel: 1, + panX: 0, + panY: 0, + }; + let fullScreenMode = false; + + // Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements + function toggleOverlap(forced = "") { + // const zIndex1 = "0"; + const zIndex1 = null; + const zIndex2 = "998"; + + targetElement.style.zIndex = targetElement.style.zIndex !== zIndex2 ? zIndex2 : zIndex1; + + if (forced === "off") { + targetElement.style.zIndex = zIndex1; + } else if (forced === "on") { + targetElement.style.zIndex = zIndex2; + } + } + + /** + * This function fits the target element to the screen by calculating + * the required scale and offsets. It also updates the global variables + * zoomLevel, panX, and panY to reflect the new state. + */ + + function fitToElement() { + //Reset Zoom + targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`; + + // Get element and screen dimensions + const elementWidth = targetElement.offsetWidth; + const elementHeight = targetElement.offsetHeight; + const parentElement = targetElement.parentElement; + const screenWidth = parentElement.clientWidth; + const screenHeight = parentElement.clientHeight; + + // Get element's coordinates relative to the parent element + const elementRect = targetElement.getBoundingClientRect(); + const parentRect = parentElement.getBoundingClientRect(); + const elementX = elementRect.x - parentRect.x; + + // Calculate scale and offsets + const scaleX = screenWidth / elementWidth; + const scaleY = screenHeight / elementHeight; + const scale = Math.min(scaleX, scaleY); + + const transformOrigin = window.getComputedStyle(targetElement).transformOrigin; + const [originX, originY] = transformOrigin.split(" "); + const originXValue = parseFloat(originX); + const originYValue = parseFloat(originY); + + const offsetX = (screenWidth - elementWidth * scale) / 2 - originXValue * (1 - scale); + const offsetY = (screenHeight - elementHeight * scale) / 2.5 - originYValue * (1 - scale); + + // Apply scale and offsets to the element + targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`; + + // Update global variables + elemData[elemId].zoomLevel = scale; + elemData[elemId].panX = offsetX; + elemData[elemId].panY = offsetY; + + fullScreenMode = false; + toggleOverlap("off"); + } + + // Reset the zoom level and pan position of the target element to their initial values + function resetZoom() { + elemData[elemId] = { + zoomLevel: 1, + panX: 0, + panY: 0, + }; + + // fixCanvas(); + targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`; + + // const canvas = gradioApp().querySelector(`${elemId} canvas[key="interface"]`); + + toggleOverlap("off"); + fullScreenMode = false; + + // if ( + // canvas && + // parseFloat(canvas.style.width) > 865 && + // parseFloat(targetElement.style.width) > 865 + // ) { + // fitToElement(); + // return; + // } + + // targetElement.style.width = ""; + // if (canvas) { + // targetElement.style.height = canvas.style.height; + // } + targetElement.style.width = null; + targetElement.style.height = 480; + } + + /** + * This function fits the target element to the screen by calculating + * the required scale and offsets. It also updates the global variables + * zoomLevel, panX, and panY to reflect the new state. + */ + + // Fullscreen mode + function fitToScreen() { + const canvas = gradioApp().querySelector(`${elemId} canvas[key="interface"]`); + const img = gradioApp().querySelector(`${elemId} img`); + + if (!canvas && !img) return; + + // if (canvas.offsetWidth > 862) { + // targetElement.style.width = canvas.offsetWidth + "px"; + // } + + if (fullScreenMode) { + resetZoom(); + fullScreenMode = false; + return; + } + + //Reset Zoom + targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`; + + // Get scrollbar width to right-align the image + const scrollbarWidth = window.innerWidth - document.documentElement.clientWidth; + + // Get element and screen dimensions + const elementWidth = targetElement.offsetWidth; + const elementHeight = targetElement.offsetHeight; + const screenWidth = window.innerWidth - scrollbarWidth; + const screenHeight = window.innerHeight; + + // Get element's coordinates relative to the page + const elementRect = targetElement.getBoundingClientRect(); + const elementY = elementRect.y; + const elementX = elementRect.x; + + // Calculate scale and offsets + const scaleX = screenWidth / elementWidth; + const scaleY = screenHeight / elementHeight; + const scale = Math.min(scaleX, scaleY); + + // Get the current transformOrigin + const computedStyle = window.getComputedStyle(targetElement); + const transformOrigin = computedStyle.transformOrigin; + const [originX, originY] = transformOrigin.split(" "); + const originXValue = parseFloat(originX); + const originYValue = parseFloat(originY); + + // Calculate offsets with respect to the transformOrigin + const offsetX = (screenWidth - elementWidth * scale) / 2 - elementX - originXValue * (1 - scale); + const offsetY = (screenHeight - elementHeight * scale) / 2 - elementY - originYValue * (1 - scale); + + // Apply scale and offsets to the element + targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`; + + // Update global variables + elemData[elemId].zoomLevel = scale; + elemData[elemId].panX = offsetX; + elemData[elemId].panY = offsetY; + + fullScreenMode = true; + toggleOverlap("on"); + } + + // Reset zoom when uploading a new image + const fileInput = gradioApp().querySelector(`${elemId} input[type="file"][accept="image/*"].svelte-116rqfv`); + if (fileInput) { + fileInput.addEventListener("click", resetZoom); + } + + // Handle keydown events + function handleKeyDown(event) { + // Disable key locks to make pasting from the buffer work correctly + if ( + (event.ctrlKey && event.code === "KeyV") || + (event.ctrlKey && event.code === "KeyC") || + event.code === "F5" + ) { + return; + } + + // before activating shortcut, ensure user is not actively typing in an input field + if (event.target.nodeName === "TEXTAREA" || event.target.nodeName === "INPUT") { + return; + } + + const hotkeyActions = { + [defaultHotkeysConfig.canvas_hotkey_reset]: resetZoom, + [defaultHotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen, + }; + + const action = hotkeyActions[event.code]; + if (action) { + event.preventDefault(); + action(event); + } + } + + // Handle events only inside the targetElement + let isKeyDownHandlerAttached = false; + + function handleMouseMove() { + if (!isKeyDownHandlerAttached) { + document.addEventListener("keydown", handleKeyDown); + isKeyDownHandlerAttached = true; + + activeElement = elemId; + } + } + + function handleMouseLeave() { + if (isKeyDownHandlerAttached) { + document.removeEventListener("keydown", handleKeyDown); + isKeyDownHandlerAttached = false; + + activeElement = null; + } + } + + // Add mouse event handlers + targetElement.addEventListener("mousemove", handleMouseMove); + targetElement.addEventListener("mouseleave", handleMouseLeave); + } + + applyZoomAndPan(elementIDs.ia_sam_image); + applyZoomAndPan(elementIDs.ia_sel_mask); + // applyZoomAndPan(elementIDs.ia_out_image); + // applyZoomAndPan(elementIDs.ia_cleaner_out_image); +}); + +var executedOnLoaded = false; + +document.addEventListener("DOMContentLoaded", function () { + var mutationObserver = new MutationObserver(function () { + if ( + !executedOnLoaded && + gradioApp().querySelector("#ia_sam_image") && + gradioApp().querySelector("#ia_sel_mask") + ) { + executedOnLoaded = true; + executeCallbacks(uiLoadedCallbacks); + } + }); + mutationObserver.observe(gradioApp(), { childList: true, subtree: true }); +}); diff --git a/lama_cleaner/__init__.py b/lama_cleaner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8145aeff12798ce8e10f4841715f823bd4fcbeef --- /dev/null +++ b/lama_cleaner/__init__.py @@ -0,0 +1,19 @@ +import os + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +import warnings # noqa: E402 + +warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") +warnings.filterwarnings("ignore", category=UserWarning, module="lama_cleaner") + +from lama_cleaner.parse_args import parse_args # noqa: E402 + + +def entry_point(): + args = parse_args() + # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers + # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18 + from lama_cleaner.server import main + + main(args) diff --git a/lama_cleaner/benchmark.py b/lama_cleaner/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..633b2f9349bebbf72b2390d9c8fc3275d86eb3ab --- /dev/null +++ b/lama_cleaner/benchmark.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 + +import argparse +import os +import time + +import numpy as np +import nvidia_smi +import psutil +import torch + +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import Config, HDStrategy, SDSampler + +try: + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(False) +except: + pass + +NUM_THREADS = str(4) + +os.environ["OMP_NUM_THREADS"] = NUM_THREADS +os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS +os.environ["MKL_NUM_THREADS"] = NUM_THREADS +os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS +os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS +if os.environ.get("CACHE_DIR"): + os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] + + +def run_model(model, size): + # RGB + image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8) + mask = np.random.randint(0, 255, size).astype(np.uint8) + + config = Config( + ldm_steps=2, + hd_strategy=HDStrategy.ORIGINAL, + hd_strategy_crop_margin=128, + hd_strategy_crop_trigger_size=128, + hd_strategy_resize_limit=128, + prompt="a fox is sitting on a bench", + sd_steps=5, + sd_sampler=SDSampler.ddim + ) + model(image, mask, config) + + +def benchmark(model, times: int, empty_cache: bool): + sizes = [(512, 512)] + + nvidia_smi.nvmlInit() + device_id = 0 + handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id) + + def format(metrics): + return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}" + + process = psutil.Process(os.getpid()) + # 每个 size 给出显存和内存占用的指标 + for size in sizes: + torch.cuda.empty_cache() + time_metrics = [] + cpu_metrics = [] + memory_metrics = [] + gpu_memory_metrics = [] + for _ in range(times): + start = time.time() + run_model(model, size) + torch.cuda.synchronize() + + # cpu_metrics.append(process.cpu_percent()) + time_metrics.append((time.time() - start) * 1000) + memory_metrics.append(process.memory_info().rss / 1024 / 1024) + gpu_memory_metrics.append(nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024) + + print(f"size: {size}".center(80, "-")) + # print(f"cpu: {format(cpu_metrics)}") + print(f"latency: {format(time_metrics)}ms") + print(f"memory: {format(memory_metrics)} MB") + print(f"gpu memory: {format(gpu_memory_metrics)} MB") + + nvidia_smi.nvmlShutdown() + + +def get_args_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--name") + parser.add_argument("--device", default="cuda", type=str) + parser.add_argument("--times", default=10, type=int) + parser.add_argument("--empty-cache", action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args_parser() + device = torch.device(args.device) + model = ModelManager( + name=args.name, + device=device, + sd_run_local=True, + disable_nsfw=True, + sd_cpu_textencoder=True, + hf_access_token="123" + ) + benchmark(model, args.times, args.empty_cache) diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py new file mode 100644 index 0000000000000000000000000000000000000000..3cf1d1642a2b7db7672008373b058326b40be3a6 --- /dev/null +++ b/lama_cleaner/const.py @@ -0,0 +1,173 @@ +import json +import os +from enum import Enum +from pydantic import BaseModel + + +MPS_SUPPORT_MODELS = [ + "instruct_pix2pix", + "sd1.5", + "anything4", + "realisticVision1.4", + "sd2", + "paint_by_example", + "controlnet", +] + +DEFAULT_MODEL = "lama" +AVAILABLE_MODELS = [ + "lama", + "ldm", + "zits", + "mat", + "fcf", + "sd1.5", + "anything4", + "realisticVision1.4", + "cv2", + "manga", + "sd2", + "paint_by_example", + "instruct_pix2pix", +] +SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"] + +AVAILABLE_DEVICES = ["cuda", "cpu", "mps"] +DEFAULT_DEVICE = "cuda" + +NO_HALF_HELP = """ +Using full precision model. +If your generate result is always black or green, use this argument. (sd/paint_by_exmaple) +""" + +CPU_OFFLOAD_HELP = """ +Offloads all models to CPU, significantly reducing vRAM usage. (sd/paint_by_example) +""" + +DISABLE_NSFW_HELP = """ +Disable NSFW checker. (sd/paint_by_example) +""" + +SD_CPU_TEXTENCODER_HELP = """ +Run Stable Diffusion text encoder model on CPU to save GPU memory. +""" + +SD_CONTROLNET_HELP = """ +Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui. +""" +DEFAULT_CONTROLNET_METHOD = "control_v11p_sd15_canny" +SD_CONTROLNET_CHOICES = [ + "control_v11p_sd15_canny", + "control_v11p_sd15_openpose", + "control_v11p_sd15_inpaint", + "control_v11f1p_sd15_depth" +] + +SD_LOCAL_MODEL_HELP = """ +Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path. +""" + +LOCAL_FILES_ONLY_HELP = """ +Use local files only, not connect to Hugging Face server. (sd/paint_by_example) +""" + +ENABLE_XFORMERS_HELP = """ +Enable xFormers optimizations. Requires xformers package has been installed. See: https://github.com/facebookresearch/xformers (sd/paint_by_example) +""" + +DEFAULT_MODEL_DIR = os.getenv( + "XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache") +) +MODEL_DIR_HELP = """ +Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache +""" + +OUTPUT_DIR_HELP = """ +Result images will be saved to output directory automatically without confirmation. +""" + +INPUT_HELP = """ +If input is image, it will be loaded by default. +If input is directory, you can browse and select image in file manager. +""" + +GUI_HELP = """ +Launch Lama Cleaner as desktop app +""" + +NO_GUI_AUTO_CLOSE_HELP = """ +Prevent backend auto close after the GUI window closed. +""" + +QUALITY_HELP = """ +Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size. +""" + + +class RealESRGANModelName(str, Enum): + realesr_general_x4v3 = "realesr-general-x4v3" + RealESRGAN_x4plus = "RealESRGAN_x4plus" + RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B" + + +RealESRGANModelNameList = [e.value for e in RealESRGANModelName] + +INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything." +INTERACTIVE_SEG_MODEL_HELP = "Model size: vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed." +AVAILABLE_INTERACTIVE_SEG_MODELS = ["vit_b", "vit_l", "vit_h"] +AVAILABLE_INTERACTIVE_SEG_DEVICES = ["cuda", "cpu", "mps"] +REMOVE_BG_HELP = "Enable remove background. Always run on CPU" +ANIMESEG_HELP = "Enable anime segmentation. Always run on CPU" +REALESRGAN_HELP = "Enable realesrgan super resolution" +REALESRGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"] +GFPGAN_HELP = ( + "Enable GFPGAN face restore. To enhance background, use with --enable-realesrgan" +) +GFPGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"] +RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To enhance background, use with --enable-realesrgan" +RESTOREFORMER_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"] +GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image" + + +class Config(BaseModel): + host: str = "127.0.0.1" + port: int = 8080 + model: str = DEFAULT_MODEL + sd_local_model_path: str = None + sd_controlnet: bool = False + sd_controlnet_method: str = DEFAULT_CONTROLNET_METHOD + device: str = DEFAULT_DEVICE + gui: bool = False + no_gui_auto_close: bool = False + no_half: bool = False + cpu_offload: bool = False + disable_nsfw: bool = False + sd_cpu_textencoder: bool = False + enable_xformers: bool = False + local_files_only: bool = False + model_dir: str = DEFAULT_MODEL_DIR + input: str = None + output_dir: str = None + # plugins + enable_interactive_seg: bool = False + interactive_seg_model: str = "vit_l" + interactive_seg_device: str = "cpu" + enable_remove_bg: bool = False + enable_anime_seg: bool = False + enable_realesrgan: bool = False + realesrgan_device: str = "cpu" + realesrgan_model: str = RealESRGANModelName.realesr_general_x4v3.value + realesrgan_no_half: bool = False + enable_gfpgan: bool = False + gfpgan_device: str = "cpu" + enable_restoreformer: bool = False + restoreformer_device: str = "cpu" + enable_gif: bool = False + + +def load_config(installer_config: str): + if os.path.exists(installer_config): + with open(installer_config, "r", encoding="utf-8") as f: + return Config(**json.load(f)) + else: + return Config() diff --git a/lama_cleaner/file_manager/__init__.py b/lama_cleaner/file_manager/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87ddab7e052afcb2ac168e3f8b1f114e406b2938 --- /dev/null +++ b/lama_cleaner/file_manager/__init__.py @@ -0,0 +1 @@ +from .file_manager import FileManager diff --git a/lama_cleaner/file_manager/file_manager.py b/lama_cleaner/file_manager/file_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..ac1777792dbdfafd4d39ba181be51df2795f768a --- /dev/null +++ b/lama_cleaner/file_manager/file_manager.py @@ -0,0 +1,265 @@ +# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/thumbnail.py +import os +from datetime import datetime + +import cv2 +import time +from io import BytesIO +from pathlib import Path +import numpy as np +# from watchdog.events import FileSystemEventHandler +# from watchdog.observers import Observer + +from PIL import Image, ImageOps, PngImagePlugin +from loguru import logger + +LARGE_ENOUGH_NUMBER = 100 +PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2) +from .storage_backends import FilesystemStorageBackend +from .utils import aspect_to_string, generate_filename, glob_img + + +class FileManager: + def __init__(self, app=None): + self.app = app + self._default_root_directory = "media" + self._default_thumbnail_directory = "media" + self._default_root_url = "/" + self._default_thumbnail_root_url = "/" + self._default_format = "JPEG" + self.output_dir: Path = None + + if app is not None: + self.init_app(app) + + self.image_dir_filenames = [] + self.output_dir_filenames = [] + + self.image_dir_observer = None + self.output_dir_observer = None + + self.modified_time = { + "image": datetime.utcnow(), + "output": datetime.utcnow(), + } + + # def start(self): + # self.image_dir_filenames = self._media_names(self.root_directory) + # self.output_dir_filenames = self._media_names(self.output_dir) + # + # logger.info(f"Start watching image directory: {self.root_directory}") + # self.image_dir_observer = Observer() + # self.image_dir_observer.schedule(self, self.root_directory, recursive=False) + # self.image_dir_observer.start() + # + # logger.info(f"Start watching output directory: {self.output_dir}") + # self.output_dir_observer = Observer() + # self.output_dir_observer.schedule(self, self.output_dir, recursive=False) + # self.output_dir_observer.start() + + def on_modified(self, event): + if not os.path.isdir(event.src_path): + return + if event.src_path == str(self.root_directory): + logger.info(f"Image directory {event.src_path} modified") + self.image_dir_filenames = self._media_names(self.root_directory) + self.modified_time["image"] = datetime.utcnow() + elif event.src_path == str(self.output_dir): + logger.info(f"Output directory {event.src_path} modified") + self.output_dir_filenames = self._media_names(self.output_dir) + self.modified_time["output"] = datetime.utcnow() + + def init_app(self, app): + if self.app is None: + self.app = app + app.thumbnail_instance = self + + if not hasattr(app, "extensions"): + app.extensions = {} + + if "thumbnail" in app.extensions: + raise RuntimeError("Flask-thumbnail extension already initialized") + + app.extensions["thumbnail"] = self + + app.config.setdefault("THUMBNAIL_MEDIA_ROOT", self._default_root_directory) + app.config.setdefault( + "THUMBNAIL_MEDIA_THUMBNAIL_ROOT", self._default_thumbnail_directory + ) + app.config.setdefault("THUMBNAIL_MEDIA_URL", self._default_root_url) + app.config.setdefault( + "THUMBNAIL_MEDIA_THUMBNAIL_URL", self._default_thumbnail_root_url + ) + app.config.setdefault("THUMBNAIL_DEFAULT_FORMAT", self._default_format) + + @property + def root_directory(self): + path = self.app.config["THUMBNAIL_MEDIA_ROOT"] + + if os.path.isabs(path): + return path + else: + return os.path.join(self.app.root_path, path) + + @property + def thumbnail_directory(self): + path = self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] + + if os.path.isabs(path): + return path + else: + return os.path.join(self.app.root_path, path) + + @property + def root_url(self): + return self.app.config["THUMBNAIL_MEDIA_URL"] + + @property + def media_names(self): + # return self.image_dir_filenames + return self._media_names(self.root_directory) + + @property + def output_media_names(self): + return self._media_names(self.output_dir) + # return self.output_dir_filenames + + @staticmethod + def _media_names(directory: Path): + names = sorted([it.name for it in glob_img(directory)]) + res = [] + for name in names: + path = os.path.join(directory, name) + img = Image.open(path) + res.append( + { + "name": name, + "height": img.height, + "width": img.width, + "ctime": os.path.getctime(path), + "mtime": os.path.getmtime(path), + } + ) + return res + + @property + def thumbnail_url(self): + return self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_URL"] + + def get_thumbnail( + self, directory: Path, original_filename: str, width, height, **options + ): + storage = FilesystemStorageBackend(self.app) + crop = options.get("crop", "fit") + background = options.get("background") + quality = options.get("quality", 90) + + original_path, original_filename = os.path.split(original_filename) + original_filepath = os.path.join(directory, original_path, original_filename) + image = Image.open(BytesIO(storage.read(original_filepath))) + + # keep ratio resize + if width is not None: + height = int(image.height * width / image.width) + else: + width = int(image.width * height / image.height) + + thumbnail_size = (width, height) + + thumbnail_filename = generate_filename( + original_filename, + aspect_to_string(thumbnail_size), + crop, + background, + quality, + ) + + thumbnail_filepath = os.path.join( + self.thumbnail_directory, original_path, thumbnail_filename + ) + thumbnail_url = os.path.join( + self.thumbnail_url, original_path, thumbnail_filename + ) + + if storage.exists(thumbnail_filepath): + return thumbnail_url, (width, height) + + try: + image.load() + except (IOError, OSError): + self.app.logger.warning("Thumbnail not load image: %s", original_filepath) + return thumbnail_url, (width, height) + + # get original image format + options["format"] = options.get("format", image.format) + + image = self._create_thumbnail( + image, thumbnail_size, crop, background=background + ) + + raw_data = self.get_raw_data(image, **options) + storage.save(thumbnail_filepath, raw_data) + + return thumbnail_url, (width, height) + + def get_raw_data(self, image, **options): + data = { + "format": self._get_format(image, **options), + "quality": options.get("quality", 90), + } + + _file = BytesIO() + image.save(_file, **data) + return _file.getvalue() + + @staticmethod + def colormode(image, colormode="RGB"): + if colormode == "RGB" or colormode == "RGBA": + if image.mode == "RGBA": + return image + if image.mode == "LA": + return image.convert("RGBA") + return image.convert(colormode) + + if colormode == "GRAY": + return image.convert("L") + + return image.convert(colormode) + + @staticmethod + def background(original_image, color=0xFF): + size = (max(original_image.size),) * 2 + image = Image.new("L", size, color) + image.paste( + original_image, + tuple(map(lambda x: (x[0] - x[1]) / 2, zip(size, original_image.size))), + ) + + return image + + def _get_format(self, image, **options): + if options.get("format"): + return options.get("format") + if image.format: + return image.format + + return self.app.config["THUMBNAIL_DEFAULT_FORMAT"] + + def _create_thumbnail(self, image, size, crop="fit", background=None): + try: + resample = Image.Resampling.LANCZOS + except AttributeError: # pylint: disable=raise-missing-from + resample = Image.ANTIALIAS + + if crop == "fit": + image = ImageOps.fit(image, size, resample) + else: + image = image.copy() + image.thumbnail(size, resample=resample) + + if background is not None: + image = self.background(image) + + image = self.colormode(image) + + return image diff --git a/lama_cleaner/file_manager/storage_backends.py b/lama_cleaner/file_manager/storage_backends.py new file mode 100644 index 0000000000000000000000000000000000000000..be5b60c9e05f8add5d979081551c7d8380b9dcb7 --- /dev/null +++ b/lama_cleaner/file_manager/storage_backends.py @@ -0,0 +1,46 @@ +# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/storage_backends.py +import errno +import os +from abc import ABC, abstractmethod + + +class BaseStorageBackend(ABC): + def __init__(self, app=None): + self.app = app + + @abstractmethod + def read(self, filepath, mode="rb", **kwargs): + raise NotImplementedError + + @abstractmethod + def exists(self, filepath): + raise NotImplementedError + + @abstractmethod + def save(self, filepath, data): + raise NotImplementedError + + +class FilesystemStorageBackend(BaseStorageBackend): + def read(self, filepath, mode="rb", **kwargs): + with open(filepath, mode) as f: # pylint: disable=unspecified-encoding + return f.read() + + def exists(self, filepath): + return os.path.exists(filepath) + + def save(self, filepath, data): + directory = os.path.dirname(filepath) + + if not os.path.exists(directory): + try: + os.makedirs(directory) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + if not os.path.isdir(directory): + raise IOError("{} is not a directory".format(directory)) + + with open(filepath, "wb") as f: + f.write(data) diff --git a/lama_cleaner/file_manager/utils.py b/lama_cleaner/file_manager/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0e4f4e3a3ce45b01b0c19aea70e526784ee7fa99 --- /dev/null +++ b/lama_cleaner/file_manager/utils.py @@ -0,0 +1,67 @@ +# Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py +import importlib +import os +from pathlib import Path + +from typing import Union + + +def generate_filename(original_filename, *options): + name, ext = os.path.splitext(original_filename) + for v in options: + if v: + name += "_%s" % v + name += ext + + return name + + +def parse_size(size): + if isinstance(size, int): + # If the size parameter is a single number, assume square aspect. + return [size, size] + + if isinstance(size, (tuple, list)): + if len(size) == 1: + # If single value tuple/list is provided, exand it to two elements + return size + type(size)(size) + return size + + try: + thumbnail_size = [int(x) for x in size.lower().split("x", 1)] + except ValueError: + raise ValueError( # pylint: disable=raise-missing-from + "Bad thumbnail size format. Valid format is INTxINT." + ) + + if len(thumbnail_size) == 1: + # If the size parameter only contains a single integer, assume square aspect. + thumbnail_size.append(thumbnail_size[0]) + + return thumbnail_size + + +def aspect_to_string(size): + if isinstance(size, str): + return size + + return "x".join(map(str, size)) + + +IMG_SUFFIX = {'.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG'} + + +def glob_img(p: Union[Path, str], recursive: bool = False): + p = Path(p) + if p.is_file() and p.suffix in IMG_SUFFIX: + yield p + else: + if recursive: + files = Path(p).glob("**/*.*") + else: + files = Path(p).glob("*.*") + + for it in files: + if it.suffix not in IMG_SUFFIX: + continue + yield it diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..8e5c08aafc6c7a9d12ab9e55c8d210a1d43e32d4 --- /dev/null +++ b/lama_cleaner/helper.py @@ -0,0 +1,292 @@ +import io +import os +import sys +from typing import List, Optional + +from urllib.parse import urlparse +import cv2 +from PIL import Image, ImageOps, PngImagePlugin +import numpy as np +import torch +from lama_cleaner.const import MPS_SUPPORT_MODELS +from loguru import logger +from torch.hub import download_url_to_file, get_dir +import hashlib + + +def md5sum(filename): + md5 = hashlib.md5() + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(128 * md5.block_size), b""): + md5.update(chunk) + return md5.hexdigest() + + +def switch_mps_device(model_name, device): + if model_name not in MPS_SUPPORT_MODELS and str(device) == "mps": + logger.info(f"{model_name} not support mps, switch to cpu") + return torch.device("cpu") + return device + + +def get_cache_path_by_url(url): + parts = urlparse(url) + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, "checkpoints") + if not os.path.isdir(model_dir): + os.makedirs(model_dir) + filename = os.path.basename(parts.path) + cached_file = os.path.join(model_dir, filename) + return cached_file + + +def download_model(url, model_md5: str = None): + cached_file = get_cache_path_by_url(url) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = None + download_url_to_file(url, cached_file, hash_prefix, progress=True) + if model_md5: + _md5 = md5sum(cached_file) + if model_md5 == _md5: + logger.info(f"Download model success, md5: {_md5}") + else: + try: + os.remove(cached_file) + logger.error( + f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart lama-cleaner." + f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n" + ) + except: + logger.error( + f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart lama-cleaner." + ) + exit(-1) + + return cached_file + + +def ceil_modulo(x, mod): + if x % mod == 0: + return x + return (x // mod + 1) * mod + + +def handle_error(model_path, model_md5, e): + _md5 = md5sum(model_path) + if _md5 != model_md5: + try: + os.remove(model_path) + logger.error( + f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart lama-cleaner." + f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n" + ) + except: + logger.error( + f"Model md5: {_md5}, expected md5: {model_md5}, please delete {model_path} and restart lama-cleaner." + ) + else: + logger.error( + f"Failed to load model {model_path}," + f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}" + ) + exit(-1) + + +def load_jit_model(url_or_path, device, model_md5: str): + if os.path.exists(url_or_path): + model_path = url_or_path + else: + model_path = download_model(url_or_path, model_md5) + + logger.info(f"Loading model from: {model_path}") + try: + model = torch.jit.load(model_path, map_location="cpu").to(device) + except Exception as e: + handle_error(model_path, model_md5, e) + model.eval() + return model + + +def load_model(model: torch.nn.Module, url_or_path, device, model_md5): + if os.path.exists(url_or_path): + model_path = url_or_path + else: + model_path = download_model(url_or_path, model_md5) + + try: + logger.info(f"Loading model from: {model_path}") + state_dict = torch.load(model_path, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + model.to(device) + except Exception as e: + handle_error(model_path, model_md5, e) + model.eval() + return model + + +def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: + data = cv2.imencode( + f".{ext}", + image_numpy, + [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], + )[1] + image_bytes = data.tobytes() + return image_bytes + + +def pil_to_bytes(pil_img, ext: str, quality: int = 95, exif_infos={}) -> bytes: + with io.BytesIO() as output: + kwargs = {k: v for k, v in exif_infos.items() if v is not None} + if ext == "png" and "parameters" in kwargs: + pnginfo_data = PngImagePlugin.PngInfo() + pnginfo_data.add_text("parameters", kwargs["parameters"]) + kwargs["pnginfo"] = pnginfo_data + + pil_img.save( + output, + format=ext, + quality=quality, + **kwargs, + ) + image_bytes = output.getvalue() + return image_bytes + + +def load_img(img_bytes, gray: bool = False, return_exif: bool = False): + alpha_channel = None + image = Image.open(io.BytesIO(img_bytes)) + + if return_exif: + info = image.info or {} + exif_infos = {"exif": image.getexif(), "parameters": info.get("parameters")} + + try: + image = ImageOps.exif_transpose(image) + except: + pass + + if gray: + image = image.convert("L") + np_img = np.array(image) + else: + if image.mode == "RGBA": + np_img = np.array(image) + alpha_channel = np_img[:, :, -1] + np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB) + else: + image = image.convert("RGB") + np_img = np.array(image) + + if return_exif: + return np_img, alpha_channel, exif_infos + return np_img, alpha_channel + + +def norm_img(np_img): + if len(np_img.shape) == 2: + np_img = np_img[:, :, np.newaxis] + np_img = np.transpose(np_img, (2, 0, 1)) + np_img = np_img.astype("float32") / 255 + return np_img + + +def resize_max_size( + np_img, size_limit: int, interpolation=cv2.INTER_CUBIC +) -> np.ndarray: + # Resize image's longer size to size_limit if longer size larger than size_limit + h, w = np_img.shape[:2] + if max(h, w) > size_limit: + ratio = size_limit / max(h, w) + new_w = int(w * ratio + 0.5) + new_h = int(h * ratio + 0.5) + return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation) + else: + return np_img + + +def pad_img_to_modulo( + img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None +): + """ + + Args: + img: [H, W, C] + mod: + square: 是否为正方形 + min_size: + + Returns: + + """ + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + height, width = img.shape[:2] + out_height = ceil_modulo(height, mod) + out_width = ceil_modulo(width, mod) + + if min_size is not None: + assert min_size % mod == 0 + out_width = max(min_size, out_width) + out_height = max(min_size, out_height) + + if square: + max_size = max(out_height, out_width) + out_height = max_size + out_width = max_size + + return np.pad( + img, + ((0, out_height - height), (0, out_width - width), (0, 0)), + mode="symmetric", + ) + + +def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]: + """ + Args: + mask: (h, w, 1) 0~255 + + Returns: + + """ + height, width = mask.shape[:2] + _, thresh = cv2.threshold(mask, 127, 255, 0) + contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + boxes = [] + for cnt in contours: + x, y, w, h = cv2.boundingRect(cnt) + box = np.array([x, y, x + w, y + h]).astype(int) + + box[::2] = np.clip(box[::2], 0, width) + box[1::2] = np.clip(box[1::2], 0, height) + boxes.append(box) + + return boxes + + +def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]: + """ + Args: + mask: (h, w) 0~255 + + Returns: + + """ + _, thresh = cv2.threshold(mask, 127, 255, 0) + contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + max_area = 0 + max_index = -1 + for i, cnt in enumerate(contours): + area = cv2.contourArea(cnt) + if area > max_area: + max_area = area + max_index = i + + if max_index != -1: + new_mask = np.zeros_like(mask) + return cv2.drawContours(new_mask, contours, max_index, 255, -1) + else: + return mask diff --git a/lama_cleaner/installer.py b/lama_cleaner/installer.py new file mode 100644 index 0000000000000000000000000000000000000000..f3ae5e750fbb91056030adb38184a80958c4fdf6 --- /dev/null +++ b/lama_cleaner/installer.py @@ -0,0 +1,12 @@ +import subprocess +import sys + + +def install(package): + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + + +def install_plugins_package(): + install("rembg") + install("realesrgan") + install("gfpgan") diff --git a/lama_cleaner/model/__init__.py b/lama_cleaner/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py new file mode 100644 index 0000000000000000000000000000000000000000..055d3f3be4075df4fc42b0fed3159ec5210711ae --- /dev/null +++ b/lama_cleaner/model/base.py @@ -0,0 +1,298 @@ +import abc +from typing import Optional + +import cv2 +import torch +import numpy as np +from loguru import logger + +from lama_cleaner.helper import ( + boxes_from_mask, + resize_max_size, + pad_img_to_modulo, + switch_mps_device, +) +from lama_cleaner.schema import Config, HDStrategy + + +class InpaintModel: + name = "base" + min_size: Optional[int] = None + pad_mod = 8 + pad_to_square = False + + def __init__(self, device, **kwargs): + """ + + Args: + device: + """ + device = switch_mps_device(self.name, device) + self.device = device + self.init_model(device, **kwargs) + + @abc.abstractmethod + def init_model(self, device, **kwargs): + ... + + @staticmethod + @abc.abstractmethod + def is_downloaded() -> bool: + ... + + @abc.abstractmethod + def forward(self, image, mask, config: Config): + """Input images and output images have same size + images: [H, W, C] RGB + masks: [H, W, 1] 255 为 masks 区域 + return: BGR IMAGE + """ + ... + + def _pad_forward(self, image, mask, config: Config): + origin_height, origin_width = image.shape[:2] + pad_image = pad_img_to_modulo( + image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size + ) + pad_mask = pad_img_to_modulo( + mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size + ) + + logger.info(f"final forward pad size: {pad_image.shape}") + + result = self.forward(pad_image, pad_mask, config) + result = result[0:origin_height, 0:origin_width, :] + + result, image, mask = self.forward_post_process(result, image, mask, config) + + mask = mask[:, :, np.newaxis] + result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255)) + return result + + def forward_post_process(self, result, image, mask, config): + return result, image, mask + + @torch.no_grad() + def __call__(self, image, mask, config: Config): + """ + images: [H, W, C] RGB, not normalized + masks: [H, W] + return: BGR IMAGE + """ + inpaint_result = None + logger.info(f"hd_strategy: {config.hd_strategy}") + if config.hd_strategy == HDStrategy.CROP: + if max(image.shape) > config.hd_strategy_crop_trigger_size: + logger.info(f"Run crop strategy") + boxes = boxes_from_mask(mask) + crop_result = [] + for box in boxes: + crop_image, crop_box = self._run_box(image, mask, box, config) + crop_result.append((crop_image, crop_box)) + + inpaint_result = image[:, :, ::-1] + for crop_image, crop_box in crop_result: + x1, y1, x2, y2 = crop_box + inpaint_result[y1:y2, x1:x2, :] = crop_image + + elif config.hd_strategy == HDStrategy.RESIZE: + if max(image.shape) > config.hd_strategy_resize_limit: + origin_size = image.shape[:2] + downsize_image = resize_max_size( + image, size_limit=config.hd_strategy_resize_limit + ) + downsize_mask = resize_max_size( + mask, size_limit=config.hd_strategy_resize_limit + ) + + logger.info( + f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}" + ) + inpaint_result = self._pad_forward( + downsize_image, downsize_mask, config + ) + + # only paste masked area result + inpaint_result = cv2.resize( + inpaint_result, + (origin_size[1], origin_size[0]), + interpolation=cv2.INTER_CUBIC, + ) + original_pixel_indices = mask < 127 + inpaint_result[original_pixel_indices] = image[:, :, ::-1][ + original_pixel_indices + ] + + if inpaint_result is None: + inpaint_result = self._pad_forward(image, mask, config) + + return inpaint_result + + def _crop_box(self, image, mask, box, config: Config): + """ + + Args: + image: [H, W, C] RGB + mask: [H, W, 1] + box: [left,top,right,bottom] + + Returns: + BGR IMAGE, (l, r, r, b) + """ + box_h = box[3] - box[1] + box_w = box[2] - box[0] + cx = (box[0] + box[2]) // 2 + cy = (box[1] + box[3]) // 2 + img_h, img_w = image.shape[:2] + + w = box_w + config.hd_strategy_crop_margin * 2 + h = box_h + config.hd_strategy_crop_margin * 2 + + _l = cx - w // 2 + _r = cx + w // 2 + _t = cy - h // 2 + _b = cy + h // 2 + + l = max(_l, 0) + r = min(_r, img_w) + t = max(_t, 0) + b = min(_b, img_h) + + # try to get more context when crop around image edge + if _l < 0: + r += abs(_l) + if _r > img_w: + l -= _r - img_w + if _t < 0: + b += abs(_t) + if _b > img_h: + t -= _b - img_h + + l = max(l, 0) + r = min(r, img_w) + t = max(t, 0) + b = min(b, img_h) + + crop_img = image[t:b, l:r, :] + crop_mask = mask[t:b, l:r] + + logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}") + + return crop_img, crop_mask, [l, t, r, b] + + def _calculate_cdf(self, histogram): + cdf = histogram.cumsum() + normalized_cdf = cdf / float(cdf.max()) + return normalized_cdf + + def _calculate_lookup(self, source_cdf, reference_cdf): + lookup_table = np.zeros(256) + lookup_val = 0 + for source_index, source_val in enumerate(source_cdf): + for reference_index, reference_val in enumerate(reference_cdf): + if reference_val >= source_val: + lookup_val = reference_index + break + lookup_table[source_index] = lookup_val + return lookup_table + + def _match_histograms(self, source, reference, mask): + transformed_channels = [] + for channel in range(source.shape[-1]): + source_channel = source[:, :, channel] + reference_channel = reference[:, :, channel] + + # only calculate histograms for non-masked parts + source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256]) + reference_histogram, _ = np.histogram( + reference_channel[mask == 0], 256, [0, 256] + ) + + source_cdf = self._calculate_cdf(source_histogram) + reference_cdf = self._calculate_cdf(reference_histogram) + + lookup = self._calculate_lookup(source_cdf, reference_cdf) + + transformed_channels.append(cv2.LUT(source_channel, lookup)) + + result = cv2.merge(transformed_channels) + result = cv2.convertScaleAbs(result) + + return result + + def _apply_cropper(self, image, mask, config: Config): + img_h, img_w = image.shape[:2] + l, t, w, h = ( + config.croper_x, + config.croper_y, + config.croper_width, + config.croper_height, + ) + r = l + w + b = t + h + + l = max(l, 0) + r = min(r, img_w) + t = max(t, 0) + b = min(b, img_h) + + crop_img = image[t:b, l:r, :] + crop_mask = mask[t:b, l:r] + return crop_img, crop_mask, (l, t, r, b) + + def _run_box(self, image, mask, box, config: Config): + """ + + Args: + image: [H, W, C] RGB + mask: [H, W, 1] + box: [left,top,right,bottom] + + Returns: + BGR IMAGE + """ + crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config) + + return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b] + + +class DiffusionInpaintModel(InpaintModel): + @torch.no_grad() + def __call__(self, image, mask, config: Config): + """ + images: [H, W, C] RGB, not normalized + masks: [H, W] + return: BGR IMAGE + """ + # boxes = boxes_from_mask(mask) + if config.use_croper: + crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config) + crop_image = self._scaled_pad_forward(crop_img, crop_mask, config) + inpaint_result = image[:, :, ::-1] + inpaint_result[t:b, l:r, :] = crop_image + else: + inpaint_result = self._scaled_pad_forward(image, mask, config) + + return inpaint_result + + def _scaled_pad_forward(self, image, mask, config: Config): + longer_side_length = int(config.sd_scale * max(image.shape[:2])) + origin_size = image.shape[:2] + downsize_image = resize_max_size(image, size_limit=longer_side_length) + downsize_mask = resize_max_size(mask, size_limit=longer_side_length) + if config.sd_scale != 1: + logger.info( + f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}" + ) + inpaint_result = self._pad_forward(downsize_image, downsize_mask, config) + # only paste masked area result + inpaint_result = cv2.resize( + inpaint_result, + (origin_size[1], origin_size[0]), + interpolation=cv2.INTER_CUBIC, + ) + original_pixel_indices = mask < 127 + inpaint_result[original_pixel_indices] = image[:, :, ::-1][ + original_pixel_indices + ] + return inpaint_result diff --git a/lama_cleaner/model/controlnet.py b/lama_cleaner/model/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ec2beb88fcea9866409c9aa22965ea6671ee224a --- /dev/null +++ b/lama_cleaner/model/controlnet.py @@ -0,0 +1,289 @@ +import gc + +import PIL.Image +import cv2 +import numpy as np +import torch +from diffusers import ControlNetModel +from loguru import logger + +from lama_cleaner.model.base import DiffusionInpaintModel +from lama_cleaner.model.utils import torch_gc, get_scheduler +from lama_cleaner.schema import Config + + +class CPUTextEncoderWrapper: + def __init__(self, text_encoder, torch_dtype): + self.config = text_encoder.config + self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True) + self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) + self.torch_dtype = torch_dtype + del text_encoder + torch_gc() + + def __call__(self, x, **kwargs): + input_device = x.device + return [ + self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0] + .to(input_device) + .to(self.torch_dtype) + ] + + @property + def dtype(self): + return self.torch_dtype + + +NAMES_MAP = { + "sd1.5": "runwayml/stable-diffusion-inpainting", + "anything4": "Sanster/anything-4.0-inpainting", + "realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting", +} + +NATIVE_NAMES_MAP = { + "sd1.5": "runwayml/stable-diffusion-v1-5", + "anything4": "andite/anything-v4.0", + "realisticVision1.4": "SG161222/Realistic_Vision_V1.4", +} + + +def make_inpaint_condition(image, image_mask): + """ + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + """ + image = image.astype(np.float32) / 255.0 + image[image_mask[:, :, -1] > 128] = -1.0 # set as masked pixel + image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return image + + +def load_from_local_model( + local_model_path, torch_dtype, controlnet, pipe_class, is_native_control_inpaint +): + from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + download_from_original_stable_diffusion_ckpt, + ) + + logger.info(f"Converting {local_model_path} to diffusers controlnet pipeline") + + try: + pipe = download_from_original_stable_diffusion_ckpt( + local_model_path, + num_in_channels=4 if is_native_control_inpaint else 9, + from_safetensors=local_model_path.endswith("safetensors"), + device="cpu", + load_safety_checker=False, + ) + except Exception as e: + err_msg = str(e) + logger.exception(e) + if is_native_control_inpaint and "[320, 9, 3, 3]" in err_msg: + logger.error( + "control_v11p_sd15_inpaint method requires normal SD model, not inpainting SD model" + ) + if not is_native_control_inpaint and "[320, 4, 3, 3]" in err_msg: + logger.error( + f"{controlnet.config['_name_or_path']} method requires inpainting SD model, " + f"you can convert any SD model to inpainting model in AUTO1111: \n" + f"https://www.reddit.com/r/StableDiffusion/comments/zyi24j/how_to_turn_any_model_into_an_inpainting_model/" + ) + exit(-1) + + inpaint_pipe = pipe_class( + vae=pipe.vae, + text_encoder=pipe.text_encoder, + tokenizer=pipe.tokenizer, + unet=pipe.unet, + controlnet=controlnet, + scheduler=pipe.scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + + del pipe + gc.collect() + return inpaint_pipe.to(torch_dtype=torch_dtype) + + +class ControlNet(DiffusionInpaintModel): + name = "controlnet" + pad_mod = 8 + min_size = 512 + + def init_model(self, device: torch.device, **kwargs): + fp16 = not kwargs.get("no_half", False) + + model_kwargs = { + "local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"]) + } + if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): + logger.info("Disable Stable Diffusion Model NSFW checker") + model_kwargs.update( + dict( + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + ) + + use_gpu = device == torch.device("cuda") and torch.cuda.is_available() + torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + + sd_controlnet_method = kwargs["sd_controlnet_method"] + self.sd_controlnet_method = sd_controlnet_method + + if sd_controlnet_method == "control_v11p_sd15_inpaint": + from diffusers import StableDiffusionControlNetPipeline as PipeClass + + self.is_native_control_inpaint = True + else: + from .pipeline import StableDiffusionControlNetInpaintPipeline as PipeClass + + self.is_native_control_inpaint = False + + if self.is_native_control_inpaint: + model_id = NATIVE_NAMES_MAP[kwargs["name"]] + else: + model_id = NAMES_MAP[kwargs["name"]] + + controlnet = ControlNetModel.from_pretrained( + f"lllyasviel/{sd_controlnet_method}", torch_dtype=torch_dtype + ) + self.is_local_sd_model = False + if kwargs.get("sd_local_model_path", None): + self.is_local_sd_model = True + self.model = load_from_local_model( + kwargs["sd_local_model_path"], + torch_dtype=torch_dtype, + controlnet=controlnet, + pipe_class=PipeClass, + is_native_control_inpaint=self.is_native_control_inpaint, + ) + else: + self.model = PipeClass.from_pretrained( + model_id, + controlnet=controlnet, + revision="fp16" if use_gpu and fp16 else "main", + torch_dtype=torch_dtype, + **model_kwargs, + ) + + # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing + self.model.enable_attention_slicing() + # https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention + if kwargs.get("enable_xformers", False): + self.model.enable_xformers_memory_efficient_attention() + + if kwargs.get("cpu_offload", False) and use_gpu: + logger.info("Enable sequential cpu offload") + self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) + if kwargs["sd_cpu_textencoder"]: + logger.info("Run Stable Diffusion TextEncoder on CPU") + self.model.text_encoder = CPUTextEncoderWrapper( + self.model.text_encoder, torch_dtype + ) + + self.callback = kwargs.pop("callback", None) + + def forward(self, image, mask, config: Config): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + scheduler_config = self.model.scheduler.config + scheduler = get_scheduler(config.sd_sampler, scheduler_config) + self.model.scheduler = scheduler + + if config.sd_mask_blur != 0: + k = 2 * config.sd_mask_blur + 1 + mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] + + img_h, img_w = image.shape[:2] + + if self.is_native_control_inpaint: + control_image = make_inpaint_condition(image, mask) + output = self.model( + prompt=config.prompt, + image=control_image, + height=img_h, + width=img_w, + num_inference_steps=config.sd_steps, + guidance_scale=config.sd_guidance_scale, + controlnet_conditioning_scale=config.controlnet_conditioning_scale, + negative_prompt=config.negative_prompt, + generator=torch.manual_seed(config.sd_seed), + output_type="np.array", + callback=self.callback, + ).images[0] + else: + if "canny" in self.sd_controlnet_method: + canny_image = cv2.Canny(image, 100, 200) + canny_image = canny_image[:, :, None] + canny_image = np.concatenate( + [canny_image, canny_image, canny_image], axis=2 + ) + canny_image = PIL.Image.fromarray(canny_image) + control_image = canny_image + elif "openpose" in self.sd_controlnet_method: + from controlnet_aux import OpenposeDetector + + processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") + control_image = processor(image, hand_and_face=True) + elif "depth" in self.sd_controlnet_method: + from transformers import pipeline + + depth_estimator = pipeline("depth-estimation") + depth_image = depth_estimator(PIL.Image.fromarray(image))["depth"] + depth_image = np.array(depth_image) + depth_image = depth_image[:, :, None] + depth_image = np.concatenate( + [depth_image, depth_image, depth_image], axis=2 + ) + control_image = PIL.Image.fromarray(depth_image) + else: + raise NotImplementedError( + f"{self.sd_controlnet_method} not implemented" + ) + + mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L") + image = PIL.Image.fromarray(image) + + output = self.model( + image=image, + control_image=control_image, + prompt=config.prompt, + negative_prompt=config.negative_prompt, + mask_image=mask_image, + num_inference_steps=config.sd_steps, + guidance_scale=config.sd_guidance_scale, + output_type="np.array", + callback=self.callback, + height=img_h, + width=img_w, + generator=torch.manual_seed(config.sd_seed), + controlnet_conditioning_scale=config.controlnet_conditioning_scale, + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + def forward_post_process(self, result, image, mask, config): + if config.sd_match_histograms: + result = self._match_histograms(result, image[:, :, ::-1], mask) + + if config.sd_mask_blur != 0: + k = 2 * config.sd_mask_blur + 1 + mask = cv2.GaussianBlur(mask, (k, k), 0) + return result, image, mask + + @staticmethod + def is_downloaded() -> bool: + # model will be downloaded when app start, and can't switch in frontend settings + return True diff --git a/lama_cleaner/model/ddim_sampler.py b/lama_cleaner/model/ddim_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..d153090df0828bba7490dc546df781ae6211c858 --- /dev/null +++ b/lama_cleaner/model/ddim_sampler.py @@ -0,0 +1,193 @@ +import torch +import numpy as np +from tqdm import tqdm + +from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like + +from loguru import logger + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear"): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + # array([1]) + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000]) + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample(self, steps, conditioning, batch_size, shape): + self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + # samples: 1,3,128,128 + return self.ddim_sampling( + conditioning, + size, + quantize_denoised=False, + ddim_use_original_steps=False, + noise_dropout=0, + temperature=1.0, + ) + + @torch.no_grad() + def ddim_sampling( + self, + cond, + shape, + ddim_use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + ): + device = self.model.betas.device + b = shape[0] + img = torch.randn(shape, device=device, dtype=cond.dtype) + timesteps = ( + self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + ) + + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + logger.info(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + ) + img, _ = outs + + return img + + @torch.no_grad() + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + ): + b, *_, device = *x.shape, x.device + e_t = self.model.apply_model(x, t, c) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: # 没用 + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: # 没用 + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/lama_cleaner/model/fcf.py b/lama_cleaner/model/fcf.py new file mode 100644 index 0000000000000000000000000000000000000000..65d8bc390d4da59689bd5ea0e64dc3fda71d1ab3 --- /dev/null +++ b/lama_cleaner/model/fcf.py @@ -0,0 +1,1733 @@ +import os +import random + +import cv2 +import torch +import numpy as np +import torch.fft as fft + +from lama_cleaner.schema import Config + +from lama_cleaner.helper import ( + load_model, + get_cache_path_by_url, + norm_img, + boxes_from_mask, + resize_max_size, +) +from lama_cleaner.model.base import InpaintModel +from torch import conv2d, nn +import torch.nn.functional as F + +from lama_cleaner.model.utils import ( + setup_filter, + _parse_scaling, + _parse_padding, + Conv2dLayer, + FullyConnectedLayer, + MinibatchStdLayer, + activation_funcs, + conv2d_resample, + bias_act, + upsample2d, + normalize_2nd_moment, + downsample2d, +) + + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"): + assert isinstance(x, torch.Tensor) + return _upfirdn2d_ref( + x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain + ) + + +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.""" + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad( + x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)] + ) + x = x[ + :, + :, + max(-pady0, 0) : x.shape[2] - max(-pady1, 0), + max(-padx0, 0) : x.shape[3] - max(-padx1, 0), + ] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + + +class EncoderEpilogue(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels. + cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label. + z_dim, # Output Latent (Z) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + architecture="resnet", # Architecture: 'orig', 'skip', 'resnet'. + mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + ): + assert architecture in ["orig", "skip", "resnet"] + super().__init__() + self.in_channels = in_channels + self.cmap_dim = cmap_dim + self.resolution = resolution + self.img_channels = img_channels + self.architecture = architecture + + if architecture == "skip": + self.fromrgb = Conv2dLayer( + self.img_channels, in_channels, kernel_size=1, activation=activation + ) + self.mbstd = ( + MinibatchStdLayer( + group_size=mbstd_group_size, num_channels=mbstd_num_channels + ) + if mbstd_num_channels > 0 + else None + ) + self.conv = Conv2dLayer( + in_channels + mbstd_num_channels, + in_channels, + kernel_size=3, + activation=activation, + conv_clamp=conv_clamp, + ) + self.fc = FullyConnectedLayer( + in_channels * (resolution**2), z_dim, activation=activation + ) + self.dropout = torch.nn.Dropout(p=0.5) + + def forward(self, x, cmap, force_fp32=False): + _ = force_fp32 # unused + dtype = torch.float32 + memory_format = torch.contiguous_format + + # FromRGB. + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.mbstd is not None: + x = self.mbstd(x) + const_e = self.conv(x) + x = self.fc(const_e.flatten(1)) + x = self.dropout(x) + + # Conditioning. + if self.cmap_dim > 0: + x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + + assert x.dtype == dtype + return x, const_e + + +class EncoderBlock(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels, 0 = first block. + tmp_channels, # Number of intermediate channels. + out_channels, # Number of output channels. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + first_layer_idx, # Index of the first layer. + architecture="skip", # Architecture: 'orig', 'skip', 'resnet'. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16=False, # Use FP16 for this block? + fp16_channels_last=False, # Use channels-last memory format with FP16? + freeze_layers=0, # Freeze-D: Number of layers to freeze. + ): + assert in_channels in [0, tmp_channels] + assert architecture in ["orig", "skip", "resnet"] + super().__init__() + self.in_channels = in_channels + self.resolution = resolution + self.img_channels = img_channels + 1 + self.first_layer_idx = first_layer_idx + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = use_fp16 and fp16_channels_last + self.register_buffer("resample_filter", setup_filter(resample_filter)) + + self.num_layers = 0 + + def trainable_gen(): + while True: + layer_idx = self.first_layer_idx + self.num_layers + trainable = layer_idx >= freeze_layers + self.num_layers += 1 + yield trainable + + trainable_iter = trainable_gen() + + if in_channels == 0: + self.fromrgb = Conv2dLayer( + self.img_channels, + tmp_channels, + kernel_size=1, + activation=activation, + trainable=next(trainable_iter), + conv_clamp=conv_clamp, + channels_last=self.channels_last, + ) + + self.conv0 = Conv2dLayer( + tmp_channels, + tmp_channels, + kernel_size=3, + activation=activation, + trainable=next(trainable_iter), + conv_clamp=conv_clamp, + channels_last=self.channels_last, + ) + + self.conv1 = Conv2dLayer( + tmp_channels, + out_channels, + kernel_size=3, + activation=activation, + down=2, + trainable=next(trainable_iter), + resample_filter=resample_filter, + conv_clamp=conv_clamp, + channels_last=self.channels_last, + ) + + if architecture == "resnet": + self.skip = Conv2dLayer( + tmp_channels, + out_channels, + kernel_size=1, + bias=False, + down=2, + trainable=next(trainable_iter), + resample_filter=resample_filter, + channels_last=self.channels_last, + ) + + def forward(self, x, img, force_fp32=False): + # dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + dtype = torch.float32 + memory_format = ( + torch.channels_last + if self.channels_last and not force_fp32 + else torch.contiguous_format + ) + + # Input. + if x is not None: + x = x.to(dtype=dtype, memory_format=memory_format) + + # FromRGB. + if self.in_channels == 0: + img = img.to(dtype=dtype, memory_format=memory_format) + y = self.fromrgb(img) + x = x + y if x is not None else y + img = ( + downsample2d(img, self.resample_filter) + if self.architecture == "skip" + else None + ) + + # Main layers. + if self.architecture == "resnet": + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x) + feat = x.clone() + x = self.conv1(x, gain=np.sqrt(0.5)) + x = y.add_(x) + else: + x = self.conv0(x) + feat = x.clone() + x = self.conv1(x) + + assert x.dtype == dtype + return x, img, feat + + +class EncoderNetwork(torch.nn.Module): + def __init__( + self, + c_dim, # Conditioning label (C) dimensionality. + z_dim, # Input latent (Z) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture="orig", # Architecture: 'orig', 'skip', 'resnet'. + channel_base=16384, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=0, # Use FP16 for the N highest resolutions. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + block_kwargs={}, # Arguments for DiscriminatorBlock. + mapping_kwargs={}, # Arguments for MappingNetwork. + epilogue_kwargs={}, # Arguments for EncoderEpilogue. + ): + super().__init__() + self.c_dim = c_dim + self.z_dim = z_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [ + 2**i for i in range(self.img_resolution_log2, 2, -1) + ] + channels_dict = { + res: min(channel_base // res, channel_max) + for res in self.block_resolutions + [4] + } + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict( + img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp + ) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = res >= fp16_resolution + use_fp16 = False + block = EncoderBlock( + in_channels, + tmp_channels, + out_channels, + resolution=res, + first_layer_idx=cur_layer_idx, + use_fp16=use_fp16, + **block_kwargs, + **common_kwargs, + ) + setattr(self, f"b{res}", block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork( + z_dim=0, + c_dim=c_dim, + w_dim=cmap_dim, + num_ws=None, + w_avg_beta=None, + **mapping_kwargs, + ) + self.b4 = EncoderEpilogue( + channels_dict[4], + cmap_dim=cmap_dim, + z_dim=z_dim * 2, + resolution=4, + **epilogue_kwargs, + **common_kwargs, + ) + + def forward(self, img, c, **block_kwargs): + x = None + feats = {} + for res in self.block_resolutions: + block = getattr(self, f"b{res}") + x, img, feat = block(x, img, **block_kwargs) + feats[res] = feat + + cmap = None + if self.c_dim > 0: + cmap = self.mapping(None, c) + x, const_e = self.b4(x, cmap) + feats[4] = const_e + + B, _ = x.shape + z = torch.zeros( + (B, self.z_dim), requires_grad=False, dtype=x.dtype, device=x.device + ) ## Noise for Co-Modulation + return x, z, feats + + +def fma(a, b, c): # => a * b + c + return _FusedMultiplyAdd.apply(a, b, c) + + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + @staticmethod + def forward(ctx, a, b, c): # pylint: disable=arguments-differ + out = torch.addcmul(c, a, b) + ctx.save_for_backward(a, b) + ctx.c_shape = c.shape + return out + + @staticmethod + def backward(ctx, dout): # pylint: disable=arguments-differ + a, b = ctx.saved_tensors + c_shape = ctx.c_shape + da = None + db = None + dc = None + + if ctx.needs_input_grad[0]: + da = _unbroadcast(dout * b, a.shape) + + if ctx.needs_input_grad[1]: + db = _unbroadcast(dout * a, b.shape) + + if ctx.needs_input_grad[2]: + dc = _unbroadcast(dout, c_shape) + + return da, db, dc + + +def _unbroadcast(x, shape): + extra_dims = x.ndim - len(shape) + assert extra_dims >= 0 + dim = [ + i + for i in range(x.ndim) + if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1) + ] + if len(dim): + x = x.sum(dim=dim, keepdim=True) + if extra_dims: + x = x.reshape(-1, *x.shape[extra_dims + 1 :]) + assert x.shape == shape + return x + + +def modulated_conv2d( + x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. + weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. + styles, # Modulation coefficients of shape [batch_size, in_channels]. + noise=None, # Optional noise tensor to add to the output activations. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + padding=0, # Padding with respect to the upsampled image. + resample_filter=None, + # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). + demodulate=True, # Apply weight demodulation? + flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). + fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation? +): + batch_size = x.shape[0] + out_channels, in_channels, kh, kw = weight.shape + + # Pre-normalize inputs to avoid FP16 overflow. + if x.dtype == torch.float16 and demodulate: + weight = weight * ( + 1 + / np.sqrt(in_channels * kh * kw) + / weight.norm(float("inf"), dim=[1, 2, 3], keepdim=True) + ) # max_Ikk + styles = styles / styles.norm(float("inf"), dim=1, keepdim=True) # max_I + + # Calculate per-sample weights and demodulation coefficients. + w = None + dcoefs = None + if demodulate or fused_modconv: + w = weight.unsqueeze(0) # [NOIkk] + w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] + if demodulate: + dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO] + if demodulate and fused_modconv: + w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] + # Execute by scaling the activations before and after the convolution. + if not fused_modconv: + x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) + x = conv2d_resample.conv2d_resample( + x=x, + w=weight.to(x.dtype), + f=resample_filter, + up=up, + down=down, + padding=padding, + flip_weight=flip_weight, + ) + if demodulate and noise is not None: + x = fma( + x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype) + ) + elif demodulate: + x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) + elif noise is not None: + x = x.add_(noise.to(x.dtype)) + return x + + # Execute as one fused op using grouped convolution. + batch_size = int(batch_size) + x = x.reshape(1, -1, *x.shape[2:]) + w = w.reshape(-1, in_channels, kh, kw) + x = conv2d_resample( + x=x, + w=w.to(x.dtype), + f=resample_filter, + up=up, + down=down, + padding=padding, + groups=batch_size, + flip_weight=flip_weight, + ) + x = x.reshape(batch_size, -1, *x.shape[2:]) + if noise is not None: + x = x.add_(noise) + return x + + +class SynthesisLayer(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this layer. + kernel_size=3, # Convolution kernel size. + up=1, # Integer upsampling factor. + use_noise=True, # Enable noise input? + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + channels_last=False, # Use channels_last format for the weights? + ): + super().__init__() + self.resolution = resolution + self.up = up + self.use_noise = use_noise + self.activation = activation + self.conv_clamp = conv_clamp + self.register_buffer("resample_filter", setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.act_gain = activation_funcs[activation].def_gain + + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + memory_format = ( + torch.channels_last if channels_last else torch.contiguous_format + ) + self.weight = torch.nn.Parameter( + torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to( + memory_format=memory_format + ) + ) + if use_noise: + self.register_buffer("noise_const", torch.randn([resolution, resolution])) + self.noise_strength = torch.nn.Parameter(torch.zeros([])) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + + def forward(self, x, w, noise_mode="none", fused_modconv=True, gain=1): + assert noise_mode in ["random", "const", "none"] + in_resolution = self.resolution // self.up + styles = self.affine(w) + + noise = None + if self.use_noise and noise_mode == "random": + noise = ( + torch.randn( + [x.shape[0], 1, self.resolution, self.resolution], device=x.device + ) + * self.noise_strength + ) + if self.use_noise and noise_mode == "const": + noise = self.noise_const * self.noise_strength + + flip_weight = self.up == 1 # slightly faster + x = modulated_conv2d( + x=x, + weight=self.weight, + styles=styles, + noise=noise, + up=self.up, + padding=self.padding, + resample_filter=self.resample_filter, + flip_weight=flip_weight, + fused_modconv=fused_modconv, + ) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = F.leaky_relu(x, negative_slope=0.2, inplace=False) + if act_gain != 1: + x = x * act_gain + if act_clamp is not None: + x = x.clamp(-act_clamp, act_clamp) + return x + + +class ToRGBLayer(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + w_dim, + kernel_size=1, + conv_clamp=None, + channels_last=False, + ): + super().__init__() + self.conv_clamp = conv_clamp + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + memory_format = ( + torch.channels_last if channels_last else torch.contiguous_format + ) + self.weight = torch.nn.Parameter( + torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to( + memory_format=memory_format + ) + ) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2)) + + def forward(self, x, w, fused_modconv=True): + styles = self.affine(w) * self.weight_gain + x = modulated_conv2d( + x=x, + weight=self.weight, + styles=styles, + demodulate=False, + fused_modconv=fused_modconv, + ) + x = bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) + return x + + +class SynthesisForeword(torch.nn.Module): + def __init__( + self, + z_dim, # Output Latent (Z) dimensionality. + resolution, # Resolution of this block. + in_channels, + img_channels, # Number of input color channels. + architecture="skip", # Architecture: 'orig', 'skip', 'resnet'. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + ): + super().__init__() + self.in_channels = in_channels + self.z_dim = z_dim + self.resolution = resolution + self.img_channels = img_channels + self.architecture = architecture + + self.fc = FullyConnectedLayer( + self.z_dim, (self.z_dim // 2) * 4 * 4, activation=activation + ) + self.conv = SynthesisLayer( + self.in_channels, self.in_channels, w_dim=(z_dim // 2) * 3, resolution=4 + ) + + if architecture == "skip": + self.torgb = ToRGBLayer( + self.in_channels, + self.img_channels, + kernel_size=1, + w_dim=(z_dim // 2) * 3, + ) + + def forward(self, x, ws, feats, img, force_fp32=False): + _ = force_fp32 # unused + dtype = torch.float32 + memory_format = torch.contiguous_format + + x_global = x.clone() + # ToRGB. + x = self.fc(x) + x = x.view(-1, self.z_dim // 2, 4, 4) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + x_skip = feats[4].clone() + x = x + x_skip + + mod_vector = [] + mod_vector.append(ws[:, 0]) + mod_vector.append(x_global.clone()) + mod_vector = torch.cat(mod_vector, dim=1) + + x = self.conv(x, mod_vector) + + mod_vector = [] + mod_vector.append(ws[:, 2 * 2 - 3]) + mod_vector.append(x_global.clone()) + mod_vector = torch.cat(mod_vector, dim=1) + + if self.architecture == "skip": + img = self.torgb(x, mod_vector) + img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format) + + assert x.dtype == dtype + return x, img + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=16): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=False), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + res = x * y.expand_as(x) + return res + + +class FourierUnit(nn.Module): + def __init__( + self, + in_channels, + out_channels, + groups=1, + spatial_scale_factor=None, + spatial_scale_mode="bilinear", + spectral_pos_encoding=False, + use_se=False, + se_kwargs=None, + ffc3d=False, + fft_norm="ortho", + ): + # bn_layer not used + super(FourierUnit, self).__init__() + self.groups = groups + + self.conv_layer = torch.nn.Conv2d( + in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), + out_channels=out_channels * 2, + kernel_size=1, + stride=1, + padding=0, + groups=self.groups, + bias=False, + ) + self.relu = torch.nn.ReLU(inplace=False) + + # squeeze and excitation block + self.use_se = use_se + if use_se: + if se_kwargs is None: + se_kwargs = {} + self.se = SELayer(self.conv_layer.in_channels, **se_kwargs) + + self.spatial_scale_factor = spatial_scale_factor + self.spatial_scale_mode = spatial_scale_mode + self.spectral_pos_encoding = spectral_pos_encoding + self.ffc3d = ffc3d + self.fft_norm = fft_norm + + def forward(self, x): + batch = x.shape[0] + + if self.spatial_scale_factor is not None: + orig_size = x.shape[-2:] + x = F.interpolate( + x, + scale_factor=self.spatial_scale_factor, + mode=self.spatial_scale_mode, + align_corners=False, + ) + + r_size = x.size() + # (batch, c, h, w/2+1, 2) + fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) + ffted = fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) + ffted = torch.stack((ffted.real, ffted.imag), dim=-1) + ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) + ffted = ffted.view( + ( + batch, + -1, + ) + + ffted.size()[3:] + ) + + if self.spectral_pos_encoding: + height, width = ffted.shape[-2:] + coords_vert = ( + torch.linspace(0, 1, height)[None, None, :, None] + .expand(batch, 1, height, width) + .to(ffted) + ) + coords_hor = ( + torch.linspace(0, 1, width)[None, None, None, :] + .expand(batch, 1, height, width) + .to(ffted) + ) + ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) + + if self.use_se: + ffted = self.se(ffted) + + ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) + ffted = self.relu(ffted) + + ffted = ( + ffted.view( + ( + batch, + -1, + 2, + ) + + ffted.size()[2:] + ) + .permute(0, 1, 3, 4, 2) + .contiguous() + ) # (batch,c, t, h, w/2+1, 2) + ffted = torch.complex(ffted[..., 0], ffted[..., 1]) + + ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] + output = torch.fft.irfftn( + ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm + ) + + if self.spatial_scale_factor is not None: + output = F.interpolate( + output, + size=orig_size, + mode=self.spatial_scale_mode, + align_corners=False, + ) + + return output + + +class SpectralTransform(nn.Module): + def __init__( + self, + in_channels, + out_channels, + stride=1, + groups=1, + enable_lfu=True, + **fu_kwargs, + ): + # bn_layer not used + super(SpectralTransform, self).__init__() + self.enable_lfu = enable_lfu + if stride == 2: + self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) + else: + self.downsample = nn.Identity() + + self.stride = stride + self.conv1 = nn.Sequential( + nn.Conv2d( + in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False + ), + # nn.BatchNorm2d(out_channels // 2), + nn.ReLU(inplace=True), + ) + self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups, **fu_kwargs) + if self.enable_lfu: + self.lfu = FourierUnit(out_channels // 2, out_channels // 2, groups) + self.conv2 = torch.nn.Conv2d( + out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False + ) + + def forward(self, x): + + x = self.downsample(x) + x = self.conv1(x) + output = self.fu(x) + + if self.enable_lfu: + n, c, h, w = x.shape + split_no = 2 + split_s = h // split_no + xs = torch.cat( + torch.split(x[:, : c // 4], split_s, dim=-2), dim=1 + ).contiguous() + xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous() + xs = self.lfu(xs) + xs = xs.repeat(1, 1, split_no, split_no).contiguous() + else: + xs = 0 + + output = self.conv2(x + output + xs) + + return output + + +class FFC(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False, + enable_lfu=True, + padding_type="reflect", + gated=False, + **spectral_kwargs, + ): + super(FFC, self).__init__() + + assert stride == 1 or stride == 2, "Stride should be 1 or 2." + self.stride = stride + + in_cg = int(in_channels * ratio_gin) + in_cl = in_channels - in_cg + out_cg = int(out_channels * ratio_gout) + out_cl = out_channels - out_cg + # groups_g = 1 if groups == 1 else int(groups * ratio_gout) + # groups_l = 1 if groups == 1 else groups - groups_g + + self.ratio_gin = ratio_gin + self.ratio_gout = ratio_gout + self.global_in_num = in_cg + + module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d + self.convl2l = module( + in_cl, + out_cl, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type, + ) + module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d + self.convl2g = module( + in_cl, + out_cg, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type, + ) + module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d + self.convg2l = module( + in_cg, + out_cl, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type, + ) + module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform + self.convg2g = module( + in_cg, + out_cg, + stride, + 1 if groups == 1 else groups // 2, + enable_lfu, + **spectral_kwargs, + ) + + self.gated = gated + module = ( + nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d + ) + self.gate = module(in_channels, 2, 1) + + def forward(self, x, fname=None): + x_l, x_g = x if type(x) is tuple else (x, 0) + out_xl, out_xg = 0, 0 + + if self.gated: + total_input_parts = [x_l] + if torch.is_tensor(x_g): + total_input_parts.append(x_g) + total_input = torch.cat(total_input_parts, dim=1) + + gates = torch.sigmoid(self.gate(total_input)) + g2l_gate, l2g_gate = gates.chunk(2, dim=1) + else: + g2l_gate, l2g_gate = 1, 1 + + spec_x = self.convg2g(x_g) + + if self.ratio_gout != 1: + out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate + if self.ratio_gout != 0: + out_xg = self.convl2g(x_l) * l2g_gate + spec_x + + return out_xl, out_xg + + +class FFC_BN_ACT(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False, + norm_layer=nn.SyncBatchNorm, + activation_layer=nn.Identity, + padding_type="reflect", + enable_lfu=True, + **kwargs, + ): + super(FFC_BN_ACT, self).__init__() + self.ffc = FFC( + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride, + padding, + dilation, + groups, + bias, + enable_lfu, + padding_type=padding_type, + **kwargs, + ) + lnorm = nn.Identity if ratio_gout == 1 else norm_layer + gnorm = nn.Identity if ratio_gout == 0 else norm_layer + global_channels = int(out_channels * ratio_gout) + # self.bn_l = lnorm(out_channels - global_channels) + # self.bn_g = gnorm(global_channels) + + lact = nn.Identity if ratio_gout == 1 else activation_layer + gact = nn.Identity if ratio_gout == 0 else activation_layer + self.act_l = lact(inplace=True) + self.act_g = gact(inplace=True) + + def forward(self, x, fname=None): + x_l, x_g = self.ffc( + x, + fname=fname, + ) + x_l = self.act_l(x_l) + x_g = self.act_g(x_g) + return x_l, x_g + + +class FFCResnetBlock(nn.Module): + def __init__( + self, + dim, + padding_type, + norm_layer, + activation_layer=nn.ReLU, + dilation=1, + spatial_transform_kwargs=None, + inline=False, + ratio_gin=0.75, + ratio_gout=0.75, + ): + super().__init__() + self.conv1 = FFC_BN_ACT( + dim, + dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + ratio_gin=ratio_gin, + ratio_gout=ratio_gout, + ) + self.conv2 = FFC_BN_ACT( + dim, + dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + ratio_gin=ratio_gin, + ratio_gout=ratio_gout, + ) + self.inline = inline + + def forward(self, x, fname=None): + if self.inline: + x_l, x_g = ( + x[:, : -self.conv1.ffc.global_in_num], + x[:, -self.conv1.ffc.global_in_num :], + ) + else: + x_l, x_g = x if type(x) is tuple else (x, 0) + + id_l, id_g = x_l, x_g + + x_l, x_g = self.conv1((x_l, x_g), fname=fname) + x_l, x_g = self.conv2((x_l, x_g), fname=fname) + + x_l, x_g = id_l + x_l, id_g + x_g + out = x_l, x_g + if self.inline: + out = torch.cat(out, dim=1) + return out + + +class ConcatTupleLayer(nn.Module): + def forward(self, x): + assert isinstance(x, tuple) + x_l, x_g = x + assert torch.is_tensor(x_l) or torch.is_tensor(x_g) + if not torch.is_tensor(x_g): + return x_l + return torch.cat(x, dim=1) + + +class FFCBlock(torch.nn.Module): + def __init__( + self, + dim, # Number of output/input channels. + kernel_size, # Width and height of the convolution kernel. + padding, + ratio_gin=0.75, + ratio_gout=0.75, + activation="linear", # Activation function: 'relu', 'lrelu', etc. + ): + super().__init__() + if activation == "linear": + self.activation = nn.Identity + else: + self.activation = nn.ReLU + self.padding = padding + self.kernel_size = kernel_size + self.ffc_block = FFCResnetBlock( + dim=dim, + padding_type="reflect", + norm_layer=nn.SyncBatchNorm, + activation_layer=self.activation, + dilation=1, + ratio_gin=ratio_gin, + ratio_gout=ratio_gout, + ) + + self.concat_layer = ConcatTupleLayer() + + def forward(self, gen_ft, mask, fname=None): + x = gen_ft.float() + + x_l, x_g = ( + x[:, : -self.ffc_block.conv1.ffc.global_in_num], + x[:, -self.ffc_block.conv1.ffc.global_in_num :], + ) + id_l, id_g = x_l, x_g + + x_l, x_g = self.ffc_block((x_l, x_g), fname=fname) + x_l, x_g = id_l + x_l, id_g + x_g + x = self.concat_layer((x_l, x_g)) + + return x + gen_ft.float() + + +class FFCSkipLayer(torch.nn.Module): + def __init__( + self, + dim, # Number of input/output channels. + kernel_size=3, # Convolution kernel size. + ratio_gin=0.75, + ratio_gout=0.75, + ): + super().__init__() + self.padding = kernel_size // 2 + + self.ffc_act = FFCBlock( + dim=dim, + kernel_size=kernel_size, + activation=nn.ReLU, + padding=self.padding, + ratio_gin=ratio_gin, + ratio_gout=ratio_gout, + ) + + def forward(self, gen_ft, mask, fname=None): + x = self.ffc_act(gen_ft, mask, fname=fname) + return x + + +class SynthesisBlock(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of output color channels. + is_last, # Is this the last block? + architecture="skip", # Architecture: 'orig', 'skip', 'resnet'. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16=False, # Use FP16 for this block? + fp16_channels_last=False, # Use channels-last memory format with FP16? + **layer_kwargs, # Arguments for SynthesisLayer. + ): + assert architecture in ["orig", "skip", "resnet"] + super().__init__() + self.in_channels = in_channels + self.w_dim = w_dim + self.resolution = resolution + self.img_channels = img_channels + self.is_last = is_last + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = use_fp16 and fp16_channels_last + self.register_buffer("resample_filter", setup_filter(resample_filter)) + self.num_conv = 0 + self.num_torgb = 0 + self.res_ffc = {4: 0, 8: 0, 16: 0, 32: 1, 64: 1, 128: 1, 256: 1, 512: 1} + + if in_channels != 0 and resolution >= 8: + self.ffc_skip = nn.ModuleList() + for _ in range(self.res_ffc[resolution]): + self.ffc_skip.append(FFCSkipLayer(dim=out_channels)) + + if in_channels == 0: + self.const = torch.nn.Parameter( + torch.randn([out_channels, resolution, resolution]) + ) + + if in_channels != 0: + self.conv0 = SynthesisLayer( + in_channels, + out_channels, + w_dim=w_dim * 3, + resolution=resolution, + up=2, + resample_filter=resample_filter, + conv_clamp=conv_clamp, + channels_last=self.channels_last, + **layer_kwargs, + ) + self.num_conv += 1 + + self.conv1 = SynthesisLayer( + out_channels, + out_channels, + w_dim=w_dim * 3, + resolution=resolution, + conv_clamp=conv_clamp, + channels_last=self.channels_last, + **layer_kwargs, + ) + self.num_conv += 1 + + if is_last or architecture == "skip": + self.torgb = ToRGBLayer( + out_channels, + img_channels, + w_dim=w_dim * 3, + conv_clamp=conv_clamp, + channels_last=self.channels_last, + ) + self.num_torgb += 1 + + if in_channels != 0 and architecture == "resnet": + self.skip = Conv2dLayer( + in_channels, + out_channels, + kernel_size=1, + bias=False, + up=2, + resample_filter=resample_filter, + channels_last=self.channels_last, + ) + + def forward( + self, + x, + mask, + feats, + img, + ws, + fname=None, + force_fp32=False, + fused_modconv=None, + **layer_kwargs, + ): + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + dtype = torch.float32 + memory_format = ( + torch.channels_last + if self.channels_last and not force_fp32 + else torch.contiguous_format + ) + if fused_modconv is None: + fused_modconv = (not self.training) and ( + dtype == torch.float32 or int(x.shape[0]) == 1 + ) + + x = x.to(dtype=dtype, memory_format=memory_format) + x_skip = ( + feats[self.resolution].clone().to(dtype=dtype, memory_format=memory_format) + ) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, ws[1], fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == "resnet": + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0( + x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs + ) + if len(self.ffc_skip) > 0: + mask = F.interpolate( + mask, + size=x_skip.shape[2:], + ) + z = x + x_skip + for fres in self.ffc_skip: + z = fres(z, mask) + x = x + z + else: + x = x + x_skip + x = self.conv1( + x, + ws[1].clone(), + fused_modconv=fused_modconv, + gain=np.sqrt(0.5), + **layer_kwargs, + ) + x = y.add_(x) + else: + x = self.conv0( + x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs + ) + if len(self.ffc_skip) > 0: + mask = F.interpolate( + mask, + size=x_skip.shape[2:], + ) + z = x + x_skip + for fres in self.ffc_skip: + z = fres(z, mask) + x = x + z + else: + x = x + x_skip + x = self.conv1( + x, ws[1].clone(), fused_modconv=fused_modconv, **layer_kwargs + ) + # ToRGB. + if img is not None: + img = upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == "skip": + y = self.torgb(x, ws[2].clone(), fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + x = x.to(dtype=dtype) + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + + +class SynthesisNetwork(torch.nn.Module): + def __init__( + self, + w_dim, # Intermediate latent (W) dimensionality. + z_dim, # Output Latent (Z) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + channel_base=16384, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=0, # Use FP16 for the N highest resolutions. + **block_kwargs, # Arguments for SynthesisBlock. + ): + assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 + super().__init__() + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [ + 2**i for i in range(3, self.img_resolution_log2 + 1) + ] + channels_dict = { + res: min(channel_base // res, channel_max) for res in self.block_resolutions + } + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + self.foreword = SynthesisForeword( + img_channels=img_channels, + in_channels=min(channel_base // 4, channel_max), + z_dim=z_dim * 2, + resolution=4, + ) + + self.num_ws = self.img_resolution_log2 * 2 - 2 + for res in self.block_resolutions: + if res // 2 in channels_dict.keys(): + in_channels = channels_dict[res // 2] if res > 4 else 0 + else: + in_channels = min(channel_base // (res // 2), channel_max) + out_channels = channels_dict[res] + use_fp16 = res >= fp16_resolution + use_fp16 = False + is_last = res == self.img_resolution + block = SynthesisBlock( + in_channels, + out_channels, + w_dim=w_dim, + resolution=res, + img_channels=img_channels, + is_last=is_last, + use_fp16=use_fp16, + **block_kwargs, + ) + setattr(self, f"b{res}", block) + + def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs): + + img = None + + x, img = self.foreword(x_global, ws, feats, img) + + for res in self.block_resolutions: + block = getattr(self, f"b{res}") + mod_vector0 = [] + mod_vector0.append(ws[:, int(np.log2(res)) * 2 - 5]) + mod_vector0.append(x_global.clone()) + mod_vector0 = torch.cat(mod_vector0, dim=1) + + mod_vector1 = [] + mod_vector1.append(ws[:, int(np.log2(res)) * 2 - 4]) + mod_vector1.append(x_global.clone()) + mod_vector1 = torch.cat(mod_vector1, dim=1) + + mod_vector_rgb = [] + mod_vector_rgb.append(ws[:, int(np.log2(res)) * 2 - 3]) + mod_vector_rgb.append(x_global.clone()) + mod_vector_rgb = torch.cat(mod_vector_rgb, dim=1) + x, img = block( + x, + mask, + feats, + img, + (mod_vector0, mod_vector1, mod_vector_rgb), + fname=fname, + **block_kwargs, + ) + return img + + +class MappingNetwork(torch.nn.Module): + def __init__( + self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output, None = do not broadcast. + num_layers=8, # Number of mapping layers. + embed_features=None, # Label embedding dimensionality, None = same as w_dim. + layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + + if embed_features is None: + embed_features = w_dim + if c_dim == 0: + embed_features = 0 + if layer_features is None: + layer_features = w_dim + features_list = ( + [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + ) + + if c_dim > 0: + self.embed = FullyConnectedLayer(c_dim, embed_features) + for idx in range(num_layers): + in_features = features_list[idx] + out_features = features_list[idx + 1] + layer = FullyConnectedLayer( + in_features, + out_features, + activation=activation, + lr_multiplier=lr_multiplier, + ) + setattr(self, f"fc{idx}", layer) + + if num_ws is not None and w_avg_beta is not None: + self.register_buffer("w_avg", torch.zeros([w_dim])) + + def forward( + self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False + ): + # Embed, normalize, and concat inputs. + x = None + with torch.autograd.profiler.record_function("input"): + if self.z_dim > 0: + x = normalize_2nd_moment(z.to(torch.float32)) + if self.c_dim > 0: + y = normalize_2nd_moment(self.embed(c.to(torch.float32))) + x = torch.cat([x, y], dim=1) if x is not None else y + + # Main layers. + for idx in range(self.num_layers): + layer = getattr(self, f"fc{idx}") + x = layer(x) + + # Update moving average of W. + if self.w_avg_beta is not None and self.training and not skip_w_avg_update: + with torch.autograd.profiler.record_function("update_w_avg"): + self.w_avg.copy_( + x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta) + ) + + # Broadcast. + if self.num_ws is not None: + with torch.autograd.profiler.record_function("broadcast"): + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + + # Apply truncation. + if truncation_psi != 1: + with torch.autograd.profiler.record_function("truncate"): + assert self.w_avg_beta is not None + if self.num_ws is None or truncation_cutoff is None: + x = self.w_avg.lerp(x, truncation_psi) + else: + x[:, :truncation_cutoff] = self.w_avg.lerp( + x[:, :truncation_cutoff], truncation_psi + ) + return x + + +class Generator(torch.nn.Module): + def __init__( + self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + encoder_kwargs={}, # Arguments for EncoderNetwork. + mapping_kwargs={}, # Arguments for MappingNetwork. + synthesis_kwargs={}, # Arguments for SynthesisNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + self.encoder = EncoderNetwork( + c_dim=c_dim, + z_dim=z_dim, + img_resolution=img_resolution, + img_channels=img_channels, + **encoder_kwargs, + ) + self.synthesis = SynthesisNetwork( + z_dim=z_dim, + w_dim=w_dim, + img_resolution=img_resolution, + img_channels=img_channels, + **synthesis_kwargs, + ) + self.num_ws = self.synthesis.num_ws + self.mapping = MappingNetwork( + z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs + ) + + def forward( + self, + img, + c, + fname=None, + truncation_psi=1, + truncation_cutoff=None, + **synthesis_kwargs, + ): + mask = img[:, -1].unsqueeze(1) + x_global, z, feats = self.encoder(img, c) + ws = self.mapping( + z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff + ) + img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs) + return img + + +FCF_MODEL_URL = os.environ.get( + "FCF_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_fcf/places_512_G.pth", +) +FCF_MODEL_MD5 = os.environ.get("FCF_MODEL_MD5", "3323152bc01bf1c56fd8aba74435a211") + + +class FcF(InpaintModel): + name = "fcf" + min_size = 512 + pad_mod = 512 + pad_to_square = True + + def init_model(self, device, **kwargs): + seed = 0 + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + kwargs = { + "channel_base": 1 * 32768, + "channel_max": 512, + "num_fp16_res": 4, + "conv_clamp": 256, + } + G = Generator( + z_dim=512, + c_dim=0, + w_dim=512, + img_resolution=512, + img_channels=3, + synthesis_kwargs=kwargs, + encoder_kwargs=kwargs, + mapping_kwargs={"num_layers": 2}, + ) + self.model = load_model(G, FCF_MODEL_URL, device, FCF_MODEL_MD5) + self.label = torch.zeros([1, self.model.c_dim], device=device) + + @staticmethod + def is_downloaded() -> bool: + return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL)) + + @torch.no_grad() + def __call__(self, image, mask, config: Config): + """ + images: [H, W, C] RGB, not normalized + masks: [H, W] + return: BGR IMAGE + """ + if image.shape[0] == 512 and image.shape[1] == 512: + return self._pad_forward(image, mask, config) + + boxes = boxes_from_mask(mask) + crop_result = [] + config.hd_strategy_crop_margin = 128 + for box in boxes: + crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config) + origin_size = crop_image.shape[:2] + resize_image = resize_max_size(crop_image, size_limit=512) + resize_mask = resize_max_size(crop_mask, size_limit=512) + inpaint_result = self._pad_forward(resize_image, resize_mask, config) + + # only paste masked area result + inpaint_result = cv2.resize( + inpaint_result, + (origin_size[1], origin_size[0]), + interpolation=cv2.INTER_CUBIC, + ) + + original_pixel_indices = crop_mask < 127 + inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][ + original_pixel_indices + ] + + crop_result.append((inpaint_result, crop_box)) + + inpaint_result = image[:, :, ::-1] + for crop_image, crop_box in crop_result: + x1, y1, x2, y2 = crop_box + inpaint_result[y1:y2, x1:x2, :] = crop_image + + return inpaint_result + + def forward(self, image, mask, config: Config): + """Input images and output images have same size + images: [H, W, C] RGB + masks: [H, W] mask area == 255 + return: BGR IMAGE + """ + + image = norm_img(image) # [0, 1] + image = image * 2 - 1 # [0, 1] -> [-1, 1] + mask = (mask > 120) * 255 + mask = norm_img(mask) + + image = torch.from_numpy(image).unsqueeze(0).to(self.device) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) + + erased_img = image * (1 - mask) + input_image = torch.cat([0.5 - mask, erased_img], dim=1) + + output = self.model( + input_image, self.label, truncation_psi=0.1, noise_mode="none" + ) + output = ( + (output.permute(0, 2, 3, 1) * 127.5 + 127.5) + .round() + .clamp(0, 255) + .to(torch.uint8) + ) + output = output[0].cpu().numpy() + cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return cur_res diff --git a/lama_cleaner/model/instruct_pix2pix.py b/lama_cleaner/model/instruct_pix2pix.py new file mode 100644 index 0000000000000000000000000000000000000000..7f9ab063b370b92ed8093c7f0fe5ea8abdaf5ebb --- /dev/null +++ b/lama_cleaner/model/instruct_pix2pix.py @@ -0,0 +1,83 @@ +import PIL.Image +import cv2 +import torch +from loguru import logger + +from lama_cleaner.model.base import DiffusionInpaintModel +from lama_cleaner.model.utils import set_seed +from lama_cleaner.schema import Config + + +class InstructPix2Pix(DiffusionInpaintModel): + name = "instruct_pix2pix" + pad_mod = 8 + min_size = 512 + + def init_model(self, device: torch.device, **kwargs): + from diffusers import StableDiffusionInstructPix2PixPipeline + fp16 = not kwargs.get('no_half', False) + + model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)} + if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False): + logger.info("Disable Stable Diffusion Model NSFW checker") + model_kwargs.update(dict( + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False + )) + + use_gpu = device == torch.device('cuda') and torch.cuda.is_available() + torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained( + "timbrooks/instruct-pix2pix", + revision="fp16" if use_gpu and fp16 else "main", + torch_dtype=torch_dtype, + **model_kwargs + ) + + self.model.enable_attention_slicing() + if kwargs.get('enable_xformers', False): + self.model.enable_xformers_memory_efficient_attention() + + if kwargs.get('cpu_offload', False) and use_gpu: + logger.info("Enable sequential cpu offload") + self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) + + def forward(self, image, mask, config: Config): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + edit = pipe(prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7).images[0] + """ + output = self.model( + image=PIL.Image.fromarray(image), + prompt=config.prompt, + negative_prompt=config.negative_prompt, + num_inference_steps=config.p2p_steps, + image_guidance_scale=config.p2p_image_guidance_scale, + guidance_scale=config.p2p_guidance_scale, + output_type="np.array", + generator=torch.manual_seed(config.sd_seed) + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + # + # def forward_post_process(self, result, image, mask, config): + # if config.sd_match_histograms: + # result = self._match_histograms(result, image[:, :, ::-1], mask) + # + # if config.sd_mask_blur != 0: + # k = 2 * config.sd_mask_blur + 1 + # mask = cv2.GaussianBlur(mask, (k, k), 0) + # return result, image, mask + + @staticmethod + def is_downloaded() -> bool: + # model will be downloaded when app start, and can't switch in frontend settings + return True diff --git a/lama_cleaner/model/lama.py b/lama_cleaner/model/lama.py new file mode 100644 index 0000000000000000000000000000000000000000..8ffefd0d97f1b896bf53dc934120bb007c926371 --- /dev/null +++ b/lama_cleaner/model/lama.py @@ -0,0 +1,51 @@ +import os + +import cv2 +import numpy as np +import torch + +from lama_cleaner.helper import ( + norm_img, + get_cache_path_by_url, + load_jit_model, +) +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.schema import Config + +LAMA_MODEL_URL = os.environ.get( + "LAMA_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", +) +LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e500") + + +class LaMa(InpaintModel): + name = "lama" + pad_mod = 8 + + def init_model(self, device, **kwargs): + self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval() + + @staticmethod + def is_downloaded() -> bool: + return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL)) + + def forward(self, image, mask, config: Config): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W] + return: BGR IMAGE + """ + image = norm_img(image) + mask = norm_img(mask) + + mask = (mask > 0) * 1 + image = torch.from_numpy(image).unsqueeze(0).to(self.device) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) + + inpainted_image = self.model(image, mask) + + cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() + cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) + return cur_res diff --git a/lama_cleaner/model/ldm.py b/lama_cleaner/model/ldm.py new file mode 100644 index 0000000000000000000000000000000000000000..2c476ceb0bc704cdcb6d15bd7afe163f78a701fc --- /dev/null +++ b/lama_cleaner/model/ldm.py @@ -0,0 +1,333 @@ +import os +from functools import wraps + +import numpy as np +import torch +import torch.nn as nn + +from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, norm_img +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.model.ddim_sampler import DDIMSampler +from lama_cleaner.model.plms_sampler import PLMSSampler +from lama_cleaner.model.utils import make_beta_schedule, timestep_embedding +from lama_cleaner.schema import Config, LDMSampler + +# torch.manual_seed(42) + + +def conditional_autocast(func): + @wraps(func) + def wrapper(*args, **kwargs): + if torch.cuda.is_available(): + with torch.cuda.amp.autocast(): + return func(*args, **kwargs) + else: + return func(*args, **kwargs) + return wrapper + + +LDM_ENCODE_MODEL_URL = os.environ.get( + "LDM_ENCODE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt", +) +LDM_ENCODE_MODEL_MD5 = os.environ.get( + "LDM_ENCODE_MODEL_MD5", "23239fc9081956a3e70de56472b3f296" +) + +LDM_DECODE_MODEL_URL = os.environ.get( + "LDM_DECODE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt", +) +LDM_DECODE_MODEL_MD5 = os.environ.get( + "LDM_DECODE_MODEL_MD5", "fe419cd15a750d37a4733589d0d3585c" +) + +LDM_DIFFUSION_MODEL_URL = os.environ.get( + "LDM_DIFFUSION_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt", +) + +LDM_DIFFUSION_MODEL_MD5 = os.environ.get( + "LDM_DIFFUSION_MODEL_MD5", "b0afda12bf790c03aba2a7431f11d22d" +) + + +class DDPM(nn.Module): + # classic DDPM with Gaussian diffusion, in image space + def __init__( + self, + device, + timesteps=1000, + beta_schedule="linear", + linear_start=0.0015, + linear_end=0.0205, + cosine_s=0.008, + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, + parameterization="eps", # all assuming fixed variance schedules + use_positional_encodings=False, + ): + super().__init__() + self.device = device + self.parameterization = parameterization + self.use_positional_encodings = use_positional_encodings + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + self.register_schedule( + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + betas = make_beta_schedule( + self.device, + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + def to_torch(x): return torch.tensor(x, dtype=torch.float32).to(self.device) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped", + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1", + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + ) + self.register_buffer( + "posterior_mean_coef2", + to_torch( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + ) + + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) + elif self.parameterization == "x0": + lvlb_weights = ( + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + ) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + +class LatentDiffusion(DDPM): + def __init__( + self, + diffusion_model, + device, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + scale_factor=1.0, + scale_by_std=False, + *args, + **kwargs, + ): + self.num_timesteps_cond = 1 + self.scale_by_std = scale_by_std + super().__init__(device, *args, **kwargs) + self.diffusion_model = diffusion_model + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + self.num_downs = 2 + self.scale_factor = scale_factor + + def make_cond_schedule( + self, + ): + self.cond_ids = torch.full( + size=(self.num_timesteps,), + fill_value=self.num_timesteps - 1, + dtype=torch.long, + ) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) + ).long() + self.cond_ids[: self.num_timesteps_cond] = ids + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + super().register_schedule( + given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s + ) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def apply_model(self, x_noisy, t, cond): + # x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128 + t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False) + x_recon = self.diffusion_model(x_noisy, t_emb, cond) + return x_recon + + +class LDM(InpaintModel): + name = "ldm" + pad_mod = 32 + + def __init__(self, device, fp16: bool = True, **kwargs): + self.fp16 = fp16 + super().__init__(device) + self.device = device + + def init_model(self, device, **kwargs): + self.diffusion_model = load_jit_model( + LDM_DIFFUSION_MODEL_URL, device, LDM_DIFFUSION_MODEL_MD5 + ) + self.cond_stage_model_decode = load_jit_model( + LDM_DECODE_MODEL_URL, device, LDM_DECODE_MODEL_MD5 + ) + self.cond_stage_model_encode = load_jit_model( + LDM_ENCODE_MODEL_URL, device, LDM_ENCODE_MODEL_MD5 + ) + if self.fp16 and "cuda" in str(device): + self.diffusion_model = self.diffusion_model.half() + self.cond_stage_model_decode = self.cond_stage_model_decode.half() + self.cond_stage_model_encode = self.cond_stage_model_encode.half() + + self.model = LatentDiffusion(self.diffusion_model, device) + + @staticmethod + def is_downloaded() -> bool: + model_paths = [ + get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL), + get_cache_path_by_url(LDM_DECODE_MODEL_URL), + get_cache_path_by_url(LDM_ENCODE_MODEL_URL), + ] + return all([os.path.exists(it) for it in model_paths]) + + @conditional_autocast + def forward(self, image, mask, config: Config): + """ + image: [H, W, C] RGB + mask: [H, W, 1] + return: BGR IMAGE + """ + # image [1,3,512,512] float32 + # mask: [1,1,512,512] float32 + # masked_image: [1,3,512,512] float32 + if config.ldm_sampler == LDMSampler.ddim: + sampler = DDIMSampler(self.model) + elif config.ldm_sampler == LDMSampler.plms: + sampler = PLMSSampler(self.model) + else: + raise ValueError() + + steps = config.ldm_steps + image = norm_img(image) + mask = norm_img(mask) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + image = torch.from_numpy(image).unsqueeze(0).to(self.device) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) + masked_image = (1 - mask) * image + + mask = self._norm(mask) + masked_image = self._norm(masked_image) + + c = self.cond_stage_model_encode(masked_image) + torch.cuda.empty_cache() + + cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128 + c = torch.cat((c, cc), dim=1) # 1,4,128,128 + + shape = (c.shape[1] - 1,) + c.shape[2:] + samples_ddim = sampler.sample( + steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape + ) + torch.cuda.empty_cache() + x_samples_ddim = self.cond_stage_model_decode( + samples_ddim + ) # samples_ddim: 1, 3, 128, 128 float32 + torch.cuda.empty_cache() + + # image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) + # mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0) + inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + # inpainted = (1 - mask) * image + mask * predicted_image + inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 + inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1] + return inpainted_image + + def _norm(self, tensor): + return tensor * 2.0 - 1.0 diff --git a/lama_cleaner/model/manga.py b/lama_cleaner/model/manga.py new file mode 100644 index 0000000000000000000000000000000000000000..c152c8cb8f127ac2a7a0ee5fa3e17df52d310438 --- /dev/null +++ b/lama_cleaner/model/manga.py @@ -0,0 +1,91 @@ +import os +import random + +import cv2 +import numpy as np +import torch +import time +from loguru import logger + +from lama_cleaner.helper import get_cache_path_by_url, load_jit_model +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.schema import Config + + +MANGA_INPAINTOR_MODEL_URL = os.environ.get( + "MANGA_INPAINTOR_MODEL_URL", + "https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit", +) +MANGA_INPAINTOR_MODEL_MD5 = os.environ.get( + "MANGA_INPAINTOR_MODEL_MD5", "7d8b269c4613b6b3768af714610da86c" +) + +MANGA_LINE_MODEL_URL = os.environ.get( + "MANGA_LINE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/manga/erika.jit", +) +MANGA_LINE_MODEL_MD5 = os.environ.get( + "MANGA_LINE_MODEL_MD5", "0c926d5a4af8450b0d00bc5b9a095644" +) + + +class Manga(InpaintModel): + name = "manga" + pad_mod = 16 + + def init_model(self, device, **kwargs): + self.inpaintor_model = load_jit_model( + MANGA_INPAINTOR_MODEL_URL, device, MANGA_INPAINTOR_MODEL_MD5 + ) + self.line_model = load_jit_model( + MANGA_LINE_MODEL_URL, device, MANGA_LINE_MODEL_MD5 + ) + self.seed = 42 + + @staticmethod + def is_downloaded() -> bool: + model_paths = [ + get_cache_path_by_url(MANGA_INPAINTOR_MODEL_URL), + get_cache_path_by_url(MANGA_LINE_MODEL_URL), + ] + return all([os.path.exists(it) for it in model_paths]) + + def forward(self, image, mask, config: Config): + """ + image: [H, W, C] RGB + mask: [H, W, 1] + return: BGR IMAGE + """ + seed = self.seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + gray_img = torch.from_numpy( + gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32) + ).to(self.device) + start = time.time() + lines = self.line_model(gray_img) + torch.cuda.empty_cache() + lines = torch.clamp(lines, 0, 255) + logger.info(f"erika_model time: {time.time() - start}") + + mask = torch.from_numpy(mask[np.newaxis, :, :, :]).to(self.device) + mask = mask.permute(0, 3, 1, 2) + mask = torch.where(mask > 0.5, 1.0, 0.0) + noise = torch.randn_like(mask) + ones = torch.ones_like(mask) + + gray_img = gray_img / 255 * 2 - 1.0 + lines = lines / 255 * 2 - 1.0 + + start = time.time() + inpainted_image = self.inpaintor_model(gray_img, lines, mask, noise, ones) + logger.info(f"image_inpaintor_model time: {time.time() - start}") + + cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() + cur_res = (cur_res * 127.5 + 127.5).astype(np.uint8) + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_GRAY2BGR) + return cur_res diff --git a/lama_cleaner/model/mat.py b/lama_cleaner/model/mat.py new file mode 100644 index 0000000000000000000000000000000000000000..9a98bee6b0fb2bf40fb3deda260080424a636f76 --- /dev/null +++ b/lama_cleaner/model/mat.py @@ -0,0 +1,1935 @@ +import os +import random + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.model.utils import ( + setup_filter, + Conv2dLayer, + FullyConnectedLayer, + conv2d_resample, + bias_act, + upsample2d, + activation_funcs, + MinibatchStdLayer, + to_2tuple, + normalize_2nd_moment, + set_seed, +) +from lama_cleaner.schema import Config + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + style_dim, # dimension of the style code + demodulate=True, # perfrom demodulation + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output to +-X, None = disable clamping. + ): + super().__init__() + self.demodulate = demodulate + + self.weight = torch.nn.Parameter( + torch.randn([1, out_channels, in_channels, kernel_size, kernel_size]) + ) + self.out_channels = out_channels + self.kernel_size = kernel_size + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + self.padding = self.kernel_size // 2 + self.up = up + self.down = down + self.register_buffer("resample_filter", setup_filter(resample_filter)) + self.conv_clamp = conv_clamp + + self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1) + + def forward(self, x, style): + batch, in_channels, height, width = x.shape + style = self.affine(style).view(batch, 1, in_channels, 1, 1) + weight = self.weight * self.weight_gain * style + + if self.demodulate: + decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt() + weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1) + + weight = weight.view( + batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size + ) + x = x.view(1, batch * in_channels, height, width) + x = conv2d_resample( + x=x, + w=weight, + f=self.resample_filter, + up=self.up, + down=self.down, + padding=self.padding, + groups=batch, + ) + out = x.view(batch, self.out_channels, *x.shape[2:]) + + return out + + +class StyleConv(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + style_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this layer. + kernel_size=3, # Convolution kernel size. + up=1, # Integer upsampling factor. + use_noise=False, # Enable noise input? + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + demodulate=True, # perform demodulation + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + style_dim=style_dim, + demodulate=demodulate, + up=up, + resample_filter=resample_filter, + conv_clamp=conv_clamp, + ) + + self.use_noise = use_noise + self.resolution = resolution + if use_noise: + self.register_buffer("noise_const", torch.randn([resolution, resolution])) + self.noise_strength = torch.nn.Parameter(torch.zeros([])) + + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + self.activation = activation + self.act_gain = activation_funcs[activation].def_gain + self.conv_clamp = conv_clamp + + def forward(self, x, style, noise_mode="random", gain=1): + x = self.conv(x, style) + + assert noise_mode in ["random", "const", "none"] + + if self.use_noise: + if noise_mode == "random": + xh, xw = x.size()[-2:] + noise = ( + torch.randn([x.shape[0], 1, xh, xw], device=x.device) + * self.noise_strength + ) + if noise_mode == "const": + noise = self.noise_const * self.noise_strength + x = x + noise + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + out = bias_act( + x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp + ) + + return out + + +class ToRGB(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + style_dim, + kernel_size=1, + resample_filter=[1, 3, 3, 1], + conv_clamp=None, + demodulate=False, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + style_dim=style_dim, + demodulate=demodulate, + resample_filter=resample_filter, + conv_clamp=conv_clamp, + ) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + self.register_buffer("resample_filter", setup_filter(resample_filter)) + self.conv_clamp = conv_clamp + + def forward(self, x, style, skip=None): + x = self.conv(x, style) + out = bias_act(x, self.bias, clamp=self.conv_clamp) + + if skip is not None: + if skip.shape != out.shape: + skip = upsample2d(skip, self.resample_filter) + out = out + skip + + return out + + +def get_style_code(a, b): + return torch.cat([a, b], dim=1) + + +class DecBlockFirst(nn.Module): + def __init__( + self, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): + super().__init__() + self.fc = FullyConnectedLayer( + in_features=in_channels * 2, + out_features=in_channels * 4 ** 2, + activation=activation, + ) + self.conv = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=4, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, ws, gs, E_features, noise_mode="random"): + x = self.fc(x).view(x.shape[0], -1, 4, 4) + x = x + E_features[2] + style = get_style_code(ws[:, 0], gs) + x = self.conv(x, style, noise_mode=noise_mode) + style = get_style_code(ws[:, 1], gs) + img = self.toRGB(x, style, skip=None) + + return x, img + + +class DecBlockFirstV2(nn.Module): + def __init__( + self, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): + super().__init__() + self.conv0 = Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + ) + self.conv1 = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=4, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, ws, gs, E_features, noise_mode="random"): + # x = self.fc(x).view(x.shape[0], -1, 4, 4) + x = self.conv0(x) + x = x + E_features[2] + style = get_style_code(ws[:, 0], gs) + x = self.conv1(x, style, noise_mode=noise_mode) + style = get_style_code(ws[:, 1], gs) + img = self.toRGB(x, style, skip=None) + + return x, img + + +class DecBlock(nn.Module): + def __init__( + self, + res, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): # res = 2, ..., resolution_log2 + super().__init__() + self.res = res + + self.conv0 = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2 ** res, + kernel_size=3, + up=2, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.conv1 = StyleConv( + in_channels=out_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2 ** res, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, img, ws, gs, E_features, noise_mode="random"): + style = get_style_code(ws[:, self.res * 2 - 5], gs) + x = self.conv0(x, style, noise_mode=noise_mode) + x = x + E_features[self.res] + style = get_style_code(ws[:, self.res * 2 - 4], gs) + x = self.conv1(x, style, noise_mode=noise_mode) + style = get_style_code(ws[:, self.res * 2 - 3], gs) + img = self.toRGB(x, style, skip=img) + + return x, img + + +class MappingNet(torch.nn.Module): + def __init__( + self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output, None = do not broadcast. + num_layers=8, # Number of mapping layers. + embed_features=None, # Label embedding dimensionality, None = same as w_dim. + layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track. + torch_dtype=torch.float32, + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + self.torch_dtype = torch_dtype + + if embed_features is None: + embed_features = w_dim + if c_dim == 0: + embed_features = 0 + if layer_features is None: + layer_features = w_dim + features_list = ( + [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + ) + + if c_dim > 0: + self.embed = FullyConnectedLayer(c_dim, embed_features) + for idx in range(num_layers): + in_features = features_list[idx] + out_features = features_list[idx + 1] + layer = FullyConnectedLayer( + in_features, + out_features, + activation=activation, + lr_multiplier=lr_multiplier, + ) + setattr(self, f"fc{idx}", layer) + + if num_ws is not None and w_avg_beta is not None: + self.register_buffer("w_avg", torch.zeros([w_dim])) + + def forward( + self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False + ): + # Embed, normalize, and concat inputs. + x = None + if self.z_dim > 0: + x = normalize_2nd_moment(z) + if self.c_dim > 0: + y = normalize_2nd_moment(self.embed(c)) + x = torch.cat([x, y], dim=1) if x is not None else y + + # Main layers. + for idx in range(self.num_layers): + layer = getattr(self, f"fc{idx}") + x = layer(x) + + # Update moving average of W. + if self.w_avg_beta is not None and self.training and not skip_w_avg_update: + self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + # Broadcast. + if self.num_ws is not None: + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + + # Apply truncation. + if truncation_psi != 1: + assert self.w_avg_beta is not None + if self.num_ws is None or truncation_cutoff is None: + x = self.w_avg.lerp(x, truncation_psi) + else: + x[:, :truncation_cutoff] = self.w_avg.lerp( + x[:, :truncation_cutoff], truncation_psi + ) + + return x + + +class DisFromRGB(nn.Module): + def __init__( + self, in_channels, out_channels, activation + ): # res = 2, ..., resolution_log2 + super().__init__() + self.conv = Conv2dLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + activation=activation, + ) + + def forward(self, x): + return self.conv(x) + + +class DisBlock(nn.Module): + def __init__( + self, in_channels, out_channels, activation + ): # res = 2, ..., resolution_log2 + super().__init__() + self.conv0 = Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + ) + self.conv1 = Conv2dLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + down=2, + activation=activation, + ) + self.skip = Conv2dLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + down=2, + bias=False, + ) + + def forward(self, x): + skip = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x) + x = self.conv1(x, gain=np.sqrt(0.5)) + out = skip + x + + return out + + +class Discriminator(torch.nn.Module): + def __init__( + self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + channel_decay=1, + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + activation="lrelu", + mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. + ): + super().__init__() + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + + resolution_log2 = int(np.log2(img_resolution)) + assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4 + self.resolution_log2 = resolution_log2 + + def nf(stage): + return np.clip( + int(channel_base / 2 ** (stage * channel_decay)), 1, channel_max + ) + + if cmap_dim == None: + cmap_dim = nf(2) + if c_dim == 0: + cmap_dim = 0 + self.cmap_dim = cmap_dim + + if c_dim > 0: + self.mapping = MappingNet( + z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None + ) + + Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)] + for res in range(resolution_log2, 2, -1): + Dis.append(DisBlock(nf(res), nf(res - 1), activation)) + + if mbstd_num_channels > 0: + Dis.append( + MinibatchStdLayer( + group_size=mbstd_group_size, num_channels=mbstd_num_channels + ) + ) + Dis.append( + Conv2dLayer( + nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation + ) + ) + self.Dis = nn.Sequential(*Dis) + + self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation) + self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim) + + def forward(self, images_in, masks_in, c): + x = torch.cat([masks_in - 0.5, images_in], dim=1) + x = self.Dis(x) + x = self.fc1(self.fc0(x.flatten(start_dim=1))) + + if self.c_dim > 0: + cmap = self.mapping(None, c) + + if self.cmap_dim > 0: + x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + + return x + + +def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512): + NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512} + return NF[2 ** stage] + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = FullyConnectedLayer( + in_features=in_features, out_features=hidden_features, activation="lrelu" + ) + self.fc2 = FullyConnectedLayer( + in_features=hidden_features, out_features=out_features + ) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows + + +def window_reverse(windows, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + # B = windows.shape[0] / (H * W / window_size / window_size) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class Conv2dLayerPartial(nn.Module): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + bias=True, # Apply additive bias before the activation function? + activation="linear", # Activation function: 'relu', 'lrelu', etc. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output to +-X, None = disable clamping. + trainable=True, # Update the weights of this layer during training? + ): + super().__init__() + self.conv = Conv2dLayer( + in_channels, + out_channels, + kernel_size, + bias, + activation, + up, + down, + resample_filter, + conv_clamp, + trainable, + ) + + self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size) + self.slide_winsize = kernel_size ** 2 + self.stride = down + self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0 + + def forward(self, x, mask=None): + if mask is not None: + with torch.no_grad(): + if self.weight_maskUpdater.type() != x.type(): + self.weight_maskUpdater = self.weight_maskUpdater.to(x) + update_mask = F.conv2d( + mask, + self.weight_maskUpdater, + bias=None, + stride=self.stride, + padding=self.padding, + ) + mask_ratio = self.slide_winsize / (update_mask.to(torch.float32) + 1e-8) + update_mask = torch.clamp(update_mask, 0, 1) # 0 or 1 + mask_ratio = torch.mul(mask_ratio, update_mask).to(x.dtype) + x = self.conv(x) + x = torch.mul(x, mask_ratio) + return x, update_mask + else: + x = self.conv(x) + return x, None + + +class WindowAttention(nn.Module): + r"""Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + down_ratio=1, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = FullyConnectedLayer(in_features=dim, out_features=dim) + self.k = FullyConnectedLayer(in_features=dim, out_features=dim) + self.v = FullyConnectedLayer(in_features=dim, out_features=dim) + self.proj = FullyConnectedLayer(in_features=dim, out_features=dim) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask_windows=None, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + norm_x = F.normalize(x, p=2.0, dim=-1, eps=torch.finfo(x.dtype).eps) + q = ( + self.q(norm_x) + .reshape(B_, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + k = ( + self.k(norm_x) + .view(B_, -1, self.num_heads, C // self.num_heads) + .permute(0, 2, 3, 1) + ) + v = ( + self.v(x) + .view(B_, -1, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + attn = (q @ k) * self.scale + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + + if mask_windows is not None: + attn_mask_windows = mask_windows.squeeze(-1).unsqueeze(1).unsqueeze(1) + attn = attn + attn_mask_windows.masked_fill( + attn_mask_windows == 0, float(-100.0) + ).masked_fill(attn_mask_windows == 1, float(0.0)) + with torch.no_grad(): + mask_windows = torch.clamp( + torch.sum(mask_windows, dim=1, keepdim=True), 0, 1 + ).repeat(1, N, 1) + + attn = self.softmax(attn) + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + return x, mask_windows + + +class SwinTransformerBlock(nn.Module): + r"""Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + down_ratio=1, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" + + if self.shift_size > 0: + down_ratio = 1 + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + down_ratio=down_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.fuse = FullyConnectedLayer( + in_features=dim * 2, out_features=dim, activation="lrelu" + ) + + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + + return attn_mask + + def forward(self, x, x_size, mask=None): + # H, W = self.input_resolution + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = x.view(B, H, W, C) + if mask is not None: + mask = mask.view(B, H, W, 1) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) + if mask is not None: + shifted_mask = torch.roll( + mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) + else: + shifted_x = x + if mask is not None: + shifted_mask = mask + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + if mask is not None: + mask_windows = window_partition(shifted_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size, 1) + else: + mask_windows = None + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows, mask_windows = self.attn( + x_windows, mask_windows, mask=self.attn_mask + ) # nW*B, window_size*window_size, C + else: + attn_windows, mask_windows = self.attn( + x_windows, + mask_windows, + mask=self.calculate_mask(x_size).to(x.dtype).to(x.device), + ) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + if mask is not None: + mask_windows = mask_windows.view(-1, self.window_size, self.window_size, 1) + shifted_mask = window_reverse(mask_windows, self.window_size, H, W) + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) + if mask is not None: + mask = torch.roll( + shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) + else: + x = shifted_x + if mask is not None: + mask = shifted_mask + x = x.view(B, H * W, C) + if mask is not None: + mask = mask.view(B, H * W, 1) + + # FFN + x = self.fuse(torch.cat([shortcut, x], dim=-1)) + x = self.mlp(x) + + return x, mask + + +class PatchMerging(nn.Module): + def __init__(self, in_channels, out_channels, down=2): + super().__init__() + self.conv = Conv2dLayerPartial( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + activation="lrelu", + down=down, + ) + self.down = down + + def forward(self, x, x_size, mask=None): + x = token2feature(x, x_size) + if mask is not None: + mask = token2feature(mask, x_size) + x, mask = self.conv(x, mask) + if self.down != 1: + ratio = 1 / self.down + x_size = (int(x_size[0] * ratio), int(x_size[1] * ratio)) + x = feature2token(x) + if mask is not None: + mask = feature2token(mask) + return x, x_size, mask + + +class PatchUpsampling(nn.Module): + def __init__(self, in_channels, out_channels, up=2): + super().__init__() + self.conv = Conv2dLayerPartial( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + activation="lrelu", + up=up, + ) + self.up = up + + def forward(self, x, x_size, mask=None): + x = token2feature(x, x_size) + if mask is not None: + mask = token2feature(mask, x_size) + x, mask = self.conv(x, mask) + if self.up != 1: + x_size = (int(x_size[0] * self.up), int(x_size[1] * self.up)) + x = feature2token(x) + if mask is not None: + mask = feature2token(mask) + return x, x_size, mask + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + down_ratio=1, + mlp_ratio=2.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # patch merging layer + if downsample is not None: + # self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + self.downsample = downsample + else: + self.downsample = None + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + down_ratio=down_ratio, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + self.conv = Conv2dLayerPartial( + in_channels=dim, out_channels=dim, kernel_size=3, activation="lrelu" + ) + + def forward(self, x, x_size, mask=None): + if self.downsample is not None: + x, x_size, mask = self.downsample(x, x_size, mask) + identity = x + for blk in self.blocks: + if self.use_checkpoint: + x, mask = checkpoint.checkpoint(blk, x, x_size, mask) + else: + x, mask = blk(x, x_size, mask) + if mask is not None: + mask = token2feature(mask, x_size) + x, mask = self.conv(token2feature(x, x_size), mask) + x = feature2token(x) + identity + if mask is not None: + mask = feature2token(mask) + return x, x_size, mask + + +class ToToken(nn.Module): + def __init__(self, in_channels=3, dim=128, kernel_size=5, stride=1): + super().__init__() + + self.proj = Conv2dLayerPartial( + in_channels=in_channels, + out_channels=dim, + kernel_size=kernel_size, + activation="lrelu", + ) + + def forward(self, x, mask): + x, mask = self.proj(x, mask) + + return x, mask + + +class EncFromRGB(nn.Module): + def __init__( + self, in_channels, out_channels, activation + ): # res = 2, ..., resolution_log2 + super().__init__() + self.conv0 = Conv2dLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + activation=activation, + ) + self.conv1 = Conv2dLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + activation=activation, + ) + + def forward(self, x): + x = self.conv0(x) + x = self.conv1(x) + + return x + + +class ConvBlockDown(nn.Module): + def __init__( + self, in_channels, out_channels, activation + ): # res = 2, ..., resolution_log + super().__init__() + + self.conv0 = Conv2dLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + activation=activation, + down=2, + ) + self.conv1 = Conv2dLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + activation=activation, + ) + + def forward(self, x): + x = self.conv0(x) + x = self.conv1(x) + + return x + + +def token2feature(x, x_size): + B, N, C = x.shape + h, w = x_size + x = x.permute(0, 2, 1).reshape(B, C, h, w) + return x + + +def feature2token(x): + B, C, H, W = x.shape + x = x.view(B, C, -1).transpose(1, 2) + return x + + +class Encoder(nn.Module): + def __init__( + self, + res_log2, + img_channels, + activation, + patch_size=5, + channels=16, + drop_path_rate=0.1, + ): + super().__init__() + + self.resolution = [] + + for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16 + res = 2 ** i + self.resolution.append(res) + if i == res_log2: + block = EncFromRGB(img_channels * 2 + 1, nf(i), activation) + else: + block = ConvBlockDown(nf(i + 1), nf(i), activation) + setattr(self, "EncConv_Block_%dx%d" % (res, res), block) + + def forward(self, x): + out = {} + for res in self.resolution: + res_log2 = int(np.log2(res)) + x = getattr(self, "EncConv_Block_%dx%d" % (res, res))(x) + out[res_log2] = x + + return out + + +class ToStyle(nn.Module): + def __init__(self, in_channels, out_channels, activation, drop_rate): + super().__init__() + self.conv = nn.Sequential( + Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + down=2, + ), + Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + down=2, + ), + Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + down=2, + ), + ) + + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc = FullyConnectedLayer( + in_features=in_channels, out_features=out_channels, activation=activation + ) + # self.dropout = nn.Dropout(drop_rate) + + def forward(self, x): + x = self.conv(x) + x = self.pool(x) + x = self.fc(x.flatten(start_dim=1)) + # x = self.dropout(x) + + return x + + +class DecBlockFirstV2(nn.Module): + def __init__( + self, + res, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): + super().__init__() + self.res = res + + self.conv0 = Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + ) + self.conv1 = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2 ** res, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, ws, gs, E_features, noise_mode="random"): + # x = self.fc(x).view(x.shape[0], -1, 4, 4) + x = self.conv0(x) + x = x + E_features[self.res] + style = get_style_code(ws[:, 0], gs) + x = self.conv1(x, style, noise_mode=noise_mode) + style = get_style_code(ws[:, 1], gs) + img = self.toRGB(x, style, skip=None) + + return x, img + + +class DecBlock(nn.Module): + def __init__( + self, + res, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): # res = 4, ..., resolution_log2 + super().__init__() + self.res = res + + self.conv0 = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2 ** res, + kernel_size=3, + up=2, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.conv1 = StyleConv( + in_channels=out_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2 ** res, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, img, ws, gs, E_features, noise_mode="random"): + style = get_style_code(ws[:, self.res * 2 - 9], gs) + x = self.conv0(x, style, noise_mode=noise_mode) + x = x + E_features[self.res] + style = get_style_code(ws[:, self.res * 2 - 8], gs) + x = self.conv1(x, style, noise_mode=noise_mode) + style = get_style_code(ws[:, self.res * 2 - 7], gs) + img = self.toRGB(x, style, skip=img) + + return x, img + + +class Decoder(nn.Module): + def __init__( + self, res_log2, activation, style_dim, use_noise, demodulate, img_channels + ): + super().__init__() + self.Dec_16x16 = DecBlockFirstV2( + 4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels + ) + for res in range(5, res_log2 + 1): + setattr( + self, + "Dec_%dx%d" % (2 ** res, 2 ** res), + DecBlock( + res, + nf(res - 1), + nf(res), + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ), + ) + self.res_log2 = res_log2 + + def forward(self, x, ws, gs, E_features, noise_mode="random"): + x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode) + for res in range(5, self.res_log2 + 1): + block = getattr(self, "Dec_%dx%d" % (2 ** res, 2 ** res)) + x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode) + + return img + + +class DecStyleBlock(nn.Module): + def __init__( + self, + res, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): + super().__init__() + self.res = res + + self.conv0 = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2 ** res, + kernel_size=3, + up=2, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.conv1 = StyleConv( + in_channels=out_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2 ** res, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) + + def forward(self, x, img, style, skip, noise_mode="random"): + x = self.conv0(x, style, noise_mode=noise_mode) + x = x + skip + x = self.conv1(x, style, noise_mode=noise_mode) + img = self.toRGB(x, style, skip=img) + + return x, img + + +class FirstStage(nn.Module): + def __init__( + self, + img_channels, + img_resolution=256, + dim=180, + w_dim=512, + use_noise=False, + demodulate=True, + activation="lrelu", + ): + super().__init__() + res = 64 + + self.conv_first = Conv2dLayerPartial( + in_channels=img_channels + 1, + out_channels=dim, + kernel_size=3, + activation=activation, + ) + self.enc_conv = nn.ModuleList() + down_time = int(np.log2(img_resolution // res)) + # 根据图片尺寸构建 swim transformer 的层数 + for i in range(down_time): # from input size to 64 + self.enc_conv.append( + Conv2dLayerPartial( + in_channels=dim, + out_channels=dim, + kernel_size=3, + down=2, + activation=activation, + ) + ) + + # from 64 -> 16 -> 64 + depths = [2, 3, 4, 3, 2] + ratios = [1, 1 / 2, 1 / 2, 2, 2] + num_heads = 6 + window_sizes = [8, 16, 16, 16, 8] + drop_path_rate = 0.1 + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + self.tran = nn.ModuleList() + for i, depth in enumerate(depths): + res = int(res * ratios[i]) + if ratios[i] < 1: + merge = PatchMerging(dim, dim, down=int(1 / ratios[i])) + elif ratios[i] > 1: + merge = PatchUpsampling(dim, dim, up=ratios[i]) + else: + merge = None + self.tran.append( + BasicLayer( + dim=dim, + input_resolution=[res, res], + depth=depth, + num_heads=num_heads, + window_size=window_sizes[i], + drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])], + downsample=merge, + ) + ) + + # global style + down_conv = [] + for i in range(int(np.log2(16))): + down_conv.append( + Conv2dLayer( + in_channels=dim, + out_channels=dim, + kernel_size=3, + down=2, + activation=activation, + ) + ) + down_conv.append(nn.AdaptiveAvgPool2d((1, 1))) + self.down_conv = nn.Sequential(*down_conv) + self.to_style = FullyConnectedLayer( + in_features=dim, out_features=dim * 2, activation=activation + ) + self.ws_style = FullyConnectedLayer( + in_features=w_dim, out_features=dim, activation=activation + ) + self.to_square = FullyConnectedLayer( + in_features=dim, out_features=16 * 16, activation=activation + ) + + style_dim = dim * 3 + self.dec_conv = nn.ModuleList() + for i in range(down_time): # from 64 to input size + res = res * 2 + self.dec_conv.append( + DecStyleBlock( + res, + dim, + dim, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ) + ) + + def forward(self, images_in, masks_in, ws, noise_mode="random"): + x = torch.cat([masks_in - 0.5, images_in * masks_in], dim=1) + + skips = [] + x, mask = self.conv_first(x, masks_in) # input size + skips.append(x) + for i, block in enumerate(self.enc_conv): # input size to 64 + x, mask = block(x, mask) + if i != len(self.enc_conv) - 1: + skips.append(x) + + x_size = x.size()[-2:] + x = feature2token(x) + mask = feature2token(mask) + mid = len(self.tran) // 2 + for i, block in enumerate(self.tran): # 64 to 16 + if i < mid: + x, x_size, mask = block(x, x_size, mask) + skips.append(x) + elif i > mid: + x, x_size, mask = block(x, x_size, None) + x = x + skips[mid - i] + else: + x, x_size, mask = block(x, x_size, None) + + mul_map = torch.ones_like(x) * 0.5 + mul_map = F.dropout(mul_map, training=True) + ws = self.ws_style(ws[:, -1]) + add_n = self.to_square(ws).unsqueeze(1) + add_n = ( + F.interpolate( + add_n, size=x.size(1), mode="linear", align_corners=False + ) + .squeeze(1) + .unsqueeze(-1) + ) + x = x * mul_map + add_n * (1 - mul_map) + gs = self.to_style( + self.down_conv(token2feature(x, x_size)).flatten(start_dim=1) + ) + style = torch.cat([gs, ws], dim=1) + + x = token2feature(x, x_size).contiguous() + img = None + for i, block in enumerate(self.dec_conv): + x, img = block( + x, img, style, skips[len(self.dec_conv) - i - 1], noise_mode=noise_mode + ) + + # ensemble + img = img * (1 - masks_in) + images_in * masks_in + + return img + + +class SynthesisNet(nn.Module): + def __init__( + self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels=3, # Number of color channels. + channel_base=32768, # Overall multiplier for the number of channels. + channel_decay=1.0, + channel_max=512, # Maximum number of channels in any layer. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + drop_rate=0.5, + use_noise=False, + demodulate=True, + ): + super().__init__() + resolution_log2 = int(np.log2(img_resolution)) + assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4 + + self.num_layers = resolution_log2 * 2 - 3 * 2 + self.img_resolution = img_resolution + self.resolution_log2 = resolution_log2 + + # first stage + self.first_stage = FirstStage( + img_channels, + img_resolution=img_resolution, + w_dim=w_dim, + use_noise=False, + demodulate=demodulate, + ) + + # second stage + self.enc = Encoder( + resolution_log2, img_channels, activation, patch_size=5, channels=16 + ) + self.to_square = FullyConnectedLayer( + in_features=w_dim, out_features=16 * 16, activation=activation + ) + self.to_style = ToStyle( + in_channels=nf(4), + out_channels=nf(2) * 2, + activation=activation, + drop_rate=drop_rate, + ) + style_dim = w_dim + nf(2) * 2 + self.dec = Decoder( + resolution_log2, activation, style_dim, use_noise, demodulate, img_channels + ) + + def forward(self, images_in, masks_in, ws, noise_mode="random", return_stg1=False): + out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode) + + # encoder + x = images_in * masks_in + out_stg1 * (1 - masks_in) + x = torch.cat([masks_in - 0.5, x, images_in * masks_in], dim=1) + E_features = self.enc(x) + + fea_16 = E_features[4] + mul_map = torch.ones_like(fea_16) * 0.5 + mul_map = F.dropout(mul_map, training=True) + add_n = self.to_square(ws[:, 0]).view(-1, 16, 16).unsqueeze(1) + add_n = F.interpolate( + add_n, size=fea_16.size()[-2:], mode="bilinear", align_corners=False + ) + fea_16 = fea_16 * mul_map + add_n * (1 - mul_map) + E_features[4] = fea_16 + + # style + gs = self.to_style(fea_16) + + # decoder + img = self.dec(fea_16, ws, gs, E_features, noise_mode=noise_mode) + + # ensemble + img = img * (1 - masks_in) + images_in * masks_in + + if not return_stg1: + return img + else: + return img, out_stg1 + + +class Generator(nn.Module): + def __init__( + self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # resolution of generated image + img_channels, # Number of input color channels. + synthesis_kwargs={}, # Arguments for SynthesisNetwork. + mapping_kwargs={}, # Arguments for MappingNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + + self.synthesis = SynthesisNet( + w_dim=w_dim, + img_resolution=img_resolution, + img_channels=img_channels, + **synthesis_kwargs, + ) + self.mapping = MappingNet( + z_dim=z_dim, + c_dim=c_dim, + w_dim=w_dim, + num_ws=self.synthesis.num_layers, + **mapping_kwargs, + ) + + def forward( + self, + images_in, + masks_in, + z, + c, + truncation_psi=1, + truncation_cutoff=None, + skip_w_avg_update=False, + noise_mode="none", + return_stg1=False, + ): + ws = self.mapping( + z, + c, + truncation_psi=truncation_psi, + truncation_cutoff=truncation_cutoff, + skip_w_avg_update=skip_w_avg_update, + ) + img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode) + return img + + +class Discriminator(torch.nn.Module): + def __init__( + self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + channel_decay=1, + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + activation="lrelu", + mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. + ): + super().__init__() + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + + resolution_log2 = int(np.log2(img_resolution)) + assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4 + self.resolution_log2 = resolution_log2 + + if cmap_dim == None: + cmap_dim = nf(2) + if c_dim == 0: + cmap_dim = 0 + self.cmap_dim = cmap_dim + + if c_dim > 0: + self.mapping = MappingNet( + z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None + ) + + Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)] + for res in range(resolution_log2, 2, -1): + Dis.append(DisBlock(nf(res), nf(res - 1), activation)) + + if mbstd_num_channels > 0: + Dis.append( + MinibatchStdLayer( + group_size=mbstd_group_size, num_channels=mbstd_num_channels + ) + ) + Dis.append( + Conv2dLayer( + nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation + ) + ) + self.Dis = nn.Sequential(*Dis) + + self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation) + self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim) + + # for 64x64 + Dis_stg1 = [DisFromRGB(img_channels + 1, nf(resolution_log2) // 2, activation)] + for res in range(resolution_log2, 2, -1): + Dis_stg1.append(DisBlock(nf(res) // 2, nf(res - 1) // 2, activation)) + + if mbstd_num_channels > 0: + Dis_stg1.append( + MinibatchStdLayer( + group_size=mbstd_group_size, num_channels=mbstd_num_channels + ) + ) + Dis_stg1.append( + Conv2dLayer( + nf(2) // 2 + mbstd_num_channels, + nf(2) // 2, + kernel_size=3, + activation=activation, + ) + ) + self.Dis_stg1 = nn.Sequential(*Dis_stg1) + + self.fc0_stg1 = FullyConnectedLayer( + nf(2) // 2 * 4 ** 2, nf(2) // 2, activation=activation + ) + self.fc1_stg1 = FullyConnectedLayer( + nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim + ) + + def forward(self, images_in, masks_in, images_stg1, c): + x = self.Dis(torch.cat([masks_in - 0.5, images_in], dim=1)) + x = self.fc1(self.fc0(x.flatten(start_dim=1))) + + x_stg1 = self.Dis_stg1(torch.cat([masks_in - 0.5, images_stg1], dim=1)) + x_stg1 = self.fc1_stg1(self.fc0_stg1(x_stg1.flatten(start_dim=1))) + + if self.c_dim > 0: + cmap = self.mapping(None, c) + + if self.cmap_dim > 0: + x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + x_stg1 = (x_stg1 * cmap).sum(dim=1, keepdim=True) * ( + 1 / np.sqrt(self.cmap_dim) + ) + + return x, x_stg1 + + +MAT_MODEL_URL = os.environ.get( + "MAT_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_mat/Places_512_FullData_G.pth", +) + +MAT_MODEL_MD5 = os.environ.get("MAT_MODEL_MD5", "8ca927835fa3f5e21d65ffcb165377ed") + + +class MAT(InpaintModel): + name = "mat" + min_size = 512 + pad_mod = 512 + pad_to_square = True + + def init_model(self, device, **kwargs): + seed = 240 # pick up a random number + set_seed(seed) + + fp16 = not kwargs.get("no_half", False) + use_gpu = "cuda" in str(device) and torch.cuda.is_available() + self.torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + + G = Generator( + z_dim=512, + c_dim=0, + w_dim=512, + img_resolution=512, + img_channels=3, + mapping_kwargs={"torch_dtype": self.torch_dtype}, + ).to(self.torch_dtype) + # fmt: off + self.model = load_model(G, MAT_MODEL_URL, device, MAT_MODEL_MD5) + self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(self.torch_dtype).to(device) + self.label = torch.zeros([1, self.model.c_dim], device=device).to(self.torch_dtype) + # fmt: on + + @staticmethod + def is_downloaded() -> bool: + return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL)) + + def forward(self, image, mask, config: Config): + """Input images and output images have same size + images: [H, W, C] RGB + masks: [H, W] mask area == 255 + return: BGR IMAGE + """ + + image = norm_img(image) # [0, 1] + image = image * 2 - 1 # [0, 1] -> [-1, 1] + + mask = (mask > 127) * 255 + mask = 255 - mask + mask = norm_img(mask) + + image = ( + torch.from_numpy(image).unsqueeze(0).to(self.torch_dtype).to(self.device) + ) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.torch_dtype).to(self.device) + + output = self.model( + image, mask, self.z, self.label, truncation_psi=1, noise_mode="none" + ) + output = ( + (output.permute(0, 2, 3, 1) * 127.5 + 127.5) + .round() + .clamp(0, 255) + .to(torch.uint8) + ) + output = output[0].cpu().numpy() + cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return cur_res diff --git a/lama_cleaner/model/opencv2.py b/lama_cleaner/model/opencv2.py new file mode 100644 index 0000000000000000000000000000000000000000..5c9bce2cce248890e89e4754e4e137febbf333ec --- /dev/null +++ b/lama_cleaner/model/opencv2.py @@ -0,0 +1,28 @@ +import cv2 +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.schema import Config + +flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA} + + +class OpenCV2(InpaintModel): + name = "cv2" + pad_mod = 1 + + @staticmethod + def is_downloaded() -> bool: + return True + + def forward(self, image, mask, config: Config): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] + return: BGR IMAGE + """ + cur_res = cv2.inpaint( + image[:, :, ::-1], + mask, + inpaintRadius=config.cv2_radius, + flags=flag_map[config.cv2_flag], + ) + return cur_res diff --git a/lama_cleaner/model/paint_by_example.py b/lama_cleaner/model/paint_by_example.py new file mode 100644 index 0000000000000000000000000000000000000000..0004341864dd0cdf99bcc43cec996c0023496a37 --- /dev/null +++ b/lama_cleaner/model/paint_by_example.py @@ -0,0 +1,79 @@ +import PIL +import PIL.Image +import cv2 +import torch +from diffusers import DiffusionPipeline +from loguru import logger + +from lama_cleaner.model.base import DiffusionInpaintModel +from lama_cleaner.model.utils import set_seed +from lama_cleaner.schema import Config + + +class PaintByExample(DiffusionInpaintModel): + name = "paint_by_example" + pad_mod = 8 + min_size = 512 + + def init_model(self, device: torch.device, **kwargs): + fp16 = not kwargs.get('no_half', False) + use_gpu = device == torch.device('cuda') and torch.cuda.is_available() + torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)} + + if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False): + logger.info("Disable Paint By Example Model NSFW checker") + model_kwargs.update(dict( + safety_checker=None, + requires_safety_checker=False + )) + + self.model = DiffusionPipeline.from_pretrained( + "Fantasy-Studio/Paint-by-Example", + torch_dtype=torch_dtype, + **model_kwargs + ) + + self.model.enable_attention_slicing() + if kwargs.get('enable_xformers', False): + self.model.enable_xformers_memory_efficient_attention() + + # TODO: gpu_id + if kwargs.get('cpu_offload', False) and use_gpu: + self.model.image_encoder = self.model.image_encoder.to(device) + self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) + + def forward(self, image, mask, config: Config): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + output = self.model( + image=PIL.Image.fromarray(image), + mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), + example_image=config.paint_by_example_example_image, + num_inference_steps=config.paint_by_example_steps, + output_type='np.array', + generator=torch.manual_seed(config.paint_by_example_seed) + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + def forward_post_process(self, result, image, mask, config): + if config.paint_by_example_match_histograms: + result = self._match_histograms(result, image[:, :, ::-1], mask) + + if config.paint_by_example_mask_blur != 0: + k = 2 * config.paint_by_example_mask_blur + 1 + mask = cv2.GaussianBlur(mask, (k, k), 0) + return result, image, mask + + @staticmethod + def is_downloaded() -> bool: + # model will be downloaded when app start, and can't switch in frontend settings + return True diff --git a/lama_cleaner/model/pipeline/__init__.py b/lama_cleaner/model/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c38f1fc9646fc343e71db484737a9bdd6ea76c7 --- /dev/null +++ b/lama_cleaner/model/pipeline/__init__.py @@ -0,0 +1,3 @@ +from .pipeline_stable_diffusion_controlnet_inpaint import ( + StableDiffusionControlNetInpaintPipeline, +) diff --git a/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py b/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..9d22592fb80880960b078ca98b4725c560f3d4ca --- /dev/null +++ b/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py @@ -0,0 +1,585 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copy from https://github.com/mikonvergence/ControlNetInpaint/blob/main/src/pipeline_stable_diffusion_controlnet_inpaint.py + +import torch +import PIL.Image +import numpy as np + +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import * + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + >>> # download an image + >>> image = load_image( + ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + ... ) + >>> image = np.array(image) + >>> mask_image = load_image( + ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + ... ) + >>> mask_image = np.array(mask_image) + >>> # get canny image + >>> canny_image = cv2.Canny(image, 100, 200) + >>> canny_image = canny_image[:, :, None] + >>> canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2) + >>> canny_image = Image.fromarray(canny_image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> # remove following line if xformers is not installed + >>> pipe.enable_xformers_memory_efficient_attention() + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking doggo", + ... num_inference_steps=20, + ... generator=generator, + ... image=image, + ... control_image=canny_image, + ... mask_image=mask_image + ... ).images[0] + ``` +""" + + +def prepare_mask_and_masked_image(image, mask): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError( + f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not" + ) + + # Batch single image + if image.ndim == 3: + assert ( + image.shape[0] == 3 + ), "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert ( + image.ndim == 4 and mask.ndim == 4 + ), "Image and Mask must have 4 dimensions" + assert ( + image.shape[-2:] == mask.shape[-2:] + ), "Image and Mask must have the same spatial dimensions" + assert ( + image.shape[0] == mask.shape[0] + ), "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError( + f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not" + ) + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = np.concatenate( + [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0 + ) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + return mask, masked_image + + +class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance. + + This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`]): + Provides additional conditioning to the unet during the denoising process + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + height, + width, + dtype, + device, + generator, + do_classifier_free_guidance, + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + # encode the mask image into latents space so we can concatenate it to the latents + if isinstance(generator, list): + masked_image_latents = [ + self.vae.encode(masked_image[i : i + 1]).latent_dist.sample( + generator=generator[i] + ) + for i in range(batch_size) + ] + masked_image_latents = torch.cat(masked_image_latents, dim=0) + else: + masked_image_latents = self.vae.encode(masked_image).latent_dist.sample( + generator=generator + ) + masked_image_latents = self.vae.config.scaling_factor * masked_image_latents + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) + if do_classifier_free_guidance + else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + control_image: Union[ + torch.FloatTensor, + PIL.Image.Image, + List[torch.FloatTensor], + List[PIL.Image.Image], + ] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: float = 1.0, + ): + r""" + Function invoked when calling the pipeline for generation. + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can + also be accepted as an image. The control image is automatically resized to fit the output image. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. + Examples: + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, control_image) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare image + control_image = self.prepare_image( + control_image, + width, + height, + batch_size * num_images_per_prompt, + num_images_per_prompt, + device, + self.controlnet.dtype, + ) + + if do_classifier_free_guidance: + control_image = torch.cat([control_image] * 2) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.controlnet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # EXTRA: prepare mask latents + mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + down_block_res_samples, mid_block_res_sample = self.controlnet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=control_image, + return_dict=False, + ) + + down_block_res_samples = [ + down_block_res_sample * controlnet_conditioning_scale + for down_block_res_sample in down_block_res_samples + ] + mid_block_res_sample *= controlnet_conditioning_scale + + # predict the noise residual + latent_model_input = torch.cat( + [latent_model_input, mask, masked_image_latents], dim=1 + ) + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + else: + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) diff --git a/lama_cleaner/model/plms_sampler.py b/lama_cleaner/model/plms_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..4b5d7668e45c0fe12c580059b83255722c6c8576 --- /dev/null +++ b/lama_cleaner/model/plms_sampler.py @@ -0,0 +1,225 @@ +# From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py +import torch +import numpy as np +from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like +from tqdm import tqdm + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + steps, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=False, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + return img + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py new file mode 100644 index 0000000000000000000000000000000000000000..ab372e6c3794b0553a206b55a63174e464072100 --- /dev/null +++ b/lama_cleaner/model/sd.py @@ -0,0 +1,193 @@ +import gc + +import PIL.Image +import cv2 +import numpy as np +import torch +from loguru import logger + +from lama_cleaner.model.base import DiffusionInpaintModel +from lama_cleaner.model.utils import torch_gc, get_scheduler +from lama_cleaner.schema import Config + + +class CPUTextEncoderWrapper: + def __init__(self, text_encoder, torch_dtype): + self.config = text_encoder.config + self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True) + self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) + self.torch_dtype = torch_dtype + del text_encoder + torch_gc() + + def __call__(self, x, **kwargs): + input_device = x.device + return [ + self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0] + .to(input_device) + .to(self.torch_dtype) + ] + + @property + def dtype(self): + return self.torch_dtype + + +def load_from_local_model(local_model_path, torch_dtype, disable_nsfw=True): + from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + download_from_original_stable_diffusion_ckpt, + ) + from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline + + logger.info(f"Converting {local_model_path} to diffusers pipeline") + + pipe = download_from_original_stable_diffusion_ckpt( + local_model_path, + num_in_channels=9, + from_safetensors=local_model_path.endswith("safetensors"), + device="cpu", + ) + + inpaint_pipe = StableDiffusionInpaintPipeline( + vae=pipe.vae, + text_encoder=pipe.text_encoder, + tokenizer=pipe.tokenizer, + unet=pipe.unet, + scheduler=pipe.scheduler, + safety_checker=None if disable_nsfw else pipe.safety_checker, + feature_extractor=None if disable_nsfw else pipe.safety_checker, + requires_safety_checker=not disable_nsfw, + ) + + del pipe + gc.collect() + return inpaint_pipe.to(torch_dtype=torch_dtype) + + +class SD(DiffusionInpaintModel): + pad_mod = 8 + min_size = 512 + + def init_model(self, device: torch.device, **kwargs): + from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline + + fp16 = not kwargs.get("no_half", False) + + model_kwargs = { + "local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"]) + } + if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): + logger.info("Disable Stable Diffusion Model NSFW checker") + model_kwargs.update( + dict( + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + ) + + use_gpu = device == torch.device("cuda") and torch.cuda.is_available() + torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + + if kwargs.get("sd_local_model_path", None): + self.model = load_from_local_model( + kwargs["sd_local_model_path"], + torch_dtype=torch_dtype, + ) + else: + self.model = StableDiffusionInpaintPipeline.from_pretrained( + self.model_id_or_path, + revision="fp16" if use_gpu and fp16 else "main", + torch_dtype=torch_dtype, + use_auth_token=kwargs["hf_access_token"], + **model_kwargs, + ) + + # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing + self.model.enable_attention_slicing() + # https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention + if kwargs.get("enable_xformers", False): + self.model.enable_xformers_memory_efficient_attention() + + if kwargs.get("cpu_offload", False) and use_gpu: + # TODO: gpu_id + logger.info("Enable sequential cpu offload") + self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) + if kwargs["sd_cpu_textencoder"]: + logger.info("Run Stable Diffusion TextEncoder on CPU") + self.model.text_encoder = CPUTextEncoderWrapper( + self.model.text_encoder, torch_dtype + ) + + self.callback = kwargs.pop("callback", None) + + def forward(self, image, mask, config: Config): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + + scheduler_config = self.model.scheduler.config + scheduler = get_scheduler(config.sd_sampler, scheduler_config) + self.model.scheduler = scheduler + + if config.sd_mask_blur != 0: + k = 2 * config.sd_mask_blur + 1 + mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] + + img_h, img_w = image.shape[:2] + + output = self.model( + image=PIL.Image.fromarray(image), + prompt=config.prompt, + negative_prompt=config.negative_prompt, + mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), + num_inference_steps=config.sd_steps, + guidance_scale=config.sd_guidance_scale, + output_type="np.array", + callback=self.callback, + height=img_h, + width=img_w, + generator=torch.manual_seed(config.sd_seed), + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + def forward_post_process(self, result, image, mask, config): + if config.sd_match_histograms: + result = self._match_histograms(result, image[:, :, ::-1], mask) + + if config.sd_mask_blur != 0: + k = 2 * config.sd_mask_blur + 1 + mask = cv2.GaussianBlur(mask, (k, k), 0) + return result, image, mask + + @staticmethod + def is_downloaded() -> bool: + # model will be downloaded when app start, and can't switch in frontend settings + return True + + +class SD15(SD): + name = "sd1.5" + model_id_or_path = "runwayml/stable-diffusion-inpainting" + + +class Anything4(SD): + name = "anything4" + model_id_or_path = "Sanster/anything-4.0-inpainting" + + +class RealisticVision14(SD): + name = "realisticVision1.4" + model_id_or_path = "Sanster/Realistic_Vision_V1.4-inpainting" + + +class SD2(SD): + name = "sd2" + model_id_or_path = "stabilityai/stable-diffusion-2-inpainting" diff --git a/lama_cleaner/model/utils.py b/lama_cleaner/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7552423745b786d523abab435bdc576ccbcbb3de --- /dev/null +++ b/lama_cleaner/model/utils.py @@ -0,0 +1,941 @@ +import math +import random +from typing import Any + +import torch +import numpy as np +import collections +from itertools import repeat + +from diffusers import ( + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + UniPCMultistepScheduler, +) + +from lama_cleaner.schema import SDSampler +from torch import conv2d, conv_transpose2d + + +def make_beta_schedule( + device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ).to(device) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2).to(device) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=device) + + args = timesteps[:, None].float() * freqs[None] + + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +###### MAT and FcF ####### + + +def normalize_2nd_moment(x, dim=1): + return ( + x * (x.square().mean(dim=dim, keepdim=True) + torch.finfo(x.dtype).eps).rsqrt() + ) + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops.""" + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + return x + + +def bias_act( + x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None, impl="ref" +): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ["ref", "cuda"] + return _bias_act_ref( + x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp + ) + + +def _get_filter_size(f): + if f is None: + return 1, 1 + + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + + fw = int(fw) + fh = int(fh) + assert fw >= 1 and fh >= 1 + return fw, fh + + +def _get_weight_shape(w): + shape = [int(sz) for sz in w.shape] + return shape + + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + + +def setup_filter( + f, + device=torch.device("cpu"), + normalize=True, + flip_filter=False, + gain=1, + separable=None, +): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = f.ndim == 1 and f.numel() >= 8 + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + +activation_funcs = { + "linear": EasyDict( + func=lambda x, **_: x, + def_alpha=0, + def_gain=1, + cuda_idx=1, + ref="", + has_2nd_grad=False, + ), + "relu": EasyDict( + func=lambda x, **_: torch.nn.functional.relu(x), + def_alpha=0, + def_gain=np.sqrt(2), + cuda_idx=2, + ref="y", + has_2nd_grad=False, + ), + "lrelu": EasyDict( + func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), + def_alpha=0.2, + def_gain=np.sqrt(2), + cuda_idx=3, + ref="y", + has_2nd_grad=False, + ), + "tanh": EasyDict( + func=lambda x, **_: torch.tanh(x), + def_alpha=0, + def_gain=1, + cuda_idx=4, + ref="y", + has_2nd_grad=True, + ), + "sigmoid": EasyDict( + func=lambda x, **_: torch.sigmoid(x), + def_alpha=0, + def_gain=1, + cuda_idx=5, + ref="y", + has_2nd_grad=True, + ), + "elu": EasyDict( + func=lambda x, **_: torch.nn.functional.elu(x), + def_alpha=0, + def_gain=1, + cuda_idx=6, + ref="y", + has_2nd_grad=True, + ), + "selu": EasyDict( + func=lambda x, **_: torch.nn.functional.selu(x), + def_alpha=0, + def_gain=1, + cuda_idx=7, + ref="y", + has_2nd_grad=True, + ), + "softplus": EasyDict( + func=lambda x, **_: torch.nn.functional.softplus(x), + def_alpha=0, + def_gain=1, + cuda_idx=8, + ref="y", + has_2nd_grad=True, + ), + "swish": EasyDict( + func=lambda x, **_: torch.sigmoid(x) * x, + def_alpha=0, + def_gain=np.sqrt(2), + cuda_idx=9, + ref="x", + has_2nd_grad=True, + ), +} + + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # assert isinstance(x, torch.Tensor) + # assert impl in ['ref', 'cuda'] + return _upfirdn2d_ref( + x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain + ) + + +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.""" + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + # upx, upy = _parse_scaling(up) + # downx, downy = _parse_scaling(down) + + upx, upy = up, up + downx, downy = down, down + + # padx0, padx1, pady0, pady1 = _parse_padding(padding) + padx0, padx1, pady0, pady1 = padding[0], padding[1], padding[2], padding[3] + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad( + x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)] + ) + x = x[ + :, + :, + max(-pady0, 0) : x.shape[2] - max(-pady1, 0), + max(-padx0, 0) : x.shape[3] - max(-padx1, 0), + ] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl="cuda"): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + # padx0, padx1, pady0, pady1 = _parse_padding(padding) + padx0, padx1, pady0, pady1 = padding, padding, padding, padding + + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d( + x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl + ) + + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl="cuda"): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + # upx, upy = up, up + padx0, padx1, pady0, pady1 = _parse_padding(padding) + # padx0, padx1, pady0, pady1 = padding, padding, padding, padding + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d( + x, + f, + up=up, + padding=p, + flip_filter=flip_filter, + gain=gain * upx * upy, + impl=impl, + ) + + +class MinibatchStdLayer(torch.nn.Module): + def __init__(self, group_size, num_channels=1): + super().__init__() + self.group_size = group_size + self.num_channels = num_channels + + def forward(self, x): + N, C, H, W = x.shape + G = ( + torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) + if self.group_size is not None + else N + ) + F = self.num_channels + c = C // F + + y = x.reshape( + G, -1, F, c, H, W + ) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. + y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. + y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. + y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. + y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels. + y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. + y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. + x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. + return x + + +class FullyConnectedLayer(torch.nn.Module): + def __init__( + self, + in_features, # Number of input features. + out_features, # Number of output features. + bias=True, # Apply additive bias before the activation function? + activation="linear", # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=1, # Learning rate multiplier. + bias_init=0, # Initial value for the additive bias. + ): + super().__init__() + self.weight = torch.nn.Parameter( + torch.randn([out_features, in_features]) / lr_multiplier + ) + self.bias = ( + torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) + if bias + else None + ) + self.activation = activation + + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + + def forward(self, x): + w = self.weight * self.weight_gain + b = self.bias + if b is not None and self.bias_gain != 1: + b = b * self.bias_gain + + if self.activation == "linear" and b is not None: + # out = torch.addmm(b.unsqueeze(0), x, w.t()) + x = x.matmul(w.t()) + out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)]) + else: + x = x.matmul(w.t()) + out = bias_act(x, b, act=self.activation, dim=x.ndim - 1) + return out + + +def _conv2d_wrapper( + x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True +): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.""" + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + if ( + not flip_weight + ): # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + w = w.flip([2, 3]) + + # Workaround performance pitfall in cuDNN 8.0.5, triggered when using + # 1x1 kernel + memory_format=channels_last + less than 64 channels. + if ( + kw == 1 + and kh == 1 + and stride == 1 + and padding in [0, [0, 0], (0, 0)] + and not transpose + ): + if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: + if out_channels <= 4 and groups == 1: + in_shape = x.shape + x = w.squeeze(3).squeeze(2) @ x.reshape( + [in_shape[0], in_channels_per_group, -1] + ) + x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) + else: + x = x.to(memory_format=torch.contiguous_format) + w = w.to(memory_format=torch.contiguous_format) + x = conv2d(x, w, groups=groups) + return x.to(memory_format=torch.channels_last) + + # Otherwise => execute using conv2d_gradfix. + op = conv_transpose2d if transpose else conv2d + return op(x, w, stride=stride, padding=padding, groups=groups) + + +def conv2d_resample( + x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False +): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2]) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}" + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + # px0, px1, py0, py1 = _parse_padding(padding) + px0, px1, py0, py1 = padding, padding, padding, padding + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d( + x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter + ) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + x = upfirdn2d( + x=x, + f=f, + up=up, + padding=[px0, px1, py0, py1], + gain=up ** 2, + flip_filter=flip_filter, + ) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter) + x = _conv2d_wrapper( + x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight + ) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape( + groups * in_channels_per_group, out_channels // groups, kh, kw + ) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper( + x=x, + w=w, + stride=up, + padding=[pyt, pxt], + groups=groups, + transpose=True, + flip_weight=(not flip_weight), + ) + x = upfirdn2d( + x=x, + f=f, + padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], + gain=up ** 2, + flip_filter=flip_filter, + ) + if down > 1: + x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper( + x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight + ) + + # Fallback: Generic reference implementation. + x = upfirdn2d( + x=x, + f=(f if up > 1 else None), + up=up, + padding=[px0, px1, py0, py1], + gain=up ** 2, + flip_filter=flip_filter, + ) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + if down > 1: + x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + + +class Conv2dLayer(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + bias=True, # Apply additive bias before the activation function? + activation="linear", # Activation function: 'relu', 'lrelu', etc. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output to +-X, None = disable clamping. + channels_last=False, # Expect the input to have memory_format=channels_last? + trainable=True, # Update the weights of this layer during training? + ): + super().__init__() + self.activation = activation + self.up = up + self.down = down + self.register_buffer("resample_filter", setup_filter(resample_filter)) + self.conv_clamp = conv_clamp + self.padding = kernel_size // 2 + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + self.act_gain = activation_funcs[activation].def_gain + + memory_format = ( + torch.channels_last if channels_last else torch.contiguous_format + ) + weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to( + memory_format=memory_format + ) + bias = torch.zeros([out_channels]) if bias else None + if trainable: + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) if bias is not None else None + else: + self.register_buffer("weight", weight) + if bias is not None: + self.register_buffer("bias", bias) + else: + self.bias = None + + def forward(self, x, gain=1): + w = self.weight * self.weight_gain + x = conv2d_resample( + x=x, + w=w, + f=self.resample_filter, + up=self.up, + down=self.down, + padding=self.padding, + ) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + out = bias_act( + x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp + ) + return out + + +def torch_gc(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +def set_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_scheduler(sd_sampler, scheduler_config): + if sd_sampler == SDSampler.ddim: + return DDIMScheduler.from_config(scheduler_config) + elif sd_sampler == SDSampler.pndm: + return PNDMScheduler.from_config(scheduler_config) + elif sd_sampler == SDSampler.k_lms: + return LMSDiscreteScheduler.from_config(scheduler_config) + elif sd_sampler == SDSampler.k_euler: + return EulerDiscreteScheduler.from_config(scheduler_config) + elif sd_sampler == SDSampler.k_euler_a: + return EulerAncestralDiscreteScheduler.from_config(scheduler_config) + elif sd_sampler == SDSampler.dpm_plus_plus: + return DPMSolverMultistepScheduler.from_config(scheduler_config) + elif sd_sampler == SDSampler.uni_pc: + return UniPCMultistepScheduler.from_config(scheduler_config) + else: + raise ValueError(sd_sampler) diff --git a/lama_cleaner/model/zits.py b/lama_cleaner/model/zits.py new file mode 100644 index 0000000000000000000000000000000000000000..340f428129dc3d99f30d32505efeaf70c7131fe9 --- /dev/null +++ b/lama_cleaner/model/zits.py @@ -0,0 +1,447 @@ +import os +import time + +import cv2 +import torch +import torch.nn.functional as F + +from lama_cleaner.helper import get_cache_path_by_url, load_jit_model +from lama_cleaner.schema import Config +import numpy as np + +from lama_cleaner.model.base import InpaintModel + +ZITS_INPAINT_MODEL_URL = os.environ.get( + "ZITS_INPAINT_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_zits/zits-inpaint-0717.pt", +) +ZITS_INPAINT_MODEL_MD5 = os.environ.get( + "ZITS_INPAINT_MODEL_MD5", "9978cc7157dc29699e42308d675b2154" +) + +ZITS_EDGE_LINE_MODEL_URL = os.environ.get( + "ZITS_EDGE_LINE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_zits/zits-edge-line-0717.pt", +) +ZITS_EDGE_LINE_MODEL_MD5 = os.environ.get( + "ZITS_EDGE_LINE_MODEL_MD5", "55e31af21ba96bbf0c80603c76ea8c5f" +) + +ZITS_STRUCTURE_UPSAMPLE_MODEL_URL = os.environ.get( + "ZITS_STRUCTURE_UPSAMPLE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_zits/zits-structure-upsample-0717.pt", +) +ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5 = os.environ.get( + "ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5", "3d88a07211bd41b2ec8cc0d999f29927" +) + +ZITS_WIRE_FRAME_MODEL_URL = os.environ.get( + "ZITS_WIRE_FRAME_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_zits/zits-wireframe-0717.pt", +) +ZITS_WIRE_FRAME_MODEL_MD5 = os.environ.get( + "ZITS_WIRE_FRAME_MODEL_MD5", "a9727c63a8b48b65c905d351b21ce46b" +) + + +def resize(img, height, width, center_crop=False): + imgh, imgw = img.shape[0:2] + + if center_crop and imgh != imgw: + # center crop + side = np.minimum(imgh, imgw) + j = (imgh - side) // 2 + i = (imgw - side) // 2 + img = img[j : j + side, i : i + side, ...] + + if imgh > height and imgw > width: + inter = cv2.INTER_AREA + else: + inter = cv2.INTER_LINEAR + img = cv2.resize(img, (height, width), interpolation=inter) + + return img + + +def to_tensor(img, scale=True, norm=False): + if img.ndim == 2: + img = img[:, :, np.newaxis] + c = img.shape[-1] + + if scale: + img_t = torch.from_numpy(img).permute(2, 0, 1).float().div(255) + else: + img_t = torch.from_numpy(img).permute(2, 0, 1).float() + + if norm: + mean = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1) + std = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1) + img_t = (img_t - mean) / std + return img_t + + +def load_masked_position_encoding(mask): + ones_filter = np.ones((3, 3), dtype=np.float32) + d_filter1 = np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32) + d_filter2 = np.array([[0, 0, 0], [1, 1, 0], [1, 1, 0]], dtype=np.float32) + d_filter3 = np.array([[0, 1, 1], [0, 1, 1], [0, 0, 0]], dtype=np.float32) + d_filter4 = np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]], dtype=np.float32) + str_size = 256 + pos_num = 128 + + ori_mask = mask.copy() + ori_h, ori_w = ori_mask.shape[0:2] + ori_mask = ori_mask / 255 + mask = cv2.resize(mask, (str_size, str_size), interpolation=cv2.INTER_AREA) + mask[mask > 0] = 255 + h, w = mask.shape[0:2] + mask3 = mask.copy() + mask3 = 1.0 - (mask3 / 255.0) + pos = np.zeros((h, w), dtype=np.int32) + direct = np.zeros((h, w, 4), dtype=np.int32) + i = 0 + while np.sum(1 - mask3) > 0: + i += 1 + mask3_ = cv2.filter2D(mask3, -1, ones_filter) + mask3_[mask3_ > 0] = 1 + sub_mask = mask3_ - mask3 + pos[sub_mask == 1] = i + + m = cv2.filter2D(mask3, -1, d_filter1) + m[m > 0] = 1 + m = m - mask3 + direct[m == 1, 0] = 1 + + m = cv2.filter2D(mask3, -1, d_filter2) + m[m > 0] = 1 + m = m - mask3 + direct[m == 1, 1] = 1 + + m = cv2.filter2D(mask3, -1, d_filter3) + m[m > 0] = 1 + m = m - mask3 + direct[m == 1, 2] = 1 + + m = cv2.filter2D(mask3, -1, d_filter4) + m[m > 0] = 1 + m = m - mask3 + direct[m == 1, 3] = 1 + + mask3 = mask3_ + + abs_pos = pos.copy() + rel_pos = pos / (str_size / 2) # to 0~1 maybe larger than 1 + rel_pos = (rel_pos * pos_num).astype(np.int32) + rel_pos = np.clip(rel_pos, 0, pos_num - 1) + + if ori_w != w or ori_h != h: + rel_pos = cv2.resize(rel_pos, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST) + rel_pos[ori_mask == 0] = 0 + direct = cv2.resize(direct, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST) + direct[ori_mask == 0, :] = 0 + + return rel_pos, abs_pos, direct + + +def load_image(img, mask, device, sigma256=3.0): + """ + Args: + img: [H, W, C] RGB + mask: [H, W] 255 为 masks 区域 + sigma256: + + Returns: + + """ + h, w, _ = img.shape + imgh, imgw = img.shape[0:2] + img_256 = resize(img, 256, 256) + + mask = (mask > 127).astype(np.uint8) * 255 + mask_256 = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_AREA) + mask_256[mask_256 > 0] = 255 + + mask_512 = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_AREA) + mask_512[mask_512 > 0] = 255 + + # original skimage implemention + # https://scikit-image.org/docs/stable/api/skimage.feature.html#skimage.feature.canny + # low_threshold: Lower bound for hysteresis thresholding (linking edges). If None, low_threshold is set to 10% of dtype’s max. + # high_threshold: Upper bound for hysteresis thresholding (linking edges). If None, high_threshold is set to 20% of dtype’s max. + + try: + import skimage + gray_256 = skimage.color.rgb2gray(img_256) + edge_256 = skimage.feature.canny(gray_256, sigma=3.0, mask=None).astype(float) + # cv2.imwrite("skimage_gray.jpg", (gray_256*255).astype(np.uint8)) + # cv2.imwrite("skimage_edge.jpg", (edge_256*255).astype(np.uint8)) + except: + gray_256 = cv2.cvtColor(img_256, cv2.COLOR_RGB2GRAY) + gray_256_blured = cv2.GaussianBlur(gray_256, ksize=(7, 7), sigmaX=sigma256, sigmaY=sigma256) + edge_256 = cv2.Canny(gray_256_blured, threshold1=int(255*0.1), threshold2=int(255*0.2)) + + # cv2.imwrite("opencv_edge.jpg", edge_256) + + # line + img_512 = resize(img, 512, 512) + + rel_pos, abs_pos, direct = load_masked_position_encoding(mask) + + batch = dict() + batch["images"] = to_tensor(img.copy()).unsqueeze(0).to(device) + batch["img_256"] = to_tensor(img_256, norm=True).unsqueeze(0).to(device) + batch["masks"] = to_tensor(mask).unsqueeze(0).to(device) + batch["mask_256"] = to_tensor(mask_256).unsqueeze(0).to(device) + batch["mask_512"] = to_tensor(mask_512).unsqueeze(0).to(device) + batch["edge_256"] = to_tensor(edge_256, scale=False).unsqueeze(0).to(device) + batch["img_512"] = to_tensor(img_512).unsqueeze(0).to(device) + batch["rel_pos"] = torch.LongTensor(rel_pos).unsqueeze(0).to(device) + batch["abs_pos"] = torch.LongTensor(abs_pos).unsqueeze(0).to(device) + batch["direct"] = torch.LongTensor(direct).unsqueeze(0).to(device) + batch["h"] = imgh + batch["w"] = imgw + + return batch + + +def to_device(data, device): + if isinstance(data, torch.Tensor): + return data.to(device) + if isinstance(data, dict): + for key in data: + if isinstance(data[key], torch.Tensor): + data[key] = data[key].to(device) + return data + if isinstance(data, list): + return [to_device(d, device) for d in data] + + +class ZITS(InpaintModel): + name = "zits" + min_size = 256 + pad_mod = 32 + pad_to_square = True + + def __init__(self, device, **kwargs): + """ + + Args: + device: + """ + super().__init__(device) + self.device = device + self.sample_edge_line_iterations = 1 + + def init_model(self, device, **kwargs): + self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device, ZITS_WIRE_FRAME_MODEL_MD5) + self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device, ZITS_EDGE_LINE_MODEL_MD5) + self.structure_upsample = load_jit_model( + ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5 + ) + self.inpaint = load_jit_model(ZITS_INPAINT_MODEL_URL, device, ZITS_INPAINT_MODEL_MD5) + + @staticmethod + def is_downloaded() -> bool: + model_paths = [ + get_cache_path_by_url(ZITS_WIRE_FRAME_MODEL_URL), + get_cache_path_by_url(ZITS_EDGE_LINE_MODEL_URL), + get_cache_path_by_url(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL), + get_cache_path_by_url(ZITS_INPAINT_MODEL_URL), + ] + return all([os.path.exists(it) for it in model_paths]) + + def wireframe_edge_and_line(self, items, enable: bool): + # 最终向 items 中添加 edge 和 line key + if not enable: + items["edge"] = torch.zeros_like(items["masks"]) + items["line"] = torch.zeros_like(items["masks"]) + return + + start = time.time() + try: + line_256 = self.wireframe_forward( + items["img_512"], + h=256, + w=256, + masks=items["mask_512"], + mask_th=0.85, + ) + except: + line_256 = torch.zeros_like(items["mask_256"]) + + print(f"wireframe_forward time: {(time.time() - start) * 1000:.2f}ms") + + # np_line = (line[0][0].numpy() * 255).astype(np.uint8) + # cv2.imwrite("line.jpg", np_line) + + start = time.time() + edge_pred, line_pred = self.sample_edge_line_logits( + context=[items["img_256"], items["edge_256"], line_256], + mask=items["mask_256"].clone(), + iterations=self.sample_edge_line_iterations, + add_v=0.05, + mul_v=4, + ) + print(f"sample_edge_line_logits time: {(time.time() - start) * 1000:.2f}ms") + + # np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8) + # cv2.imwrite("edge_pred.jpg", np_edge_pred) + # np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8) + # cv2.imwrite("line_pred.jpg", np_line_pred) + # exit() + + input_size = min(items["h"], items["w"]) + if input_size != 256 and input_size > 256: + while edge_pred.shape[2] < input_size: + edge_pred = self.structure_upsample(edge_pred) + edge_pred = torch.sigmoid((edge_pred + 2) * 2) + + line_pred = self.structure_upsample(line_pred) + line_pred = torch.sigmoid((line_pred + 2) * 2) + + edge_pred = F.interpolate( + edge_pred, + size=(input_size, input_size), + mode="bilinear", + align_corners=False, + ) + line_pred = F.interpolate( + line_pred, + size=(input_size, input_size), + mode="bilinear", + align_corners=False, + ) + + # np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8) + # cv2.imwrite("edge_pred_upsample.jpg", np_edge_pred) + # np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8) + # cv2.imwrite("line_pred_upsample.jpg", np_line_pred) + # exit() + + items["edge"] = edge_pred.detach() + items["line"] = line_pred.detach() + + @torch.no_grad() + def forward(self, image, mask, config: Config): + """Input images and output images have same size + images: [H, W, C] RGB + masks: [H, W] + return: BGR IMAGE + """ + mask = mask[:, :, 0] + items = load_image(image, mask, device=self.device) + + self.wireframe_edge_and_line(items, config.zits_wireframe) + + inpainted_image = self.inpaint( + items["images"], + items["masks"], + items["edge"], + items["line"], + items["rel_pos"], + items["direct"], + ) + + inpainted_image = inpainted_image * 255.0 + inpainted_image = ( + inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8) + ) + inpainted_image = inpainted_image[:, :, ::-1] + + # cv2.imwrite("inpainted.jpg", inpainted_image) + # exit() + + return inpainted_image + + def wireframe_forward(self, images, h, w, masks, mask_th=0.925): + lcnn_mean = torch.tensor([109.730, 103.832, 98.681]).reshape(1, 3, 1, 1) + lcnn_std = torch.tensor([22.275, 22.124, 23.229]).reshape(1, 3, 1, 1) + images = images * 255.0 + # the masks value of lcnn is 127.5 + masked_images = images * (1 - masks) + torch.ones_like(images) * masks * 127.5 + masked_images = (masked_images - lcnn_mean) / lcnn_std + + def to_int(x): + return tuple(map(int, x)) + + lines_tensor = [] + lmap = np.zeros((h, w)) + + output_masked = self.wireframe(masked_images) + + output_masked = to_device(output_masked, "cpu") + if output_masked["num_proposals"] == 0: + lines_masked = [] + scores_masked = [] + else: + lines_masked = output_masked["lines_pred"].numpy() + lines_masked = [ + [line[1] * h, line[0] * w, line[3] * h, line[2] * w] + for line in lines_masked + ] + scores_masked = output_masked["lines_score"].numpy() + + for line, score in zip(lines_masked, scores_masked): + if score > mask_th: + try: + import skimage + rr, cc, value = skimage.draw.line_aa( + *to_int(line[0:2]), *to_int(line[2:4]) + ) + lmap[rr, cc] = np.maximum(lmap[rr, cc], value) + except: + cv2.line(lmap, to_int(line[0:2][::-1]), to_int(line[2:4][::-1]), (1, 1, 1), 1, cv2.LINE_AA) + + lmap = np.clip(lmap * 255, 0, 255).astype(np.uint8) + lines_tensor.append(to_tensor(lmap).unsqueeze(0)) + + lines_tensor = torch.cat(lines_tensor, dim=0) + return lines_tensor.detach().to(self.device) + + def sample_edge_line_logits( + self, context, mask=None, iterations=1, add_v=0, mul_v=4 + ): + [img, edge, line] = context + + img = img * (1 - mask) + edge = edge * (1 - mask) + line = line * (1 - mask) + + for i in range(iterations): + edge_logits, line_logits = self.edge_line(img, edge, line, masks=mask) + + edge_pred = torch.sigmoid(edge_logits) + line_pred = torch.sigmoid((line_logits + add_v) * mul_v) + edge = edge + edge_pred * mask + edge[edge >= 0.25] = 1 + edge[edge < 0.25] = 0 + line = line + line_pred * mask + + b, _, h, w = edge_pred.shape + edge_pred = edge_pred.reshape(b, -1, 1) + line_pred = line_pred.reshape(b, -1, 1) + mask = mask.reshape(b, -1) + + edge_probs = torch.cat([1 - edge_pred, edge_pred], dim=-1) + line_probs = torch.cat([1 - line_pred, line_pred], dim=-1) + edge_probs[:, :, 1] += 0.5 + line_probs[:, :, 1] += 0.5 + edge_max_probs = edge_probs.max(dim=-1)[0] + (1 - mask) * (-100) + line_max_probs = line_probs.max(dim=-1)[0] + (1 - mask) * (-100) + + indices = torch.sort( + edge_max_probs + line_max_probs, dim=-1, descending=True + )[1] + + for ii in range(b): + keep = int((i + 1) / iterations * torch.sum(mask[ii, ...])) + + assert torch.sum(mask[ii][indices[ii, :keep]]) == keep, "Error!!!" + mask[ii][indices[ii, :keep]] = 0 + + mask = mask.reshape(b, 1, h, w) + edge = edge * (1 - mask) + line = line * (1 - mask) + + edge, line = edge.to(torch.float32), line.to(torch.float32) + return edge, line diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..34c7b4e69170d1d3ccf4f157349e5fca4c1bd8e7 --- /dev/null +++ b/lama_cleaner/model_manager.py @@ -0,0 +1,118 @@ +import torch +import gc + +from loguru import logger + +from lama_cleaner.const import SD15_MODELS +from lama_cleaner.helper import switch_mps_device +from lama_cleaner.model.controlnet import ControlNet +from lama_cleaner.model.fcf import FcF +from lama_cleaner.model.lama import LaMa +from lama_cleaner.model.ldm import LDM +from lama_cleaner.model.manga import Manga +from lama_cleaner.model.mat import MAT +from lama_cleaner.model.paint_by_example import PaintByExample +from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix +from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14 +from lama_cleaner.model.utils import torch_gc +from lama_cleaner.model.zits import ZITS +from lama_cleaner.model.opencv2 import OpenCV2 +from lama_cleaner.schema import Config + +models = { + "lama": LaMa, + "ldm": LDM, + "zits": ZITS, + "mat": MAT, + "fcf": FcF, + SD15.name: SD15, + Anything4.name: Anything4, + RealisticVision14.name: RealisticVision14, + "cv2": OpenCV2, + "manga": Manga, + "sd2": SD2, + "paint_by_example": PaintByExample, + "instruct_pix2pix": InstructPix2Pix, +} + + +class ModelManager: + def __init__(self, name: str, device: torch.device, **kwargs): + self.name = name + self.device = device + self.kwargs = kwargs + self.model = self.init_model(name, device, **kwargs) + + def init_model(self, name: str, device, **kwargs): + if name in SD15_MODELS and kwargs.get("sd_controlnet", False): + return ControlNet(device, **{**kwargs, "name": name}) + + if name in models: + model = models[name](device, **kwargs) + else: + raise NotImplementedError(f"Not supported model: {name}") + return model + + def is_downloaded(self, name: str) -> bool: + if name in models: + return models[name].is_downloaded() + else: + raise NotImplementedError(f"Not supported model: {name}") + + def __call__(self, image, mask, config: Config): + self.switch_controlnet_method(control_method=config.controlnet_method) + return self.model(image, mask, config) + + def switch(self, new_name: str, **kwargs): + if new_name == self.name: + return + try: + if torch.cuda.memory_allocated() > 0: + # Clear current loaded model from memory + torch.cuda.empty_cache() + del self.model + gc.collect() + + self.model = self.init_model( + new_name, switch_mps_device(new_name, self.device), **self.kwargs + ) + self.name = new_name + except NotImplementedError as e: + raise e + + def switch_controlnet_method(self, control_method: str): + if not self.kwargs.get("sd_controlnet"): + return + if self.kwargs["sd_controlnet_method"] == control_method: + return + if not hasattr(self.model, "is_local_sd_model"): + return + + if self.model.is_local_sd_model: + # is_native_control_inpaint 表示加载了普通 SD 模型 + if ( + self.model.is_native_control_inpaint + and control_method != "control_v11p_sd15_inpaint" + ): + raise RuntimeError( + f"--sd-local-model-path load a normal SD model, " + f"to use {control_method} you should load an inpainting SD model" + ) + elif ( + not self.model.is_native_control_inpaint + and control_method == "control_v11p_sd15_inpaint" + ): + raise RuntimeError( + f"--sd-local-model-path load an inpainting SD model, " + f"to use {control_method} you should load a norml SD model" + ) + + del self.model + torch_gc() + + old_method = self.kwargs["sd_controlnet_method"] + self.kwargs["sd_controlnet_method"] = control_method + self.model = self.init_model( + self.name, switch_mps_device(self.name, self.device), **self.kwargs + ) + logger.info(f"Switch ControlNet method from {old_method} to {control_method}") diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py new file mode 100644 index 0000000000000000000000000000000000000000..086b35b32d0b4e90846a8e29cf065c3719c29665 --- /dev/null +++ b/lama_cleaner/parse_args.py @@ -0,0 +1,256 @@ +import os +import imghdr +import argparse +from pathlib import Path + +from loguru import logger + +from lama_cleaner.const import * +from lama_cleaner.runtime import dump_environment_info + + +def parse_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", default=8080, type=int) + + parser.add_argument( + "--config-installer", + action="store_true", + help="Open config web page, mainly for windows installer", + ) + parser.add_argument( + "--load-installer-config", + action="store_true", + help="Load all cmd args from installer config file", + ) + parser.add_argument( + "--installer-config", default=None, help="Config file for windows installer" + ) + + parser.add_argument("--model", default=DEFAULT_MODEL, choices=AVAILABLE_MODELS) + parser.add_argument("--no-half", action="store_true", help=NO_HALF_HELP) + parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP) + parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP) + parser.add_argument( + "--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP + ) + parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP) + parser.add_argument( + "--sd-controlnet-method", + default=DEFAULT_CONTROLNET_METHOD, + choices=SD_CONTROLNET_CHOICES, + ) + parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP) + parser.add_argument( + "--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP + ) + parser.add_argument( + "--enable-xformers", action="store_true", help=ENABLE_XFORMERS_HELP + ) + parser.add_argument( + "--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES + ) + parser.add_argument("--gui", action="store_true", help=GUI_HELP) + parser.add_argument( + "--no-gui-auto-close", action="store_true", help=NO_GUI_AUTO_CLOSE_HELP + ) + parser.add_argument( + "--gui-size", + default=[1600, 1000], + nargs=2, + type=int, + help="Set window size for GUI", + ) + parser.add_argument("--input", type=str, default=None, help=INPUT_HELP) + parser.add_argument("--output-dir", type=str, default=None, help=OUTPUT_DIR_HELP) + parser.add_argument( + "--model-dir", type=str, default=DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP + ) + parser.add_argument( + "--disable-model-switch", + action="store_true", + help="Disable model switch in frontend", + ) + parser.add_argument( + "--quality", + default=95, + type=int, + help=QUALITY_HELP, + ) + + # Plugins + parser.add_argument( + "--enable-interactive-seg", + action="store_true", + help=INTERACTIVE_SEG_HELP, + ) + parser.add_argument( + "--interactive-seg-model", + default="vit_l", + choices=AVAILABLE_INTERACTIVE_SEG_MODELS, + help=INTERACTIVE_SEG_MODEL_HELP, + ) + parser.add_argument( + "--interactive-seg-device", + default="cpu", + choices=AVAILABLE_INTERACTIVE_SEG_DEVICES, + ) + parser.add_argument( + "--enable-remove-bg", + action="store_true", + help=REMOVE_BG_HELP, + ) + parser.add_argument( + "--enable-anime-seg", + action="store_true", + help=ANIMESEG_HELP, + ) + parser.add_argument( + "--enable-realesrgan", + action="store_true", + help=REALESRGAN_HELP, + ) + parser.add_argument( + "--realesrgan-device", + default="cpu", + type=str, + choices=REALESRGAN_AVAILABLE_DEVICES, + ) + parser.add_argument( + "--realesrgan-model", + default=RealESRGANModelName.realesr_general_x4v3.value, + type=str, + choices=RealESRGANModelNameList, + ) + parser.add_argument( + "--realesrgan-no-half", + action="store_true", + help="Disable half precision for RealESRGAN", + ) + parser.add_argument("--enable-gfpgan", action="store_true", help=GFPGAN_HELP) + parser.add_argument( + "--gfpgan-device", default="cpu", type=str, choices=GFPGAN_AVAILABLE_DEVICES + ) + parser.add_argument( + "--enable-restoreformer", action="store_true", help=RESTOREFORMER_HELP + ) + parser.add_argument( + "--restoreformer-device", + default="cpu", + type=str, + choices=RESTOREFORMER_AVAILABLE_DEVICES, + ) + parser.add_argument( + "--enable-gif", + action="store_true", + help=GIF_HELP, + ) + parser.add_argument( + "--install-plugins-package", + action="store_true", + ) + ######### + + # useless args + parser.add_argument("--debug", action="store_true", help=argparse.SUPPRESS) + parser.add_argument("--hf_access_token", default="", help=argparse.SUPPRESS) + parser.add_argument( + "--sd-disable-nsfw", action="store_true", help=argparse.SUPPRESS + ) + parser.add_argument("--sd-run-local", action="store_true", help=argparse.SUPPRESS) + parser.add_argument( + "--sd-enable-xformers", action="store_true", help=argparse.SUPPRESS + ) + + args = parser.parse_args() + + # collect system info to help debug + dump_environment_info() + if args.install_plugins_package: + from lama_cleaner.installer import install_plugins_package + + install_plugins_package() + exit() + + if args.config_installer: + if args.installer_config is None: + parser.error( + "args.config_installer==True, must set args.installer_config to store config file" + ) + from lama_cleaner.web_config import main + + logger.info("Launching installer web config page") + main(args.installer_config) + exit() + + if args.load_installer_config: + if args.installer_config and not os.path.exists(args.installer_config): + parser.error(f"args.installer_config={args.installer_config} not exists") + + logger.info(f"Loading installer config from {args.installer_config}") + _args = load_config(args.installer_config) + for k, v in vars(_args).items(): + if k in vars(args): + setattr(args, k, v) + + if args.device == "cuda": + import platform + + if platform.system() == "Darwin": + logger.info("MacOS does not support cuda, use cpu instead") + setattr(args, "device", "cpu") + else: + import torch + + if torch.cuda.is_available() is False: + parser.error( + "torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation" + ) + + if args.sd_local_model_path and args.model == "sd1.5": + if not os.path.exists(args.sd_local_model_path): + parser.error( + f"invalid --sd-local-model-path: {args.sd_local_model_path} not exists" + ) + if not os.path.isfile(args.sd_local_model_path): + parser.error( + f"invalid --sd-local-model-path: {args.sd_local_model_path} is a directory" + ) + + os.environ["U2NET_HOME"] = DEFAULT_MODEL_DIR + if args.model_dir and args.model_dir is not None: + if os.path.isfile(args.model_dir): + parser.error(f"invalid --model-dir: {args.model_dir} is a file") + + if not os.path.exists(args.model_dir): + logger.info(f"Create model cache directory: {args.model_dir}") + Path(args.model_dir).mkdir(exist_ok=True, parents=True) + + os.environ["XDG_CACHE_HOME"] = args.model_dir + os.environ["U2NET_HOME"] = args.model_dir + + if args.input and args.input is not None: + if not os.path.exists(args.input): + parser.error(f"invalid --input: {args.input} not exists") + if os.path.isfile(args.input): + if imghdr.what(args.input) is None: + parser.error(f"invalid --input: {args.input} is not a valid image file") + else: + if args.output_dir is None: + parser.error( + f"invalid --input: {args.input} is a directory, --output-dir is required" + ) + + if args.output_dir is not None: + output_dir = Path(args.output_dir) + if not output_dir.exists(): + logger.info(f"Creating output directory: {output_dir}") + output_dir.mkdir(parents=True) + else: + if not output_dir.is_dir(): + parser.error(f"invalid --output-dir: {output_dir} is not a directory") + + return args diff --git a/lama_cleaner/plugins/__init__.py b/lama_cleaner/plugins/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7dd8b85fd61ccb7b3241bc519037d7b3e2da721 --- /dev/null +++ b/lama_cleaner/plugins/__init__.py @@ -0,0 +1,7 @@ +from .interactive_seg import InteractiveSeg +from .remove_bg import RemoveBG +from .realesrgan import RealESRGANUpscaler +from .gfpgan_plugin import GFPGANPlugin +from .restoreformer import RestoreFormerPlugin +from .gif import MakeGIF +from .anime_seg import AnimeSeg diff --git a/lama_cleaner/plugins/anime_seg.py b/lama_cleaner/plugins/anime_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..f82ccc5ed4edd01851c8c0e912758b3267e83c73 --- /dev/null +++ b/lama_cleaner/plugins/anime_seg.py @@ -0,0 +1,455 @@ +import cv2 +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from PIL import Image + +from lama_cleaner.helper import load_model +from lama_cleaner.plugins.base_plugin import BasePlugin + + +class REBNCONV(nn.Module): + def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1): + super(REBNCONV, self).__init__() + + self.conv_s1 = nn.Conv2d( + in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride + ) + self.bn_s1 = nn.BatchNorm2d(out_ch) + self.relu_s1 = nn.ReLU(inplace=True) + + def forward(self, x): + hx = x + xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) + + return xout + + +## upsample tensor 'src' to have the same spatial size with tensor 'tar' +def _upsample_like(src, tar): + src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False) + + return src + + +### RSU-7 ### +class RSU7(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512): + super(RSU7, self).__init__() + + self.in_ch = in_ch + self.mid_ch = mid_ch + self.out_ch = out_ch + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2 + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + b, c, h, w = x.shape + + hx = x + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + hx = self.pool5(hx5) + + hx6 = self.rebnconv6(hx) + + hx7 = self.rebnconv7(hx6) + + hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) + hx6dup = _upsample_like(hx6d, hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-6 ### +class RSU6(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU6, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + + hx6 = self.rebnconv6(hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-5 ### +class RSU5(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU5, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + + hx5 = self.rebnconv5(hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-4 ### +class RSU4(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-4F ### +class RSU4F(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4F, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) + + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx2 = self.rebnconv2(hx1) + hx3 = self.rebnconv3(hx2) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) + hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) + + return hx1d + hxin + + +class ISNetDIS(nn.Module): + def __init__(self, in_ch=3, out_ch=1): + super(ISNetDIS, self).__init__() + + self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1) + self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage1 = RSU7(64, 32, 64) + self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage2 = RSU6(64, 32, 128) + self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage3 = RSU5(128, 64, 256) + self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage4 = RSU4(256, 128, 512) + self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage5 = RSU4F(512, 256, 512) + self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage6 = RSU4F(512, 256, 512) + + # decoder + self.stage5d = RSU4F(1024, 256, 512) + self.stage4d = RSU4(1024, 128, 256) + self.stage3d = RSU5(512, 64, 128) + self.stage2d = RSU6(256, 32, 64) + self.stage1d = RSU7(128, 16, 64) + + self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) + + def forward(self, x): + hx = x + + hxin = self.conv_in(hx) + hx = self.pool_in(hxin) + + # stage 1 + hx1 = self.stage1(hxin) + hx = self.pool12(hx1) + + # stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + # stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + # stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + # stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + # stage 6 + hx6 = self.stage6(hx) + hx6up = _upsample_like(hx6, hx5) + + # -------------------- decoder -------------------- + hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) + + # side output + d1 = self.side1(hx1d) + d1 = _upsample_like(d1, x) + return d1.sigmoid() + + +# 从小到大 +ANIME_SEG_MODELS = { + "url": "https://github.com/Sanster/models/releases/download/isnetis/isnetis.pth", + "md5": "5f25479076b73074730ab8de9e8f2051", +} + + +class AnimeSeg(BasePlugin): + # Model from: https://github.com/SkyTNT/anime-segmentation + name = "AnimeSeg" + + def __init__(self): + super().__init__() + self.model = load_model( + ISNetDIS(), + ANIME_SEG_MODELS["url"], + "cpu", + ANIME_SEG_MODELS["md5"], + ) + + def __call__(self, rgb_np_img, files, form): + return self.forward(rgb_np_img) + + @torch.no_grad() + def forward(self, rgb_np_img): + s = 1024 + + h0, w0 = h, w = rgb_np_img.shape[0], rgb_np_img.shape[1] + if h > w: + h, w = s, int(s * w / h) + else: + h, w = int(s * h / w), s + ph, pw = s - h, s - w + tmpImg = np.zeros([s, s, 3], dtype=np.float32) + tmpImg[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] = ( + cv2.resize(rgb_np_img, (w, h)) / 255 + ) + tmpImg = tmpImg.transpose((2, 0, 1)) + tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor) + mask = self.model(tmpImg) + mask = mask[0, :, ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] + mask = cv2.resize(mask.cpu().numpy().transpose((1, 2, 0)), (w0, h0)) + mask = Image.fromarray((mask * 255).astype("uint8"), mode="L") + + empty = Image.new("RGBA", (w0, h0), 0) + img = Image.fromarray(rgb_np_img) + cutout = Image.composite(img, empty, mask) + return np.asarray(cutout) diff --git a/lama_cleaner/plugins/base_plugin.py b/lama_cleaner/plugins/base_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..1dc3710c2a219e3a224bf6fcec6617e6db2756da --- /dev/null +++ b/lama_cleaner/plugins/base_plugin.py @@ -0,0 +1,15 @@ +from loguru import logger + + +class BasePlugin: + def __init__(self): + err_msg = self.check_dep() + if err_msg: + logger.error(err_msg) + exit(-1) + + def __call__(self, rgb_np_img, files, form): + ... + + def check_dep(self): + ... diff --git a/lama_cleaner/plugins/gfpgan_plugin.py b/lama_cleaner/plugins/gfpgan_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..1aaa6ca9b98ebb2a7c9b4d53ae2a1f7946975ff3 --- /dev/null +++ b/lama_cleaner/plugins/gfpgan_plugin.py @@ -0,0 +1,71 @@ +import cv2 +from loguru import logger + +from lama_cleaner.helper import download_model +from lama_cleaner.plugins.base_plugin import BasePlugin + + +class GFPGANPlugin(BasePlugin): + name = "GFPGAN" + + def __init__(self, device, upscaler=None): + super().__init__() + from .gfpganer import MyGFPGANer + + url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" + model_md5 = "94d735072630ab734561130a47bc44f8" + model_path = download_model(url, model_md5) + logger.info(f"GFPGAN model path: {model_path}") + + import facexlib + + if hasattr(facexlib.detection.retinaface, "device"): + facexlib.detection.retinaface.device = device + + # Use GFPGAN for face enhancement + self.face_enhancer = MyGFPGANer( + model_path=model_path, + upscale=1, + arch="clean", + channel_multiplier=2, + device=device, + bg_upsampler=upscaler.model if upscaler is not None else None, + ) + self.face_enhancer.face_helper.face_det.mean_tensor.to(device) + self.face_enhancer.face_helper.face_det = ( + self.face_enhancer.face_helper.face_det.to(device) + ) + + def __call__(self, rgb_np_img, files, form): + weight = 0.5 + bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) + logger.info(f"GFPGAN input shape: {bgr_np_img.shape}") + _, _, bgr_output = self.face_enhancer.enhance( + bgr_np_img, + has_aligned=False, + only_center_face=False, + paste_back=True, + weight=weight, + ) + logger.info(f"GFPGAN output shape: {bgr_output.shape}") + + # try: + # if scale != 2: + # interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 + # h, w = img.shape[0:2] + # output = cv2.resize( + # output, + # (int(w * scale / 2), int(h * scale / 2)), + # interpolation=interpolation, + # ) + # except Exception as error: + # print("wrong scale input.", error) + return bgr_output + + def check_dep(self): + try: + import gfpgan + except ImportError: + return ( + "gfpgan is not installed, please install it first. pip install gfpgan" + ) diff --git a/lama_cleaner/plugins/gfpganer.py b/lama_cleaner/plugins/gfpganer.py new file mode 100644 index 0000000000000000000000000000000000000000..04d5e6b51316a88b7d0d7c784981a4295908ba34 --- /dev/null +++ b/lama_cleaner/plugins/gfpganer.py @@ -0,0 +1,84 @@ +import os + +import torch +from facexlib.utils.face_restoration_helper import FaceRestoreHelper +from gfpgan import GFPGANv1Clean, GFPGANer +from torch.hub import get_dir + + +class MyGFPGANer(GFPGANer): + """Helper for restoration with GFPGAN. + + It will detect and crop faces, and then resize the faces to 512x512. + GFPGAN is used to restored the resized faces. + The background is upsampled with the bg_upsampler. + Finally, the faces will be pasted back to the upsample background image. + + Args: + model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically). + upscale (float): The upscale of the final output. Default: 2. + arch (str): The GFPGAN architecture. Option: clean | original. Default: clean. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + bg_upsampler (nn.Module): The upsampler for the background. Default: None. + """ + + def __init__( + self, + model_path, + upscale=2, + arch="clean", + channel_multiplier=2, + bg_upsampler=None, + device=None, + ): + self.upscale = upscale + self.bg_upsampler = bg_upsampler + + # initialize model + self.device = ( + torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device is None + else device + ) + # initialize the GFP-GAN + if arch == "clean": + self.gfpgan = GFPGANv1Clean( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=False, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True, + ) + elif arch == "RestoreFormer": + from gfpgan.archs.restoreformer_arch import RestoreFormer + + self.gfpgan = RestoreFormer() + + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, "checkpoints") + + # initialize face helper + self.face_helper = FaceRestoreHelper( + upscale, + face_size=512, + crop_ratio=(1, 1), + det_model="retinaface_resnet50", + save_ext="png", + use_parse=True, + device=self.device, + model_rootpath=model_dir, + ) + + loadnet = torch.load(model_path) + if "params_ema" in loadnet: + keyname = "params_ema" + else: + keyname = "params" + self.gfpgan.load_state_dict(loadnet[keyname], strict=True) + self.gfpgan.eval() + self.gfpgan = self.gfpgan.to(self.device) diff --git a/lama_cleaner/plugins/gif.py b/lama_cleaner/plugins/gif.py new file mode 100644 index 0000000000000000000000000000000000000000..a8ff7ce65a48e185f266d063ec573cd0dd82c54e --- /dev/null +++ b/lama_cleaner/plugins/gif.py @@ -0,0 +1,149 @@ +import io +import math + +from PIL import Image, ImageDraw + +from lama_cleaner.helper import load_img +from lama_cleaner.plugins.base_plugin import BasePlugin + + +def keep_ratio_resize(img, size, resample=Image.BILINEAR): + if img.width > img.height: + w = size + h = int(img.height * size / img.width) + else: + h = size + w = int(img.width * size / img.height) + return img.resize((w, h), resample) + + +def cubic_bezier(p1, p2, duration: int, frames: int): + """ + + Args: + p1: + p2: + duration: Total duration of the curve + frames: + + Returns: + + """ + x0, y0 = (0, 0) + x1, y1 = p1 + x2, y2 = p2 + x3, y3 = (1, 1) + + def cal_y(t): + return ( + math.pow(1 - t, 3) * y0 + + 3 * math.pow(1 - t, 2) * t * y1 + + 3 * (1 - t) * math.pow(t, 2) * y2 + + math.pow(t, 3) * y3 + ) + + def cal_x(t): + return ( + math.pow(1 - t, 3) * x0 + + 3 * math.pow(1 - t, 2) * t * x1 + + 3 * (1 - t) * math.pow(t, 2) * x2 + + math.pow(t, 3) * x3 + ) + + res = [] + for t in range(0, 1 * frames, duration): + t = t / frames + res.append((cal_x(t), cal_y(t))) + + res.append((1, 0)) + return res + + +def make_compare_gif( + clean_img: Image.Image, + src_img: Image.Image, + max_side_length: int = 600, + splitter_width: int = 5, + splitter_color=(255, 203, 0, int(255 * 0.73)), +): + if clean_img.size != src_img.size: + clean_img = clean_img.resize(src_img.size, Image.BILINEAR) + + duration_per_frame = 20 + num_frames = 50 + # erase-in-out + cubic_bezier_points = cubic_bezier((0.33, 0), (0.66, 1), 1, num_frames) + cubic_bezier_points.reverse() + + max_side_length = min(max_side_length, max(clean_img.size)) + + src_img = keep_ratio_resize(src_img, max_side_length) + clean_img = keep_ratio_resize(clean_img, max_side_length) + width, height = src_img.size + + # Generate images to make Gif from right to left + images = [] + + for i in range(num_frames): + new_frame = Image.new("RGB", (width, height)) + new_frame.paste(clean_img, (0, 0)) + + left = int(cubic_bezier_points[i][0] * width) + cropped_src_img = src_img.crop((left, 0, width, height)) + new_frame.paste(cropped_src_img, (left, 0, width, height)) + if i != num_frames - 1: + # draw a yellow splitter on the edge of the cropped image + draw = ImageDraw.Draw(new_frame) + draw.line( + [(left, 0), (left, height)], width=splitter_width, fill=splitter_color + ) + images.append(new_frame) + + for i in range(30): + images.append(src_img) + + cubic_bezier_points.reverse() + # Generate images to make Gif from left to right + for i in range(num_frames): + new_frame = Image.new("RGB", (width, height)) + new_frame.paste(src_img, (0, 0)) + + right = int(cubic_bezier_points[i][0] * width) + cropped_src_img = clean_img.crop((0, 0, right, height)) + new_frame.paste(cropped_src_img, (0, 0, right, height)) + if i != num_frames - 1: + # draw a yellow splitter on the edge of the cropped image + draw = ImageDraw.Draw(new_frame) + draw.line( + [(right, 0), (right, height)], width=splitter_width, fill=splitter_color + ) + images.append(new_frame) + + for _ in range(30): + images.append(clean_img) + + img_byte_arr = io.BytesIO() + clean_img.save( + img_byte_arr, + format="GIF", + save_all=True, + include_color_table=True, + append_images=images, + optimize=False, + duration=duration_per_frame, + loop=0, + ) + return img_byte_arr.getvalue() + + +class MakeGIF(BasePlugin): + name = "MakeGIF" + + def __call__(self, rgb_np_img, files, form): + origin_image = rgb_np_img + clean_image_bytes = files["clean_img"].read() + clean_image, _ = load_img(clean_image_bytes) + gif_bytes = make_compare_gif( + Image.fromarray(origin_image), Image.fromarray(clean_image) + ) + return gif_bytes diff --git a/lama_cleaner/plugins/interactive_seg.py b/lama_cleaner/plugins/interactive_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..269d6afef0b1f8959234071d77d41bcc161f436e --- /dev/null +++ b/lama_cleaner/plugins/interactive_seg.py @@ -0,0 +1,75 @@ +import json + +import cv2 +import numpy as np +from loguru import logger + +from lama_cleaner.helper import download_model +from lama_cleaner.plugins.base_plugin import BasePlugin +from lama_cleaner.plugins.segment_anything import SamPredictor, sam_model_registry + +# 从小到大 +SEGMENT_ANYTHING_MODELS = { + "vit_b": { + "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + "md5": "01ec64d29a2fca3f0661936605ae66f8", + }, + "vit_l": { + "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", + "md5": "0b3195507c641ddb6910d2bb5adee89c", + }, + "vit_h": { + "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + "md5": "4b8939a88964f0f4ff5f5b2642c598a6", + }, +} + + +class InteractiveSeg(BasePlugin): + name = "InteractiveSeg" + + def __init__(self, model_name, device): + super().__init__() + model_path = download_model( + SEGMENT_ANYTHING_MODELS[model_name]["url"], + SEGMENT_ANYTHING_MODELS[model_name]["md5"], + ) + logger.info(f"SegmentAnything model path: {model_path}") + self.predictor = SamPredictor( + sam_model_registry[model_name](checkpoint=model_path).to(device) + ) + self.prev_img_md5 = None + + def __call__(self, rgb_np_img, files, form): + clicks = json.loads(form["clicks"]) + return self.forward(rgb_np_img, clicks, form["img_md5"]) + + def forward(self, rgb_np_img, clicks, img_md5): + input_point = [] + input_label = [] + for click in clicks: + x = click[0] + y = click[1] + input_point.append([x, y]) + input_label.append(click[2]) + + if img_md5 and img_md5 != self.prev_img_md5: + self.prev_img_md5 = img_md5 + self.predictor.set_image(rgb_np_img) + + masks, scores, _ = self.predictor.predict( + point_coords=np.array(input_point), + point_labels=np.array(input_label), + multimask_output=False, + ) + mask = masks[0].astype(np.uint8) * 255 + # TODO: how to set kernel size? + kernel_size = 9 + mask = cv2.dilate( + mask, np.ones((kernel_size, kernel_size), np.uint8), iterations=1 + ) + # fronted brush color "ffcc00bb" + res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8) + res_mask[mask == 255] = [255, 203, 0, int(255 * 0.73)] + res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA) + return res_mask diff --git a/lama_cleaner/plugins/realesrgan.py b/lama_cleaner/plugins/realesrgan.py new file mode 100644 index 0000000000000000000000000000000000000000..6f941bfe78de75ebc7cb60d1c1e01cda2c2048ac --- /dev/null +++ b/lama_cleaner/plugins/realesrgan.py @@ -0,0 +1,96 @@ +from enum import Enum + +import cv2 +from loguru import logger + +from lama_cleaner.const import RealESRGANModelName +from lama_cleaner.helper import download_model +from lama_cleaner.plugins.base_plugin import BasePlugin + + +class RealESRGANUpscaler(BasePlugin): + name = "RealESRGAN" + + def __init__(self, name, device, no_half=False): + super().__init__() + from basicsr.archs.rrdbnet_arch import RRDBNet + from realesrgan import RealESRGANer + from realesrgan.archs.srvgg_arch import SRVGGNetCompact + + REAL_ESRGAN_MODELS = { + RealESRGANModelName.realesr_general_x4v3: { + "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", + "scale": 4, + "model": lambda: SRVGGNetCompact( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_conv=32, + upscale=4, + act_type="prelu", + ), + "model_md5": "91a7644643c884ee00737db24e478156", + }, + RealESRGANModelName.RealESRGAN_x4plus: { + "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", + "scale": 4, + "model": lambda: RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4, + ), + "model_md5": "99ec365d4afad750833258a1a24f44ca", + }, + RealESRGANModelName.RealESRGAN_x4plus_anime_6B: { + "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", + "scale": 4, + "model": lambda: RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=6, + num_grow_ch=32, + scale=4, + ), + "model_md5": "d58ce384064ec1591c2ea7b79dbf47ba", + }, + } + if name not in REAL_ESRGAN_MODELS: + raise ValueError(f"Unknown RealESRGAN model name: {name}") + model_info = REAL_ESRGAN_MODELS[name] + + model_path = download_model(model_info["url"], model_info["model_md5"]) + logger.info(f"RealESRGAN model path: {model_path}") + + self.model = RealESRGANer( + scale=model_info["scale"], + model_path=model_path, + model=model_info["model"](), + half=True if "cuda" in str(device) and not no_half else False, + tile=512, + tile_pad=10, + pre_pad=10, + device=device, + ) + + def __call__(self, rgb_np_img, files, form): + bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) + scale = float(form["upscale"]) + logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {scale}") + result = self.forward(bgr_np_img, scale) + logger.info(f"RealESRGAN output shape: {result.shape}") + return result + + def forward(self, bgr_np_img, scale: float): + # 输出是 BGR + upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0] + return upsampled + + def check_dep(self): + try: + import realesrgan + except ImportError: + return "RealESRGAN is not installed, please install it first. pip install realesrgan" diff --git a/lama_cleaner/plugins/remove_bg.py b/lama_cleaner/plugins/remove_bg.py new file mode 100644 index 0000000000000000000000000000000000000000..33438f14b397c183e425ed6ac1159eaee13943ac --- /dev/null +++ b/lama_cleaner/plugins/remove_bg.py @@ -0,0 +1,39 @@ +import os +import cv2 +import numpy as np +from torch.hub import get_dir + +from lama_cleaner.plugins.base_plugin import BasePlugin + + +class RemoveBG(BasePlugin): + name = "RemoveBG" + + def __init__(self): + super().__init__() + from rembg import new_session + + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, "checkpoints") + os.environ["U2NET_HOME"] = model_dir + + self.session = new_session(model_name="u2net") + + def __call__(self, rgb_np_img, files, form): + bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) + return self.forward(bgr_np_img) + + def forward(self, bgr_np_img) -> np.ndarray: + from rembg import remove + + # return BGRA image + output = remove(bgr_np_img, session=self.session) + return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) + + def check_dep(self): + try: + import rembg + except ImportError: + return ( + "RemoveBG is not installed, please install it first. pip install rembg" + ) diff --git a/lama_cleaner/plugins/restoreformer.py b/lama_cleaner/plugins/restoreformer.py new file mode 100644 index 0000000000000000000000000000000000000000..37fb4cea3d739744f988b4cc70b15e78895d6eeb --- /dev/null +++ b/lama_cleaner/plugins/restoreformer.py @@ -0,0 +1,54 @@ +import cv2 +from loguru import logger + +from lama_cleaner.helper import download_model +from lama_cleaner.plugins.base_plugin import BasePlugin + + +class RestoreFormerPlugin(BasePlugin): + name = "RestoreFormer" + + def __init__(self, device, upscaler=None): + super().__init__() + from .gfpganer import MyGFPGANer + + url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth" + model_md5 = "eaeeff6c4a1caa1673977cb374e6f699" + model_path = download_model(url, model_md5) + logger.info(f"RestoreFormer model path: {model_path}") + + import facexlib + + if hasattr(facexlib.detection.retinaface, "device"): + facexlib.detection.retinaface.device = device + + self.face_enhancer = MyGFPGANer( + model_path=model_path, + upscale=1, + arch="RestoreFormer", + channel_multiplier=2, + device=device, + bg_upsampler=upscaler.model if upscaler is not None else None, + ) + + def __call__(self, rgb_np_img, files, form): + weight = 0.5 + bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) + logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}") + _, _, bgr_output = self.face_enhancer.enhance( + bgr_np_img, + has_aligned=False, + only_center_face=False, + paste_back=True, + weight=weight, + ) + logger.info(f"RestoreFormer output shape: {bgr_output.shape}") + return bgr_output + + def check_dep(self): + try: + import gfpgan + except ImportError: + return ( + "gfpgan is not installed, please install it first. pip install gfpgan" + ) diff --git a/lama_cleaner/plugins/segment_anything/__init__.py b/lama_cleaner/plugins/segment_anything/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4d63be7f19ace303f4d438081cb4614fbfb532e --- /dev/null +++ b/lama_cleaner/plugins/segment_anything/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .build_sam import ( + build_sam, + build_sam_vit_h, + build_sam_vit_l, + build_sam_vit_b, + sam_model_registry, +) +from .predictor import SamPredictor diff --git a/lama_cleaner/plugins/segment_anything/build_sam.py b/lama_cleaner/plugins/segment_anything/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..854077a2bcffae09dc91544c45b87b70204c193a --- /dev/null +++ b/lama_cleaner/plugins/segment_anything/build_sam.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +sam_model_registry = { + "default": build_sam, + "vit_h": build_sam, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam diff --git a/lama_cleaner/plugins/segment_anything/modeling/__init__.py b/lama_cleaner/plugins/segment_anything/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f875314844385d2a5cb8279bfa63cbfc555798d --- /dev/null +++ b/lama_cleaner/plugins/segment_anything/modeling/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer diff --git a/lama_cleaner/plugins/segment_anything/modeling/common.py b/lama_cleaner/plugins/segment_anything/modeling/common.py new file mode 100644 index 0000000000000000000000000000000000000000..5c92073d1fd6a44d9a7f3abb9ab610d3ccbcac12 --- /dev/null +++ b/lama_cleaner/plugins/segment_anything/modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/lama_cleaner/plugins/segment_anything/modeling/image_encoder.py b/lama_cleaner/plugins/segment_anything/modeling/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..53531801d33e37aeb1789a7377afff21250c188c --- /dev/null +++ b/lama_cleaner/plugins/segment_anything/modeling/image_encoder.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/lama_cleaner/plugins/segment_anything/modeling/mask_decoder.py b/lama_cleaner/plugins/segment_anything/modeling/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..bb673fb62e8fccc4cc01e777a764b8df6e327cfb --- /dev/null +++ b/lama_cleaner/plugins/segment_anything/modeling/mask_decoder.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + tranformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for outptu + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/lama_cleaner/plugins/segment_anything/modeling/prompt_encoder.py b/lama_cleaner/plugins/segment_anything/modeling/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4f73520ad1318da91f271a623c8497c8b9a31475 --- /dev/null +++ b/lama_cleaner/plugins/segment_anything/modeling/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/lama_cleaner/plugins/segment_anything/modeling/sam.py b/lama_cleaner/plugins/segment_anything/modeling/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..94411cfd491aa8710bb18b30d0b2993398b77679 --- /dev/null +++ b/lama_cleaner/plugins/segment_anything/modeling/sam.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input promts, + C is determiend by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/lama_cleaner/plugins/segment_anything/modeling/transformer.py b/lama_cleaner/plugins/segment_anything/modeling/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f7bea551d9e6a1f60ac2d3facce1ef89e2033508 --- /dev/null +++ b/lama_cleaner/plugins/segment_anything/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attenion layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/lama_cleaner/plugins/segment_anything/predictor.py b/lama_cleaner/plugins/segment_anything/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..d9ccd3d3bd4a15b1eb7eaff449446f02ec33ed21 --- /dev/null +++ b/lama_cleaner/plugins/segment_anything/predictor.py @@ -0,0 +1,285 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from .modeling import Sam + +from typing import Optional, Tuple + + +class SamPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + from .utils.transforms import ResizeLongestSide + + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[ + None, :, :, : + ] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor( + point_coords, dtype=torch.float, device=self.device + ) + labels_torch = torch.as_tensor( + point_labels, dtype=torch.int, device=self.device + ) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor( + mask_input, dtype=torch.float, device=self.device + ) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks = masks[0].detach().cpu().numpy() + iou_predictions = iou_predictions[0].detach().cpu().numpy() + low_res_masks = low_res_masks[0].detach().cpu().numpy() + return masks, iou_predictions, low_res_masks + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks( + low_res_masks, self.input_size, self.original_size + ) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert ( + self.features is not None + ), "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/lama_cleaner/plugins/segment_anything/utils/__init__.py b/lama_cleaner/plugins/segment_anything/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3 --- /dev/null +++ b/lama_cleaner/plugins/segment_anything/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/lama_cleaner/plugins/segment_anything/utils/transforms.py b/lama_cleaner/plugins/segment_anything/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..ed9092a17d89fb7ce5007700646d3a4868316342 --- /dev/null +++ b/lama_cleaner/plugins/segment_anything/utils/transforms.py @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape( + image.shape[0], image.shape[1], self.target_length + ) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords( + self, coords: np.ndarray, original_size: Tuple[int, ...] + ) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes( + self, boxes: np.ndarray, original_size: Tuple[int, ...] + ) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape( + image.shape[0], image.shape[1], self.target_length + ) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape( + oldh: int, oldw: int, long_side_length: int + ) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/lama_cleaner/runtime.py b/lama_cleaner/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..d040a4117e13902b97046f0aaf6b363ea2ce071a --- /dev/null +++ b/lama_cleaner/runtime.py @@ -0,0 +1,50 @@ +# https://github.com/huggingface/huggingface_hub/blob/5a12851f54bf614be39614034ed3a9031922d297/src/huggingface_hub/utils/_runtime.py +import platform +import sys +import packaging.version +from rich import print +from typing import Dict, Any + +_PY_VERSION: str = sys.version.split()[0].rstrip("+") + +if packaging.version.Version(_PY_VERSION) < packaging.version.Version("3.8.0"): + import importlib_metadata # type: ignore +else: + import importlib.metadata as importlib_metadata # type: ignore + +_package_versions = {} + +_CANDIDATES = [ + "torch", + "torchvision", + "Pillow", + "diffusers", + "transformers", + "opencv-python", + "xformers", + "accelerate", + "lama-cleaner", + "rembg", + "realesrgan", + "gfpgan", +] +# Check once at runtime +for name in _CANDIDATES: + _package_versions[name] = "N/A" + try: + _package_versions[name] = importlib_metadata.version(name) + except importlib_metadata.PackageNotFoundError: + pass + + +def dump_environment_info() -> Dict[str, str]: + """Dump information about the machine to help debugging issues. """ + + # Generic machine info + info: Dict[str, Any] = { + "Platform": platform.platform(), + "Python version": platform.python_version(), + } + info.update(_package_versions) + print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n") + return info diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..0ed2817e69f38bbf19f5e3aa21c1f3b01554cd9f --- /dev/null +++ b/lama_cleaner/schema.py @@ -0,0 +1,101 @@ +from typing import Optional +from enum import Enum + +from PIL.Image import Image +from pydantic import BaseModel + + +class HDStrategy(str, Enum): + # Use original image size + ORIGINAL = "Original" + # Resize the longer side of the image to a specific size(hd_strategy_resize_limit), + # then do inpainting on the resized image. Finally, resize the inpainting result to the original size. + # The area outside the mask will not lose quality. + RESIZE = "Resize" + # Crop masking area(with a margin controlled by hd_strategy_crop_margin) from the original image to do inpainting + CROP = "Crop" + + +class LDMSampler(str, Enum): + ddim = "ddim" + plms = "plms" + + +class SDSampler(str, Enum): + ddim = "ddim" + pndm = "pndm" + k_lms = "k_lms" + k_euler = "k_euler" + k_euler_a = "k_euler_a" + dpm_plus_plus = "dpm++" + uni_pc = "uni_pc" + + +class Config(BaseModel): + class Config: + arbitrary_types_allowed = True + + # Configs for ldm model + ldm_steps: int + ldm_sampler: str = LDMSampler.plms + + # Configs for zits model + zits_wireframe: bool = True + + # Configs for High Resolution Strategy(different way to preprocess image) + hd_strategy: str # See HDStrategy Enum + hd_strategy_crop_margin: int + # If the longer side of the image is larger than this value, use crop strategy + hd_strategy_crop_trigger_size: int + hd_strategy_resize_limit: int + + # Configs for Stable Diffusion 1.5 + prompt: str = "" + negative_prompt: str = "" + # Crop image to this size before doing sd inpainting + # The value is always on the original image scale + use_croper: bool = False + croper_x: int = None + croper_y: int = None + croper_height: int = None + croper_width: int = None + + # Resize the image before doing sd inpainting, the area outside the mask will not lose quality. + # Used by sd models and paint_by_example model + sd_scale: float = 1.0 + # Blur the edge of mask area. The higher the number the smoother blend with the original image + sd_mask_blur: int = 0 + # Ignore this value, it's useless for inpainting + sd_strength: float = 0.75 + # The number of denoising steps. More denoising steps usually lead to a + # higher quality image at the expense of slower inference. + sd_steps: int = 50 + # Higher guidance scale encourages to generate images that are closely linked + # to the text prompt, usually at the expense of lower image quality. + sd_guidance_scale: float = 7.5 + sd_sampler: str = SDSampler.uni_pc + # -1 mean random seed + sd_seed: int = 42 + sd_match_histograms: bool = False + + # Configs for opencv inpainting + # opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07 + cv2_flag: str = "INPAINT_NS" + cv2_radius: int = 4 + + # Paint by Example + paint_by_example_steps: int = 50 + paint_by_example_guidance_scale: float = 7.5 + paint_by_example_mask_blur: int = 0 + paint_by_example_seed: int = 42 + paint_by_example_match_histograms: bool = False + paint_by_example_example_image: Optional[Image] = None + + # InstructPix2Pix + p2p_steps: int = 50 + p2p_image_guidance_scale: float = 1.5 + p2p_guidance_scale: float = 7.5 + + # ControlNet + controlnet_conditioning_scale: float = 0.4 + controlnet_method: str = "control_v11p_sd15_canny" diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py new file mode 100644 index 0000000000000000000000000000000000000000..d10b31db79bf526081990f3335128388923f4497 --- /dev/null +++ b/lama_cleaner/server.py @@ -0,0 +1,622 @@ +#!/usr/bin/env python3 +import os +import hashlib + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +import imghdr +import io +import logging +import multiprocessing +import random +import time +from pathlib import Path + +import cv2 +import numpy as np +import torch +from PIL import Image +from loguru import logger + +from lama_cleaner.const import SD15_MODELS +from lama_cleaner.file_manager import FileManager +from lama_cleaner.model.utils import torch_gc +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.plugins import ( + InteractiveSeg, + RemoveBG, + RealESRGANUpscaler, + MakeGIF, + GFPGANPlugin, + RestoreFormerPlugin, + AnimeSeg, +) +from lama_cleaner.schema import Config + +try: + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(False) +except: + pass + +from flask import ( + Flask, + request, + send_file, + cli, + make_response, + send_from_directory, + jsonify, +) +from flask_socketio import SocketIO + +# Disable ability for Flask to display warning about using a development server in a production environment. +# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356 +cli.show_server_banner = lambda *_: None +from flask_cors import CORS + +from lama_cleaner.helper import ( + load_img, + numpy_to_bytes, + resize_max_size, + pil_to_bytes, +) + +NUM_THREADS = str(multiprocessing.cpu_count()) + +# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56 +os.environ["KMP_DUPLICATE_LIB_OK"] = "True" + +os.environ["OMP_NUM_THREADS"] = NUM_THREADS +os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS +os.environ["MKL_NUM_THREADS"] = NUM_THREADS +os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS +os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS +if os.environ.get("CACHE_DIR"): + os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] + +BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build") + + +class NoFlaskwebgui(logging.Filter): + def filter(self, record): + msg = record.getMessage() + if "Running on http:" in msg: + print(msg[msg.index("Running on http:") :]) + + return ( + "flaskwebgui-keep-server-alive" not in msg + and "socket.io" not in msg + and "This is a development server." not in msg + ) + + +logging.getLogger("werkzeug").addFilter(NoFlaskwebgui()) + +app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) +app.config["JSON_AS_ASCII"] = False +CORS(app, expose_headers=["Content-Disposition"]) + +sio_logger = logging.getLogger("sio-logger") +sio_logger.setLevel(logging.ERROR) +socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading") + +model: ModelManager = None +thumb: FileManager = None +output_dir: str = None +device = None +input_image_path: str = None +is_disable_model_switch: bool = False +is_controlnet: bool = False +controlnet_method: str = "control_v11p_sd15_canny" +is_enable_file_manager: bool = False +is_enable_auto_saving: bool = False +is_desktop: bool = False +image_quality: int = 95 +plugins = {} + + +def get_image_ext(img_bytes): + w = imghdr.what("", img_bytes) + if w is None: + w = "jpeg" + return w + + +def diffuser_callback(i, t, latents): + socketio.emit("diffusion_progress", {"step": i}) + + +@app.route("/save_image", methods=["POST"]) +def save_image(): + if output_dir is None: + return "--output-dir is None", 500 + + input = request.files + filename = request.form["filename"] + origin_image_bytes = input["image"].read() # RGB + ext = get_image_ext(origin_image_bytes) + image, alpha_channel, exif_infos = load_img(origin_image_bytes, return_exif=True) + save_path = os.path.join(output_dir, filename) + + if alpha_channel is not None: + if alpha_channel.shape[:2] != image.shape[:2]: + alpha_channel = cv2.resize( + alpha_channel, dsize=(image.shape[1], image.shape[0]) + ) + image = np.concatenate((image, alpha_channel[:, :, np.newaxis]), axis=-1) + + pil_image = Image.fromarray(image) + + img_bytes = pil_to_bytes( + pil_image, + ext, + quality=image_quality, + exif_infos=exif_infos, + ) + with open(save_path, "wb") as fw: + fw.write(img_bytes) + + return "ok", 200 + + +@app.route("/medias/") +def medias(tab): + if tab == "image": + response = make_response(jsonify(thumb.media_names), 200) + else: + response = make_response(jsonify(thumb.output_media_names), 200) + # response.last_modified = thumb.modified_time[tab] + # response.cache_control.no_cache = True + # response.cache_control.max_age = 0 + # response.make_conditional(request) + return response + + +@app.route("/media//") +def media_file(tab, filename): + if tab == "image": + return send_from_directory(thumb.root_directory, filename) + return send_from_directory(thumb.output_dir, filename) + + +@app.route("/media_thumbnail//") +def media_thumbnail_file(tab, filename): + args = request.args + width = args.get("width") + height = args.get("height") + if width is None and height is None: + width = 256 + if width: + width = int(float(width)) + if height: + height = int(float(height)) + + directory = thumb.root_directory + if tab == "output": + directory = thumb.output_dir + thumb_filename, (width, height) = thumb.get_thumbnail( + directory, filename, width, height + ) + thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}" + + response = make_response(send_file(thumb_filepath)) + response.headers["X-Width"] = str(width) + response.headers["X-Height"] = str(height) + return response + + +@app.route("/inpaint", methods=["POST"]) +def process(): + input = request.files + # RGB + origin_image_bytes = input["image"].read() + image, alpha_channel, exif_infos = load_img(origin_image_bytes, return_exif=True) + + mask, _ = load_img(input["mask"].read(), gray=True) + mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] + + if image.shape[:2] != mask.shape[:2]: + return ( + f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", + 400, + ) + + original_shape = image.shape + interpolation = cv2.INTER_CUBIC + + form = request.form + size_limit = max(image.shape) + + if "paintByExampleImage" in input: + paint_by_example_example_image, _ = load_img( + input["paintByExampleImage"].read() + ) + paint_by_example_example_image = Image.fromarray(paint_by_example_example_image) + else: + paint_by_example_example_image = None + + config = Config( + ldm_steps=form["ldmSteps"], + ldm_sampler=form["ldmSampler"], + hd_strategy=form["hdStrategy"], + zits_wireframe=form["zitsWireframe"], + hd_strategy_crop_margin=form["hdStrategyCropMargin"], + hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"], + hd_strategy_resize_limit=form["hdStrategyResizeLimit"], + prompt=form["prompt"], + negative_prompt=form["negativePrompt"], + use_croper=form["useCroper"], + croper_x=form["croperX"], + croper_y=form["croperY"], + croper_height=form["croperHeight"], + croper_width=form["croperWidth"], + sd_scale=form["sdScale"], + sd_mask_blur=form["sdMaskBlur"], + sd_strength=form["sdStrength"], + sd_steps=form["sdSteps"], + sd_guidance_scale=form["sdGuidanceScale"], + sd_sampler=form["sdSampler"], + sd_seed=form["sdSeed"], + sd_match_histograms=form["sdMatchHistograms"], + cv2_flag=form["cv2Flag"], + cv2_radius=form["cv2Radius"], + paint_by_example_steps=form["paintByExampleSteps"], + paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"], + paint_by_example_mask_blur=form["paintByExampleMaskBlur"], + paint_by_example_seed=form["paintByExampleSeed"], + paint_by_example_match_histograms=form["paintByExampleMatchHistograms"], + paint_by_example_example_image=paint_by_example_example_image, + p2p_steps=form["p2pSteps"], + p2p_image_guidance_scale=form["p2pImageGuidanceScale"], + p2p_guidance_scale=form["p2pGuidanceScale"], + controlnet_conditioning_scale=form["controlnet_conditioning_scale"], + controlnet_method=form["controlnet_method"], + ) + + if config.sd_seed == -1: + config.sd_seed = random.randint(1, 999999999) + if config.paint_by_example_seed == -1: + config.paint_by_example_seed = random.randint(1, 999999999) + + logger.info(f"Origin image shape: {original_shape}") + image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) + + mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) + + start = time.time() + try: + res_np_img = model(image, mask, config) + except RuntimeError as e: + if "CUDA out of memory. " in str(e): + # NOTE: the string may change? + return "CUDA out of memory", 500 + else: + logger.exception(e) + return f"{str(e)}", 500 + finally: + logger.info(f"process time: {(time.time() - start) * 1000}ms") + torch_gc() + + res_np_img = cv2.cvtColor(res_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB) + if alpha_channel is not None: + if alpha_channel.shape[:2] != res_np_img.shape[:2]: + alpha_channel = cv2.resize( + alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0]) + ) + res_np_img = np.concatenate( + (res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 + ) + + ext = get_image_ext(origin_image_bytes) + + bytes_io = io.BytesIO( + pil_to_bytes( + Image.fromarray(res_np_img), + ext, + quality=image_quality, + exif_infos=exif_infos, + ) + ) + + response = make_response( + send_file( + # io.BytesIO(numpy_to_bytes(res_np_img, ext)), + bytes_io, + mimetype=f"image/{ext}", + ) + ) + response.headers["X-Seed"] = str(config.sd_seed) + + socketio.emit("diffusion_finish") + return response + + +@app.route("/run_plugin", methods=["POST"]) +def run_plugin(): + form = request.form + files = request.files + name = form["name"] + if name not in plugins: + return "Plugin not found", 500 + + origin_image_bytes = files["image"].read() # RGB + rgb_np_img, alpha_channel, exif_infos = load_img( + origin_image_bytes, return_exif=True + ) + + start = time.time() + try: + form = dict(form) + if name == InteractiveSeg.name: + img_md5 = hashlib.md5(origin_image_bytes).hexdigest() + form["img_md5"] = img_md5 + bgr_res = plugins[name](rgb_np_img, files, form) + except RuntimeError as e: + torch.cuda.empty_cache() + if "CUDA out of memory. " in str(e): + # NOTE: the string may change? + return "CUDA out of memory", 500 + else: + logger.exception(e) + return "Internal Server Error", 500 + + logger.info(f"{name} process time: {(time.time() - start) * 1000}ms") + torch_gc() + + if name == MakeGIF.name: + return send_file( + io.BytesIO(bgr_res), + mimetype="image/gif", + as_attachment=True, + download_name=form["filename"], + ) + if name == InteractiveSeg.name: + return make_response( + send_file( + io.BytesIO(numpy_to_bytes(bgr_res, "png")), + mimetype="image/png", + ) + ) + + if name in [RemoveBG.name, AnimeSeg.name]: + rgb_res = bgr_res + ext = "png" + else: + rgb_res = cv2.cvtColor(bgr_res, cv2.COLOR_BGR2RGB) + ext = get_image_ext(origin_image_bytes) + if alpha_channel is not None: + if alpha_channel.shape[:2] != rgb_res.shape[:2]: + alpha_channel = cv2.resize( + alpha_channel, dsize=(rgb_res.shape[1], rgb_res.shape[0]) + ) + rgb_res = np.concatenate( + (rgb_res, alpha_channel[:, :, np.newaxis]), axis=-1 + ) + + response = make_response( + send_file( + io.BytesIO( + pil_to_bytes( + Image.fromarray(rgb_res), + ext, + quality=image_quality, + exif_infos=exif_infos, + ) + ), + mimetype=f"image/{ext}", + ) + ) + return response + + +@app.route("/server_config", methods=["GET"]) +def get_server_config(): + return { + "isControlNet": is_controlnet, + "controlNetMethod": controlnet_method, + "isDisableModelSwitchState": is_disable_model_switch, + "isEnableAutoSaving": is_enable_auto_saving, + "enableFileManager": is_enable_file_manager, + "plugins": list(plugins.keys()), + }, 200 + + +@app.route("/model") +def current_model(): + return model.name, 200 + + +@app.route("/model_downloaded/") +def model_downloaded(name): + return str(model.is_downloaded(name)), 200 + + +@app.route("/is_desktop") +def get_is_desktop(): + return str(is_desktop), 200 + + +@app.route("/model", methods=["POST"]) +def switch_model(): + if is_disable_model_switch: + return "Switch model is disabled", 400 + + new_name = request.form.get("name") + if new_name == model.name: + return "Same model", 200 + + try: + model.switch(new_name) + except NotImplementedError: + return f"{new_name} not implemented", 403 + return f"ok, switch to {new_name}", 200 + + +@app.route("/") +def index(): + return send_file(os.path.join(BUILD_DIR, "index.html")) + + +@app.route("/inputimage") +def set_input_photo(): + if input_image_path: + with open(input_image_path, "rb") as f: + image_in_bytes = f.read() + return send_file( + input_image_path, + as_attachment=True, + download_name=Path(input_image_path).name, + mimetype=f"image/{get_image_ext(image_in_bytes)}", + ) + else: + return "No Input Image" + + +def build_plugins(args): + global plugins + if args.enable_interactive_seg: + logger.info(f"Initialize {InteractiveSeg.name} plugin") + plugins[InteractiveSeg.name] = InteractiveSeg( + args.interactive_seg_model, args.interactive_seg_device + ) + + if args.enable_remove_bg: + logger.info(f"Initialize {RemoveBG.name} plugin") + plugins[RemoveBG.name] = RemoveBG() + + if args.enable_anime_seg: + logger.info(f"Initialize {AnimeSeg.name} plugin") + plugins[AnimeSeg.name] = AnimeSeg() + + if args.enable_realesrgan: + logger.info( + f"Initialize {RealESRGANUpscaler.name} plugin: {args.realesrgan_model}, {args.realesrgan_device}" + ) + plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler( + args.realesrgan_model, + args.realesrgan_device, + no_half=args.realesrgan_no_half, + ) + + if args.enable_gfpgan: + logger.info(f"Initialize {GFPGANPlugin.name} plugin") + if args.enable_realesrgan: + logger.info("Use realesrgan as GFPGAN background upscaler") + else: + logger.info( + f"GFPGAN no background upscaler, use --enable-realesrgan to enable it" + ) + plugins[GFPGANPlugin.name] = GFPGANPlugin( + args.gfpgan_device, upscaler=plugins.get(RealESRGANUpscaler.name, None) + ) + + if args.enable_restoreformer: + logger.info(f"Initialize {RestoreFormerPlugin.name} plugin") + plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin( + args.restoreformer_device, + upscaler=plugins.get(RealESRGANUpscaler.name, None), + ) + + if args.enable_gif: + logger.info(f"Initialize GIF plugin") + plugins[MakeGIF.name] = MakeGIF() + + +def main(args): + global model + global device + global input_image_path + global is_disable_model_switch + global is_enable_file_manager + global is_desktop + global thumb + global output_dir + global is_enable_auto_saving + global is_controlnet + global controlnet_method + global image_quality + + build_plugins(args) + + image_quality = args.quality + + if args.sd_controlnet and args.model in SD15_MODELS: + is_controlnet = True + controlnet_method = args.sd_controlnet_method + + output_dir = args.output_dir + if output_dir: + is_enable_auto_saving = True + + device = torch.device(args.device) + is_disable_model_switch = args.disable_model_switch + is_desktop = args.gui + if is_disable_model_switch: + logger.info( + f"Start with --disable-model-switch, model switch on frontend is disable" + ) + + if args.input and os.path.isdir(args.input): + logger.info(f"Initialize file manager") + thumb = FileManager(app) + is_enable_file_manager = True + app.config["THUMBNAIL_MEDIA_ROOT"] = args.input + app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join( + args.output_dir, "lama_cleaner_thumbnails" + ) + thumb.output_dir = Path(args.output_dir) + # thumb.start() + # try: + # while True: + # time.sleep(1) + # finally: + # thumb.image_dir_observer.stop() + # thumb.image_dir_observer.join() + # thumb.output_dir_observer.stop() + # thumb.output_dir_observer.join() + + else: + input_image_path = args.input + + model = ModelManager( + name=args.model, + sd_controlnet=args.sd_controlnet, + sd_controlnet_method=args.sd_controlnet_method, + device=device, + no_half=args.no_half, + hf_access_token=args.hf_access_token, + disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw, + sd_cpu_textencoder=args.sd_cpu_textencoder, + sd_run_local=args.sd_run_local, + sd_local_model_path=args.sd_local_model_path, + local_files_only=args.local_files_only, + cpu_offload=args.cpu_offload, + enable_xformers=args.sd_enable_xformers or args.enable_xformers, + callback=diffuser_callback, + ) + + if args.gui: + app_width, app_height = args.gui_size + from flaskwebgui import FlaskUI + + ui = FlaskUI( + app, + socketio=socketio, + width=app_width, + height=app_height, + host=args.host, + port=args.port, + close_server_on_exit=not args.no_gui_auto_close, + ) + ui.run() + else: + socketio.run( + app, + host=args.host, + port=args.port, + debug=args.debug, + allow_unsafe_werkzeug=True, + ) diff --git a/lama_cleaner/tests/__init__.py b/lama_cleaner/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lama_cleaner/tests/test_controlnet.py b/lama_cleaner/tests/test_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d6df9b1179ad9582b934a6769f42f2fefe2e72ef --- /dev/null +++ b/lama_cleaner/tests/test_controlnet.py @@ -0,0 +1,195 @@ +import os + +from lama_cleaner.const import SD_CONTROLNET_CHOICES + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +from pathlib import Path + +import pytest +import torch + +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import HDStrategy, SDSampler +from lama_cleaner.tests.test_model import get_config, assert_equal + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / "result" +save_dir.mkdir(exist_ok=True, parents=True) +device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device(device) + + +@pytest.mark.parametrize("sd_device", ["cuda", "mps"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.uni_pc]) +@pytest.mark.parametrize("cpu_textencoder", [True]) +@pytest.mark.parametrize("disable_nsfw", [True]) +@pytest.mark.parametrize("sd_controlnet_method", SD_CONTROLNET_CHOICES) +def test_runway_sd_1_5( + sd_device, strategy, sampler, cpu_textencoder, disable_nsfw, sd_controlnet_method +): + if sd_device == "cuda" and not torch.cuda.is_available(): + return + if device == "mps" and not torch.backends.mps.is_available(): + return + + sd_steps = 1 if sd_device == "cpu" else 30 + model = ModelManager( + name="sd1.5", + sd_controlnet=True, + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=False, + disable_nsfw=disable_nsfw, + sd_cpu_textencoder=cpu_textencoder, + sd_controlnet_method=sd_controlnet_method, + ) + + controlnet_conditioning_scale = { + "control_v11p_sd15_canny": 0.4, + "control_v11p_sd15_openpose": 0.4, + "control_v11p_sd15_inpaint": 1.0, + "control_v11f1p_sd15_depth": 1.0, + }[sd_controlnet_method] + + cfg = get_config( + strategy, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + controlnet_conditioning_scale=controlnet_conditioning_scale, + controlnet_method=sd_controlnet_method, + ) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}_cpu_textencoder_disable_nsfw" + + assert_equal( + model, + cfg, + f"sd_controlnet_{sd_controlnet_method}_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1.2, + fy=1.2, + ) + + +@pytest.mark.parametrize("sd_device", ["cuda", "mps"]) +@pytest.mark.parametrize("sampler", [SDSampler.uni_pc]) +def test_local_file_path(sd_device, sampler): + if sd_device == "cuda" and not torch.cuda.is_available(): + return + if device == "mps" and not torch.backends.mps.is_available(): + return + + sd_steps = 1 if sd_device == "cpu" else 30 + model = ModelManager( + name="sd1.5", + sd_controlnet=True, + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=False, + disable_nsfw=True, + sd_cpu_textencoder=False, + cpu_offload=True, + sd_local_model_path="/Users/cwq/data/models/sd-v1-5-inpainting.ckpt", + sd_controlnet_method="control_v11p_sd15_canny", + ) + cfg = get_config( + HDStrategy.ORIGINAL, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + controlnet_method="control_v11p_sd15_canny", + ) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}" + + assert_equal( + model, + cfg, + f"sd_controlnet_canny_local_model_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("sd_device", ["cuda", "mps"]) +@pytest.mark.parametrize("sampler", [SDSampler.uni_pc]) +def test_local_file_path_controlnet_native_inpainting(sd_device, sampler): + if sd_device == "cuda" and not torch.cuda.is_available(): + return + if device == "mps" and not torch.backends.mps.is_available(): + return + + sd_steps = 1 if sd_device == "cpu" else 30 + model = ModelManager( + name="sd1.5", + sd_controlnet=True, + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=False, + disable_nsfw=True, + sd_cpu_textencoder=False, + cpu_offload=True, + sd_local_model_path="/Users/cwq/data/models/v1-5-pruned-emaonly.safetensors", + sd_controlnet_method="control_v11p_sd15_inpaint", + ) + cfg = get_config( + HDStrategy.ORIGINAL, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + controlnet_conditioning_scale=1.0, + sd_strength=1.0, + controlnet_method="control_v11p_sd15_inpaint", + ) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}" + + assert_equal( + model, + cfg, + f"sd_controlnet_local_native_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("sd_device", ["cuda", "mps"]) +@pytest.mark.parametrize("sampler", [SDSampler.uni_pc]) +def test_controlnet_switch(sd_device, sampler): + if sd_device == "cuda" and not torch.cuda.is_available(): + return + if device == "mps" and not torch.backends.mps.is_available(): + return + + sd_steps = 1 if sd_device == "cpu" else 30 + model = ModelManager( + name="sd1.5", + sd_controlnet=True, + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=False, + disable_nsfw=True, + sd_cpu_textencoder=False, + cpu_offload=True, + sd_controlnet_method="control_v11p_sd15_canny", + ) + cfg = get_config( + HDStrategy.ORIGINAL, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + controlnet_method="control_v11p_sd15_inpaint", + ) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}" + + assert_equal( + model, + cfg, + f"sd_controlnet_switch_to_inpaint_local_model_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) diff --git a/lama_cleaner/tests/test_instruct_pix2pix.py b/lama_cleaner/tests/test_instruct_pix2pix.py new file mode 100644 index 0000000000000000000000000000000000000000..d434a21265866527a765794aa2e035fa993d6c02 --- /dev/null +++ b/lama_cleaner/tests/test_instruct_pix2pix.py @@ -0,0 +1,62 @@ +from pathlib import Path + +import pytest +import torch + +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.tests.test_model import get_config, assert_equal +from lama_cleaner.schema import HDStrategy + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / 'result' +save_dir.mkdir(exist_ok=True, parents=True) +device = 'cuda' if torch.cuda.is_available() else 'cpu' + + +@pytest.mark.parametrize("disable_nsfw", [True, False]) +@pytest.mark.parametrize("cpu_offload", [False, True]) +def test_instruct_pix2pix(disable_nsfw, cpu_offload): + sd_steps = 50 if device == 'cuda' else 1 + model = ModelManager(name="instruct_pix2pix", + device=torch.device(device), + hf_access_token="", + sd_run_local=False, + disable_nsfw=disable_nsfw, + sd_cpu_textencoder=False, + cpu_offload=cpu_offload) + cfg = get_config(strategy=HDStrategy.ORIGINAL, prompt='What if it were snowing?', p2p_steps=sd_steps, sd_scale=1.1) + + name = f"device_{device}_disnsfw_{disable_nsfw}_cpu_offload_{cpu_offload}" + + assert_equal( + model, + cfg, + f"instruct_pix2pix_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1.3 + ) + + +@pytest.mark.parametrize("disable_nsfw", [False]) +@pytest.mark.parametrize("cpu_offload", [False]) +def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload): + sd_steps = 50 if device == 'cuda' else 1 + model = ModelManager(name="instruct_pix2pix", + device=torch.device(device), + hf_access_token="", + sd_run_local=False, + disable_nsfw=disable_nsfw, + sd_cpu_textencoder=False, + cpu_offload=cpu_offload) + cfg = get_config(strategy=HDStrategy.ORIGINAL, prompt='What if it were snowing?', p2p_steps=sd_steps) + + name = f"snow" + + assert_equal( + model, + cfg, + f"instruct_pix2pix_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) diff --git a/lama_cleaner/tests/test_interactive_seg.py b/lama_cleaner/tests/test_interactive_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..12a23bb6d70a2a3db98614816ef754817e479030 --- /dev/null +++ b/lama_cleaner/tests/test_interactive_seg.py @@ -0,0 +1,43 @@ +from pathlib import Path + +import cv2 +import numpy as np + +from lama_cleaner.plugins import InteractiveSeg, Click + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / "result" +save_dir.mkdir(exist_ok=True, parents=True) +img_p = current_dir / "overture-creations-5sI6fQgYIuo.png" + + +def test_interactive_seg(): + interactive_seg_model = InteractiveSeg() + img = cv2.imread(str(img_p)) + pred = interactive_seg_model.forward( + img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)] + ) + cv2.imwrite(str(save_dir / "test_interactive_seg.png"), pred) + + +def test_interactive_seg_with_negative_click(): + interactive_seg_model = InteractiveSeg() + img = cv2.imread(str(img_p)) + pred = interactive_seg_model.forward( + img, + clicks=[ + Click(coords=(256, 256), indx=0, is_positive=True), + Click(coords=(384, 256), indx=1, is_positive=False), + ], + ) + cv2.imwrite(str(save_dir / "test_interactive_seg_negative.png"), pred) + + +def test_interactive_seg_with_prev_mask(): + interactive_seg_model = InteractiveSeg() + img = cv2.imread(str(img_p)) + mask = np.zeros_like(img)[:, :, 0] + pred = interactive_seg_model.forward( + img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)], prev_mask=mask + ) + cv2.imwrite(str(save_dir / "test_interactive_seg_with_mask.png"), pred) diff --git a/lama_cleaner/tests/test_load_img.py b/lama_cleaner/tests/test_load_img.py new file mode 100644 index 0000000000000000000000000000000000000000..033564913be24b1b224352e7e2eece550d0bbbf5 --- /dev/null +++ b/lama_cleaner/tests/test_load_img.py @@ -0,0 +1,21 @@ +from pathlib import Path + +from lama_cleaner.helper import load_img + +current_dir = Path(__file__).parent.absolute().resolve() +png_img_p = current_dir / "image.png" +jpg_img_p = current_dir / "bunny.jpeg" + + +def test_load_png_image(): + with open(png_img_p, "rb") as f: + np_img, alpha_channel = load_img(f.read()) + assert np_img.shape == (256, 256, 3) + assert alpha_channel.shape == (256, 256) + + +def test_load_jpg_image(): + with open(jpg_img_p, "rb") as f: + np_img, alpha_channel = load_img(f.read()) + assert np_img.shape == (394, 448, 3) + assert alpha_channel is None diff --git a/lama_cleaner/tests/test_model.py b/lama_cleaner/tests/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a9f7957749d7032376be795cdacda327c0a9546b --- /dev/null +++ b/lama_cleaner/tests/test_model.py @@ -0,0 +1,194 @@ +from pathlib import Path + +import cv2 +import pytest +import torch + +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / "result" +save_dir.mkdir(exist_ok=True, parents=True) +device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device(device) + + +def get_data( + fx: float = 1, + fy: float = 1.0, + img_p=current_dir / "image.png", + mask_p=current_dir / "mask.png", +): + img = cv2.imread(str(img_p)) + img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) + mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE) + img = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA) + mask = cv2.resize(mask, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST) + return img, mask + + +def get_config(strategy, **kwargs): + data = dict( + ldm_steps=1, + ldm_sampler=LDMSampler.plms, + hd_strategy=strategy, + hd_strategy_crop_margin=32, + hd_strategy_crop_trigger_size=200, + hd_strategy_resize_limit=200, + ) + data.update(**kwargs) + return Config(**data) + + +def assert_equal( + model, + config, + gt_name, + fx: float = 1, + fy: float = 1, + img_p=current_dir / "image.png", + mask_p=current_dir / "mask.png", +): + img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p) + print(f"Input image shape: {img.shape}") + res = model(img, mask, config) + cv2.imwrite( + str(save_dir / gt_name), + res, + [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], + ) + + """ + Note that JPEG is lossy compression, so even if it is the highest quality 100, + when the saved images is reloaded, a difference occurs with the original pixel value. + If you want to save the original images as it is, save it as PNG or BMP. + """ + # gt = cv2.imread(str(current_dir / gt_name), cv2.IMREAD_UNCHANGED) + # assert np.array_equal(res, gt) + + +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) +def test_lama(strategy): + model = ModelManager(name="lama", device=device) + assert_equal( + model, + get_config(strategy), + f"lama_{strategy[0].upper() + strategy[1:]}_result.png", + ) + + fx = 1.3 + assert_equal( + model, + get_config(strategy), + f"lama_{strategy[0].upper() + strategy[1:]}_fx_{fx}_result.png", + fx=1.3, + ) + + +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) +@pytest.mark.parametrize("ldm_sampler", [LDMSampler.ddim, LDMSampler.plms]) +def test_ldm(strategy, ldm_sampler): + model = ModelManager(name="ldm", device=device) + cfg = get_config(strategy, ldm_sampler=ldm_sampler) + assert_equal( + model, cfg, f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_result.png" + ) + + fx = 1.3 + assert_equal( + model, + cfg, + f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_fx_{fx}_result.png", + fx=fx, + ) + + +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) +@pytest.mark.parametrize("zits_wireframe", [False, True]) +def test_zits(strategy, zits_wireframe): + model = ModelManager(name="zits", device=device) + cfg = get_config(strategy, zits_wireframe=zits_wireframe) + # os.environ['ZITS_DEBUG_LINE_PATH'] = str(current_dir / 'zits_debug_line.jpg') + # os.environ['ZITS_DEBUG_EDGE_PATH'] = str(current_dir / 'zits_debug_edge.jpg') + assert_equal( + model, + cfg, + f"zits_{strategy[0].upper() + strategy[1:]}_wireframe_{zits_wireframe}_result.png", + ) + + fx = 1.3 + assert_equal( + model, + cfg, + f"zits_{strategy.capitalize()}_wireframe_{zits_wireframe}_fx_{fx}_result.png", + fx=fx, + ) + + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("no_half", [True, False]) +def test_mat(strategy, no_half): + model = ModelManager(name="mat", device=device, no_half=no_half) + cfg = get_config(strategy) + + for _ in range(10): + assert_equal( + model, + cfg, + f"mat_{strategy.capitalize()}_result.png", + ) + + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_fcf(strategy): + model = ModelManager(name="fcf", device=device) + cfg = get_config(strategy) + + assert_equal(model, cfg, f"fcf_{strategy.capitalize()}_result.png", fx=2, fy=2) + + assert_equal(model, cfg, f"fcf_{strategy.capitalize()}_result.png", fx=3.8, fy=2) + + +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) +@pytest.mark.parametrize("cv2_flag", ["INPAINT_NS", "INPAINT_TELEA"]) +@pytest.mark.parametrize("cv2_radius", [3, 15]) +def test_cv2(strategy, cv2_flag, cv2_radius): + model = ModelManager( + name="cv2", + device=torch.device(device), + ) + cfg = get_config(strategy, cv2_flag=cv2_flag, cv2_radius=cv2_radius) + assert_equal( + model, + cfg, + f"sd_{strategy.capitalize()}_{cv2_flag}_{cv2_radius}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) +def test_manga(strategy): + model = ModelManager( + name="manga", + device=torch.device(device), + ) + cfg = get_config(strategy) + assert_equal( + model, + cfg, + f"sd_{strategy.capitalize()}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) diff --git a/lama_cleaner/tests/test_model_md5.py b/lama_cleaner/tests/test_model_md5.py new file mode 100644 index 0000000000000000000000000000000000000000..67b3e654561ba506a805515d9ebbd643a5a8c871 --- /dev/null +++ b/lama_cleaner/tests/test_model_md5.py @@ -0,0 +1,49 @@ +def test_load_model(): + from lama_cleaner.plugins import InteractiveSeg + from lama_cleaner.model_manager import ModelManager + + interactive_seg_model = InteractiveSeg('vit_l', 'cpu') + + models = [ + "lama", + "ldm", + "zits", + "mat", + "fcf", + "manga", + ] + for m in models: + ModelManager( + name=m, + device="cpu", + no_half=False, + hf_access_token="", + disable_nsfw=False, + sd_cpu_textencoder=True, + sd_run_local=True, + local_files_only=True, + cpu_offload=True, + enable_xformers=False, + ) + + +# def create_empty_file(tmp_dir, name): +# tmp_model_dir = os.path.join(tmp_dir, "torch", "hub", "checkpoints") +# Path(tmp_model_dir).mkdir(exist_ok=True, parents=True) +# path = os.path.join(tmp_model_dir, name) +# with open(path, "w") as f: +# f.write("1") +# +# +# def test_load_model_error(): +# MODELS = [ +# ("big-lama.pt", "e3aa4aaa15225a33ec84f9f4bc47e500"), +# ("cond_stage_model_encode.pt", "23239fc9081956a3e70de56472b3f296"), +# ("cond_stage_model_decode.pt", "fe419cd15a750d37a4733589d0d3585c"), +# ("diffusion.pt", "b0afda12bf790c03aba2a7431f11d22d"), +# ] +# with tempfile.TemporaryDirectory() as tmp_dir: +# os.environ["XDG_CACHE_HOME"] = tmp_dir +# for name, md5 in MODELS: +# create_empty_file(tmp_dir, name) +# test_load_model() diff --git a/lama_cleaner/tests/test_paint_by_example.py b/lama_cleaner/tests/test_paint_by_example.py new file mode 100644 index 0000000000000000000000000000000000000000..1c4b4689198dca0feef91cedf57bd06db9fef952 --- /dev/null +++ b/lama_cleaner/tests/test_paint_by_example.py @@ -0,0 +1,106 @@ +from pathlib import Path + +import cv2 +import pytest +import torch +from PIL import Image + +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import HDStrategy +from lama_cleaner.tests.test_model import get_config, get_data + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / 'result' +save_dir.mkdir(exist_ok=True, parents=True) +device = 'cuda' if torch.cuda.is_available() else 'cpu' +device = torch.device(device) + + +def assert_equal( + model, config, gt_name, + fx: float = 1, fy: float = 1, + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + example_p=current_dir / "bunny.jpeg", +): + img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p) + + example_image = cv2.imread(str(example_p)) + example_image = cv2.cvtColor(example_image, cv2.COLOR_BGRA2RGB) + example_image = cv2.resize(example_image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA) + + print(f"Input image shape: {img.shape}, example_image: {example_image.shape}") + config.paint_by_example_example_image = Image.fromarray(example_image) + res = model(img, mask, config) + cv2.imwrite(str(save_dir / gt_name), res) + + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_paint_by_example(strategy): + model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True) + cfg = get_config(strategy, paint_by_example_steps=30) + assert_equal( + model, + cfg, + f"paint_by_example_{strategy.capitalize()}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fy=0.9, + fx=1.3, + ) + + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_paint_by_example_disable_nsfw(strategy): + model = ModelManager(name="paint_by_example", device=device, disable_nsfw=False) + cfg = get_config(strategy, paint_by_example_steps=30) + assert_equal( + model, + cfg, + f"paint_by_example_{strategy.capitalize()}_disable_nsfw.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_paint_by_example_sd_scale(strategy): + model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True) + cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85) + assert_equal( + model, + cfg, + f"paint_by_example_{strategy.capitalize()}_sdscale.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fy=0.9, + fx=1.3 + ) + + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_paint_by_example_cpu_offload(strategy): + model = ModelManager(name="paint_by_example", device=device, cpu_offload=True, disable_nsfw=False) + cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85) + assert_equal( + model, + cfg, + f"paint_by_example_{strategy.capitalize()}_cpu_offload.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_paint_by_example_cpu_offload_cpu_device(strategy): + model = ModelManager(name="paint_by_example", device=torch.device('cpu'), cpu_offload=True, disable_nsfw=True) + cfg = get_config(strategy, paint_by_example_steps=1, sd_scale=0.85) + assert_equal( + model, + cfg, + f"paint_by_example_{strategy.capitalize()}_cpu_offload_cpu_device.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fy=0.9, + fx=1.3 + ) diff --git a/lama_cleaner/tests/test_plugins.py b/lama_cleaner/tests/test_plugins.py new file mode 100644 index 0000000000000000000000000000000000000000..e786ae62eb88e092561a315abe32e2643cefac80 --- /dev/null +++ b/lama_cleaner/tests/test_plugins.py @@ -0,0 +1,103 @@ +import hashlib +import os +import time + +from lama_cleaner.plugins.anime_seg import AnimeSeg + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +from pathlib import Path + +import cv2 +import pytest +import torch.cuda + +from lama_cleaner.plugins import ( + RemoveBG, + RealESRGANUpscaler, + GFPGANPlugin, + RestoreFormerPlugin, + InteractiveSeg, +) + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / "result" +save_dir.mkdir(exist_ok=True, parents=True) +img_p = current_dir / "bunny.jpeg" +img_bytes = open(img_p, "rb").read() +bgr_img = cv2.imread(str(img_p)) +rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) + + +def _save(img, name): + cv2.imwrite(str(save_dir / name), img) + + +def test_remove_bg(): + model = RemoveBG() + res = model.forward(bgr_img) + res = cv2.cvtColor(res, cv2.COLOR_RGBA2BGRA) + _save(res, "test_remove_bg.png") + + +def test_anime_seg(): + model = AnimeSeg() + img = cv2.imread(str(current_dir / "anime_test.png")) + res = model.forward(img) + assert len(res.shape) == 3 + assert res.shape[-1] == 4 + _save(res, "test_anime_seg.png") + + +@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) +def test_upscale(device): + if device == "cuda" and not torch.cuda.is_available(): + return + if device == "mps" and not torch.backends.mps.is_available(): + return + + model = RealESRGANUpscaler("realesr-general-x4v3", device) + res = model.forward(bgr_img, 2) + _save(res, f"test_upscale_x2_{device}.png") + + res = model.forward(bgr_img, 4) + _save(res, f"test_upscale_x4_{device}.png") + + +@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) +def test_gfpgan(device): + if device == "cuda" and not torch.cuda.is_available(): + return + if device == "mps" and not torch.backends.mps.is_available(): + return + model = GFPGANPlugin(device) + res = model(rgb_img, None, None) + _save(res, f"test_gfpgan_{device}.png") + + +@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) +def test_restoreformer(device): + if device == "cuda" and not torch.cuda.is_available(): + return + if device == "mps" and not torch.backends.mps.is_available(): + return + model = RestoreFormerPlugin(device) + res = model(rgb_img, None, None) + _save(res, f"test_restoreformer_{device}.png") + + +@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) +def test_segment_anything(device): + if device == "cuda" and not torch.cuda.is_available(): + return + if device == "mps" and not torch.backends.mps.is_available(): + return + img_md5 = hashlib.md5(img_bytes).hexdigest() + model = InteractiveSeg("vit_l", device) + new_mask = model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5) + + save_name = f"test_segment_anything_{device}.png" + _save(new_mask, save_name) + + start = time.time() + model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5) + print(f"Time for {save_name}: {time.time() - start:.2f}s") diff --git a/lama_cleaner/tests/test_save_exif.py b/lama_cleaner/tests/test_save_exif.py new file mode 100644 index 0000000000000000000000000000000000000000..86923421e7017e7538aaa08b8794559995ad6e6a --- /dev/null +++ b/lama_cleaner/tests/test_save_exif.py @@ -0,0 +1,43 @@ +import io +from pathlib import Path + +from PIL import Image + +from lama_cleaner.helper import pil_to_bytes, load_img + +current_dir = Path(__file__).parent.absolute().resolve() + + +def print_exif(exif): + for k, v in exif.items(): + print(f"{k}: {v}") + + +def run_test(img_p: Path): + print(img_p) + ext = img_p.suffix.strip(".") + img_bytes = img_p.read_bytes() + np_img, _, exif_infos = load_img(img_bytes, False, True) + print(exif_infos) + print("Original exif_infos") + print_exif(exif_infos["exif"]) + + pil_to_bytes(Image.fromarray(np_img), ext=ext, exif_infos={}) + + pil_bytes = pil_to_bytes(Image.fromarray(np_img), ext=ext, exif_infos=exif_infos) + res_img = Image.open(io.BytesIO(pil_bytes)) + print(f"Result img info: {res_img.info}") + res_exif = res_img.getexif() + print_exif(res_exif) + assert res_exif == exif_infos["exif"] + assert exif_infos["parameters"] == res_img.info.get("parameters") + + +def test_png(): + run_test(current_dir / "image.png") + run_test(current_dir / "pnginfo_test.png") + + +def test_jpeg(): + jpg_img_p = current_dir / "bunny.jpeg" + run_test(jpg_img_p) diff --git a/lama_cleaner/tests/test_sd_model.py b/lama_cleaner/tests/test_sd_model.py new file mode 100644 index 0000000000000000000000000000000000000000..94c25f1b94f912b7adb161c51979a732a5042875 --- /dev/null +++ b/lama_cleaner/tests/test_sd_model.py @@ -0,0 +1,241 @@ +import os + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +from pathlib import Path + +import pytest +import torch + +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import HDStrategy, SDSampler +from lama_cleaner.tests.test_model import get_config, assert_equal + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / "result" +save_dir.mkdir(exist_ok=True, parents=True) +device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device(device) + + +@pytest.mark.parametrize("sd_device", ["cuda"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +@pytest.mark.parametrize("cpu_textencoder", [True, False]) +@pytest.mark.parametrize("disable_nsfw", [True, False]) +def test_runway_sd_1_5_ddim( + sd_device, strategy, sampler, cpu_textencoder, disable_nsfw +): + def callback(i, t, latents): + pass + + if sd_device == "cuda" and not torch.cuda.is_available(): + return + + sd_steps = 50 if sd_device == "cuda" else 1 + model = ModelManager( + name="sd1.5", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=disable_nsfw, + sd_cpu_textencoder=cpu_textencoder, + callback=callback, + ) + cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}" + + assert_equal( + model, + cfg, + f"runway_sd_{strategy.capitalize()}_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1.3, + ) + + +@pytest.mark.parametrize("sd_device", ["cuda"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize( + "sampler", [SDSampler.pndm, SDSampler.k_lms, SDSampler.k_euler, SDSampler.k_euler_a] +) +@pytest.mark.parametrize("cpu_textencoder", [False]) +@pytest.mark.parametrize("disable_nsfw", [True]) +def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw): + def callback(i, t, latents): + print(f"sd_step_{i}") + + if sd_device == "cuda" and not torch.cuda.is_available(): + return + + sd_steps = 50 if sd_device == "cuda" else 1 + model = ModelManager( + name="sd1.5", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=disable_nsfw, + sd_cpu_textencoder=cpu_textencoder, + callback=callback, + ) + cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}" + + assert_equal( + model, + cfg, + f"runway_sd_{strategy.capitalize()}_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1.3, + ) + + +@pytest.mark.parametrize("sd_device", ["cuda"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler): + def callback(i, t, latents): + pass + + if sd_device == "cuda" and not torch.cuda.is_available(): + return + + sd_steps = 50 if sd_device == "cuda" else 1 + model = ModelManager( + name="sd1.5", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=False, + sd_cpu_textencoder=False, + callback=callback, + ) + cfg = get_config( + strategy, + sd_steps=sd_steps, + prompt="Face of a fox, high resolution, sitting on a park bench", + negative_prompt="orange, yellow, small", + sd_sampler=sampler, + sd_match_histograms=True, + ) + + name = f"{sampler}_negative_prompt" + + assert_equal( + model, + cfg, + f"runway_sd_{strategy.capitalize()}_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1, + ) + + +@pytest.mark.parametrize("sd_device", ["cuda"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a]) +@pytest.mark.parametrize("cpu_textencoder", [False]) +@pytest.mark.parametrize("disable_nsfw", [False]) +def test_runway_sd_1_5_sd_scale( + sd_device, strategy, sampler, cpu_textencoder, disable_nsfw +): + if sd_device == "cuda" and not torch.cuda.is_available(): + return + + sd_steps = 50 if sd_device == "cuda" else 1 + model = ModelManager( + name="sd1.5", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=disable_nsfw, + sd_cpu_textencoder=cpu_textencoder, + ) + cfg = get_config( + strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps, sd_scale=0.85 + ) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}" + + assert_equal( + model, + cfg, + f"runway_sd_{strategy.capitalize()}_{name}_sdscale.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1.3, + ) + + +@pytest.mark.parametrize("sd_device", ["cuda"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a]) +def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler): + if sd_device == "cuda" and not torch.cuda.is_available(): + return + + sd_steps = 50 if sd_device == "cuda" else 1 + model = ModelManager( + name="sd1.5", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=True, + sd_cpu_textencoder=False, + cpu_offload=True, + ) + cfg = get_config( + strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps, sd_scale=0.85 + ) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}" + + assert_equal( + model, + cfg, + f"runway_sd_{strategy.capitalize()}_{name}_cpu_offload.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("sd_device", ["cuda", "mps"]) +@pytest.mark.parametrize("sampler", [SDSampler.uni_pc]) +def test_local_file_path(sd_device, sampler): + if sd_device == "cuda" and not torch.cuda.is_available(): + return + + sd_steps = 1 if sd_device == "cpu" else 50 + model = ModelManager( + name="sd1.5", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=True, + sd_cpu_textencoder=False, + cpu_offload=True, + sd_local_model_path="/Users/cwq/data/models/sd-v1-5-inpainting.ckpt", + ) + cfg = get_config( + HDStrategy.ORIGINAL, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + ) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}" + + assert_equal( + model, + cfg, + f"sd_local_model_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) diff --git a/lama_cleaner/web_config.py b/lama_cleaner/web_config.py new file mode 100644 index 0000000000000000000000000000000000000000..5c9a14d90f23c64c7e122c7833e0703180c87ca8 --- /dev/null +++ b/lama_cleaner/web_config.py @@ -0,0 +1,246 @@ +import json +import os +from datetime import datetime + +import gradio as gr +from loguru import logger + +from lama_cleaner.const import * + +_config_file = None + + +def save_config( + host, + port, + model, + sd_local_model_path, + sd_controlnet, + sd_controlnet_method, + device, + gui, + no_gui_auto_close, + no_half, + cpu_offload, + disable_nsfw, + sd_cpu_textencoder, + enable_xformers, + local_files_only, + model_dir, + input, + output_dir, + quality, + enable_interactive_seg, + interactive_seg_model, + interactive_seg_device, + enable_remove_bg, + enable_anime_seg, + enable_realesrgan, + realesrgan_device, + realesrgan_model, + enable_gfpgan, + gfpgan_device, + enable_restoreformer, + restoreformer_device, + enable_gif, +): + config = Config(**locals()) + print(config) + if config.input and not os.path.exists(config.input): + return "[Error] Input file or directory does not exist" + + current_time = datetime.now().strftime("%H:%M:%S") + msg = f"[{current_time}] Successful save config to: {os.path.abspath(_config_file)}" + logger.info(msg) + try: + with open(_config_file, "w", encoding="utf-8") as f: + json.dump(config.dict(), f, indent=4, ensure_ascii=False) + except Exception as e: + return f"Save failed: {str(e)}" + return msg + + +def close_server(*args): + # TODO: make close both browser and server works + import os, signal + + pid = os.getpid() + os.kill(pid, signal.SIGUSR1) + + +def main(config_file: str): + global _config_file + _config_file = config_file + + init_config = load_config(config_file) + + with gr.Blocks() as demo: + with gr.Row(): + with gr.Column(scale=1): + save_btn = gr.Button(value="Save configurations") + message = gr.HTML() + + with gr.Tabs(): + with gr.Tab("Common"): + with gr.Row(): + host = gr.Textbox(init_config.host, label="Host") + port = gr.Number(init_config.port, label="Port", precision=0) + + model = gr.Radio( + AVAILABLE_MODELS, label="Model", value=init_config.model + ) + device = gr.Radio( + AVAILABLE_DEVICES, label="Device", value=init_config.device + ) + quality = gr.Slider( + value=95, + label=f"Image Quality ({QUALITY_HELP})", + minimum=75, + maximum=100, + step=1, + ) + + with gr.Column(): + gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}") + no_gui_auto_close = gr.Checkbox( + init_config.no_gui_auto_close, label=f"{NO_GUI_AUTO_CLOSE_HELP}" + ) + + with gr.Column(): + model_dir = gr.Textbox( + init_config.model_dir, label=f"{MODEL_DIR_HELP}" + ) + input = gr.Textbox( + init_config.input, + label=f"Input file or directory. {INPUT_HELP}", + ) + output_dir = gr.Textbox( + init_config.output_dir, + label=f"Output directory. {OUTPUT_DIR_HELP}", + ) + + with gr.Tab("Plugins"): + enable_interactive_seg = gr.Checkbox( + init_config.enable_interactive_seg, label=INTERACTIVE_SEG_HELP + ) + interactive_seg_model = gr.Radio( + AVAILABLE_INTERACTIVE_SEG_MODELS, + label=f"Segment Anything models. {INTERACTIVE_SEG_MODEL_HELP}", + value=init_config.interactive_seg_model, + ) + interactive_seg_device = gr.Radio( + AVAILABLE_INTERACTIVE_SEG_DEVICES, + label="Segment Anything Device", + value=init_config.interactive_seg_device, + ) + with gr.Row(): + enable_remove_bg = gr.Checkbox( + init_config.enable_remove_bg, label=REMOVE_BG_HELP + ) + with gr.Row(): + enable_anime_seg = gr.Checkbox( + init_config.enable_anime_seg, label=ANIMESEG_HELP + ) + + with gr.Row(): + enable_realesrgan = gr.Checkbox( + init_config.enable_realesrgan, label=REALESRGAN_HELP + ) + realesrgan_device = gr.Radio( + REALESRGAN_AVAILABLE_DEVICES, + label="RealESRGAN Device", + value=init_config.realesrgan_device, + ) + realesrgan_model = gr.Radio( + RealESRGANModelNameList, + label="RealESRGAN model", + value=init_config.realesrgan_model, + ) + with gr.Row(): + enable_gfpgan = gr.Checkbox( + init_config.enable_gfpgan, label=GFPGAN_HELP + ) + gfpgan_device = gr.Radio( + GFPGAN_AVAILABLE_DEVICES, + label="GFPGAN Device", + value=init_config.gfpgan_device, + ) + with gr.Row(): + enable_restoreformer = gr.Checkbox( + init_config.enable_restoreformer, label=RESTOREFORMER_HELP + ) + restoreformer_device = gr.Radio( + RESTOREFORMER_AVAILABLE_DEVICES, + label="RestoreFormer Device", + value=init_config.restoreformer_device, + ) + enable_gif = gr.Checkbox(init_config.enable_gif, label=GIF_HELP) + + with gr.Tab("Diffusion Model"): + sd_local_model_path = gr.Textbox( + init_config.sd_local_model_path, label=f"{SD_LOCAL_MODEL_HELP}" + ) + sd_controlnet = gr.Checkbox( + init_config.sd_controlnet, label=f"{SD_CONTROLNET_HELP}" + ) + sd_controlnet_method = gr.Radio( + SD_CONTROLNET_CHOICES, + label="ControlNet method", + value=init_config.sd_controlnet_method, + ) + no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}") + cpu_offload = gr.Checkbox( + init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}" + ) + sd_cpu_textencoder = gr.Checkbox( + init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}" + ) + disable_nsfw = gr.Checkbox( + init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}" + ) + enable_xformers = gr.Checkbox( + init_config.enable_xformers, label=f"{ENABLE_XFORMERS_HELP}" + ) + local_files_only = gr.Checkbox( + init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}" + ) + + save_btn.click( + save_config, + [ + host, + port, + model, + sd_local_model_path, + sd_controlnet, + sd_controlnet_method, + device, + gui, + no_gui_auto_close, + no_half, + cpu_offload, + disable_nsfw, + sd_cpu_textencoder, + enable_xformers, + local_files_only, + model_dir, + input, + output_dir, + quality, + enable_interactive_seg, + interactive_seg_model, + interactive_seg_device, + enable_remove_bg, + enable_anime_seg, + enable_realesrgan, + realesrgan_device, + realesrgan_model, + enable_gfpgan, + gfpgan_device, + enable_restoreformer, + restoreformer_device, + enable_gif, + ], + message, + ) + demo.launch(inbrowser=True, show_api=False) diff --git a/mobile_sam/__init__.py b/mobile_sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d46cd7b483ff84c0c0f5995294504d386434d60 --- /dev/null +++ b/mobile_sam/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +warnings.filterwarnings("ignore", category=UserWarning, module="mobile_sam") + +from .automatic_mask_generator import SamAutomaticMaskGenerator # noqa: E402 +from .build_sam import (build_sam, build_sam_vit_b, build_sam_vit_h, build_sam_vit_l, # noqa: E402 + build_sam_vit_t, sam_model_registry) +from .predictor import SamPredictor # noqa: E402 + +__all__ = [ + "build_sam", + "build_sam_vit_h", + "build_sam_vit_l", + "build_sam_vit_b", + "build_sam_vit_t", + "sam_model_registry", + "SamPredictor", + "SamAutomaticMaskGenerator", +] diff --git a/mobile_sam/automatic_mask_generator.py b/mobile_sam/automatic_mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3c70227dbef3028e6e3673518181923be55272 --- /dev/null +++ b/mobile_sam/automatic_mask_generator.py @@ -0,0 +1,383 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from .modeling import Sam +from .predictor import SamPredictor +from .utils.amg import (MaskData, area_from_rle, batch_iterator, batched_mask_to_box, + box_xyxy_to_xywh, build_all_layer_point_grids, calculate_stability_score, + coco_encode_rle, generate_crop_boxes, is_box_near_crop_edge, + mask_to_rle_pytorch, remove_small_regions, rle_to_mask, uncrop_boxes_xyxy, + uncrop_masks, uncrop_points) +from .utils.torch_nms import nms + + +class SamAutomaticMaskGenerator: + def __init__( + self, + model: Sam, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + ) -> None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = SamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + max(self.box_nms_thresh, self.crop_nms_thresh), + ) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + try: + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + except Exception: + keep_by_nms = nms( + data["boxes"].float(), + scores, + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) + data.cat(batch_data) + del batch_data + self.predictor.reset_image() + + # Remove duplicates within this crop. + try: + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + except Exception: + keep_by_nms = nms( + data["boxes"].float(), + data["iou_preds"], + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + transformed_points = self.predictor.transform.apply_coords(points, im_size).astype( + np.float32 + ) + in_points = torch.as_tensor(transformed_points, device=self.predictor.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + masks, iou_preds, _ = self.predictor.predict_torch( + in_points[:, None, :], + in_labels[:, None], + multimask_output=True, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.predictor.model.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + try: + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + except Exception: + keep_by_nms = nms( + boxes.float(), + torch.as_tensor(scores), + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data diff --git a/mobile_sam/build_sam.py b/mobile_sam/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..a6c8b4bac291acc8e68c3b44887d8e15ee75409b --- /dev/null +++ b/mobile_sam/build_sam.py @@ -0,0 +1,158 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +def build_sam_vit_t(checkpoint=None): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + mobile_sam = Sam( + image_encoder=TinyViT( + img_size=1024, in_chans=3, num_classes=1000, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8 + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + + mobile_sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + mobile_sam.load_state_dict(state_dict) + return mobile_sam + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, + "vit_t": build_sam_vit_t, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam diff --git a/mobile_sam/modeling/__init__.py b/mobile_sam/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fd894c00daa94447609ee3294b36e1986bf7b77f --- /dev/null +++ b/mobile_sam/modeling/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer +from .tiny_vit_sam import TinyViT + +__all__ = [ + "Sam", + "ImageEncoderViT", + "MaskDecoder", + "PromptEncoder", + "TwoWayTransformer", + "TinyViT", +] diff --git a/mobile_sam/modeling/common.py b/mobile_sam/modeling/common.py new file mode 100644 index 0000000000000000000000000000000000000000..5c92073d1fd6a44d9a7f3abb9ab610d3ccbcac12 --- /dev/null +++ b/mobile_sam/modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/mobile_sam/modeling/image_encoder.py b/mobile_sam/modeling/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6e74d81fd0bd8e7c33c3e323ba16ab81f37a779b --- /dev/null +++ b/mobile_sam/modeling/image_encoder.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/mobile_sam/modeling/mask_decoder.py b/mobile_sam/modeling/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b46fc0bc5ddc7f4368b86eb6a8c8468b21b880e8 --- /dev/null +++ b/mobile_sam/modeling/mask_decoder.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/mobile_sam/modeling/prompt_encoder.py b/mobile_sam/modeling/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4f73520ad1318da91f271a623c8497c8b9a31475 --- /dev/null +++ b/mobile_sam/modeling/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/mobile_sam/modeling/sam.py b/mobile_sam/modeling/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..ee5303e9b5132098214b60e225a7b9a9d96caa4d --- /dev/null +++ b/mobile_sam/modeling/sam.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple, Union + +from .tiny_vit_sam import TinyViT +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: Union[ImageEncoderViT, TinyViT], + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/mobile_sam/modeling/tiny_vit_sam.py b/mobile_sam/modeling/tiny_vit_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..87635e6ed768c4132713b78234e1c0fede5eb9d9 --- /dev/null +++ b/mobile_sam/modeling/tiny_vit_sam.py @@ -0,0 +1,721 @@ +# -------------------------------------------------------- +# TinyViT Model Architecture +# Copyright (c) 2022 Microsoft +# Adapted from LeViT and Swin Transformer +# LeViT: (https://github.com/facebookresearch/levit) +# Swin: (https://github.com/microsoft/swin-transformer) +# Build the TinyViT Model +# -------------------------------------------------------- + +import itertools +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath as TimmDropPath +from timm.models.layers import to_2tuple, trunc_normal_ +from timm.models.registry import register_model + + +class Conv2d_BN(torch.nn.Sequential): + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', torch.nn.Conv2d( + a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( + 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class DropPath(TimmDropPath): + def __init__(self, drop_prob=None): + super().__init__(drop_prob=drop_prob) + self.drop_prob = drop_prob + + def __repr__(self): + msg = super().__repr__() + msg += f'(drop_prob={self.drop_prob})' + return msg + + +class PatchEmbed(nn.Module): + def __init__(self, in_chans, embed_dim, resolution, activation): + super().__init__() + img_size: Tuple[int, int] = to_2tuple(resolution) + self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) + self.num_patches = self.patches_resolution[0] * \ + self.patches_resolution[1] + self.in_chans = in_chans + self.embed_dim = embed_dim + n = embed_dim + self.seq = nn.Sequential( + Conv2d_BN(in_chans, n // 2, 3, 2, 1), + activation(), + Conv2d_BN(n // 2, n, 3, 2, 1), + ) + + def forward(self, x): + return self.seq(x) + + +class MBConv(nn.Module): + def __init__(self, in_chans, out_chans, expand_ratio, + activation, drop_path): + super().__init__() + self.in_chans = in_chans + self.hidden_chans = int(in_chans * expand_ratio) + self.out_chans = out_chans + + self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1) + self.act1 = activation() + + self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, + ks=3, stride=1, pad=1, groups=self.hidden_chans) + self.act2 = activation() + + self.conv3 = Conv2d_BN( + self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0) + self.act3 = activation() + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.act2(x) + + x = self.conv3(x) + + x = self.drop_path(x) + + x += shortcut + x = self.act3(x) + + return x + + +class PatchMerging(nn.Module): + def __init__(self, input_resolution, dim, out_dim, activation): + super().__init__() + + self.input_resolution = input_resolution + self.dim = dim + self.out_dim = out_dim + self.act = activation() + self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) + stride_c = 2 + if (out_dim == 320 or out_dim == 448 or out_dim == 576): + stride_c = 1 + self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) + self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) + + def forward(self, x): + if x.ndim == 3: + H, W = self.input_resolution + B = len(x) + # (B, C, H, W) + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + + x = self.conv1(x) + x = self.act(x) + + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class ConvLayer(nn.Module): + def __init__(self, dim, input_resolution, depth, + activation, + drop_path=0., downsample=None, use_checkpoint=False, + out_dim=None, + conv_expand_ratio=4., + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + MBConv(dim, dim, conv_expand_ratio, activation, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, + out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.norm = nn.LayerNorm(in_features) + self.fc1 = nn.Linear(in_features, hidden_features) + self.fc2 = nn.Linear(hidden_features, out_features) + self.act = act_layer() + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.norm(x) + + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(torch.nn.Module): + def __init__(self, dim, key_dim, num_heads=8, + attn_ratio=4, + resolution=(14, 14), + ): + super().__init__() + # (h, w) + assert isinstance(resolution, tuple) and len(resolution) == 2 + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, h) + self.proj = nn.Linear(self.dh, dim) + + points = list(itertools.product( + range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N), + persistent=False) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.register_buffer('ab', + self.attention_biases[:, self.attention_bias_idxs], + persistent=False) + + def forward(self, x): # x (B,N,C) + B, N, _ = x.shape + + # Normalization + x = self.norm(x) + + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, - + 1).split([self.key_dim, self.key_dim, self.d], dim=3) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = ( + (q @ k.transpose(-2, -1)) * self.scale + + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab) + ) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class TinyViTBlock(nn.Module): + r""" TinyViT Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int, int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + local_conv_size (int): the kernel size of the convolution between + Attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, + mlp_ratio=4., drop=0., drop_path=0., + local_conv_size=3, + activation=nn.GELU, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + assert window_size > 0, 'window_size must be greater than 0' + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + assert dim % num_heads == 0, 'dim must be divisible by num_heads' + head_dim = dim // num_heads + + window_resolution = (window_size, window_size) + self.attn = Attention(dim, head_dim, num_heads, + attn_ratio=1, resolution=window_resolution) + + mlp_hidden_dim = int(dim * mlp_ratio) + mlp_activation = activation + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=mlp_activation, drop=drop) + + pad = local_conv_size // 2 + self.local_conv = Conv2d_BN( + dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + res_x = x + if H == self.window_size and W == self.window_size: + x = self.attn(x) + else: + x = x.view(B, H, W, C) + pad_b = (self.window_size - H % + self.window_size) % self.window_size + pad_r = (self.window_size - W % + self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + # window partition + x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape( + B * nH * nW, self.window_size * self.window_size, C) + x = self.attn(x) + # window reverse + x = x.view(B, nH, nW, self.window_size, self.window_size, + C).transpose(2, 3).reshape(B, pH, pW, C) + + if padding: + x = x[:, :H, :W].contiguous() + + x = x.view(B, L, C) + + x = res_x + self.drop_path(x) + + x = x.transpose(1, 2).reshape(B, C, H, W) + x = self.local_conv(x) + x = x.view(B, C, L).transpose(1, 2) + + x = x + self.drop_path(self.mlp(x)) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" + + +class BasicLayer(nn.Module): + """ A basic TinyViT layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + out_dim: the output dimension of the layer. Default: dim + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., drop=0., + drop_path=0., downsample=None, use_checkpoint=False, + local_conv_size=3, + activation=nn.GELU, + out_dim=None, + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + TinyViTBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance( + drop_path, list) else drop_path, + local_conv_size=local_conv_size, + activation=activation, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class TinyViT(nn.Module): + def __init__(self, img_size=224, in_chans=3, num_classes=1000, + embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=1.0, + ): + super().__init__() + self.img_size = img_size + self.num_classes = num_classes + self.depths = depths + self.num_layers = len(depths) + self.mlp_ratio = mlp_ratio + + activation = nn.GELU + + self.patch_embed = PatchEmbed(in_chans=in_chans, + embed_dim=embed_dims[0], + resolution=img_size, + activation=activation) + + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, + sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + kwargs = dict(dim=embed_dims[i_layer], + input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)), + patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))), + depth=depths[i_layer], + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + downsample=PatchMerging if ( + i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + out_dim=embed_dims[min( + i_layer + 1, len(embed_dims) - 1)], + activation=activation, + ) + if i_layer == 0: + layer = ConvLayer( + conv_expand_ratio=mbconv_expand_ratio, + **kwargs, + ) + else: + layer = BasicLayer( + num_heads=num_heads[i_layer], + window_size=window_sizes[i_layer], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + local_conv_size=local_conv_size, + **kwargs) + self.layers.append(layer) + + # Classifier head + self.norm_head = nn.LayerNorm(embed_dims[-1]) + self.head = nn.Linear( + embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + # init weights + self.apply(self._init_weights) + self.set_layer_lr_decay(layer_lr_decay) + self.neck = nn.Sequential( + nn.Conv2d( + embed_dims[-1], + 256, + kernel_size=1, + bias=False, + ), + LayerNorm2d(256), + nn.Conv2d( + 256, + 256, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(256), + ) + + def set_layer_lr_decay(self, layer_lr_decay): + decay_rate = layer_lr_decay + + # layers -> blocks (depth) + depth = sum(self.depths) + lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)] + # print("LR SCALES:", lr_scales) + + def _set_lr_scale(m, scale): + for p in m.parameters(): + p.lr_scale = scale + + self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0])) + i = 0 + for layer in self.layers: + for block in layer.blocks: + block.apply(lambda x: _set_lr_scale(x, lr_scales[i])) + i += 1 + if layer.downsample is not None: + layer.downsample.apply( + lambda x: _set_lr_scale(x, lr_scales[i - 1])) + assert i == depth + for m in [self.norm_head, self.head]: + m.apply(lambda x: _set_lr_scale(x, lr_scales[-1])) + + for k, p in self.named_parameters(): + p.param_name = k + + def _check_lr_scale(m): + for p in m.parameters(): + assert hasattr(p, 'lr_scale'), p.param_name + + self.apply(_check_lr_scale) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'attention_biases'} + + def forward_features(self, x): + # x: (N, C, H, W) + x = self.patch_embed(x) + + x = self.layers[0](x) + start_i = 1 + + for i in range(start_i, len(self.layers)): + layer = self.layers[i] + x = layer(x) + B, _, C = x.size() + x = x.view(B, 64, 64, C) + x = x.permute(0, 3, 1, 2) + x = self.neck(x) + return x + + def forward(self, x): + x = self.forward_features(x) + # x = self.norm_head(x) + # x = self.head(x) + return x + + +_checkpoint_url_format = \ + 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/{}.pth' +_provided_checkpoints = { + 'tiny_vit_5m_224': 'tiny_vit_5m_22kto1k_distill', + 'tiny_vit_11m_224': 'tiny_vit_11m_22kto1k_distill', + 'tiny_vit_21m_224': 'tiny_vit_21m_22kto1k_distill', + 'tiny_vit_21m_384': 'tiny_vit_21m_22kto1k_384_distill', + 'tiny_vit_21m_512': 'tiny_vit_21m_22kto1k_512_distill', +} + + +def register_tiny_vit_model(fn): + '''Register a TinyViT model + It is a wrapper of `register_model` with loading the pretrained checkpoint. + ''' + def fn_wrapper(pretrained=False, **kwargs): + model = fn() + if pretrained: + model_name = fn.__name__ + assert model_name in _provided_checkpoints, \ + f'Sorry that the checkpoint `{model_name}` is not provided yet.' + url = _checkpoint_url_format.format( + _provided_checkpoints[model_name]) + checkpoint = torch.hub.load_state_dict_from_url( + url=url, + map_location='cpu', check_hash=False, + ) + model.load_state_dict(checkpoint['model']) + + return model + + # rename the name of fn_wrapper + fn_wrapper.__name__ = fn.__name__ + return register_model(fn_wrapper) + + +@register_tiny_vit_model +def tiny_vit_5m_224(pretrained=False, num_classes=1000, drop_path_rate=0.0): + return TinyViT( + num_classes=num_classes, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_11m_224(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + num_classes=num_classes, + embed_dims=[64, 128, 256, 448], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 8, 14], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_224(pretrained=False, num_classes=1000, drop_path_rate=0.2): + return TinyViT( + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_384(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + img_size=384, + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[12, 12, 24, 12], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_512(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + img_size=512, + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[16, 16, 32, 16], + drop_path_rate=drop_path_rate, + ) diff --git a/mobile_sam/modeling/transformer.py b/mobile_sam/modeling/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d99f8e8265b5780dd3be1d8c6bbd33156ac1d8f4 --- /dev/null +++ b/mobile_sam/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/mobile_sam/predictor.py b/mobile_sam/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..2e4d0a2e978ffe9123c07164db7b5c4e695bb596 --- /dev/null +++ b/mobile_sam/predictor.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from mobile_sam.modeling import Sam + +from typing import Optional, Tuple + +from .utils.transforms import ResizeLongestSide + + +class SamPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + # import pdb; pdb.set_trace() + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/mobile_sam/utils/__init__.py b/mobile_sam/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3 --- /dev/null +++ b/mobile_sam/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/mobile_sam/utils/amg.py b/mobile_sam/utils/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..1b3177dea0c282cef17942a35479bda5b299d4b8 --- /dev/null +++ b/mobile_sam/utils/amg.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size: (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx: idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords.int(), dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords.int(), dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords.int(), dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords.int(), dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/mobile_sam/utils/onnx.py b/mobile_sam/utils/onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a9d9e2f1c5990f6b279ef7d1bb847063c68e5e --- /dev/null +++ b/mobile_sam/utils/onnx.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from typing import Tuple + +from ..modeling import Sam +from .amg import calculate_stability_score + + +class SamOnnxModel(nn.Module): + """ + This model should not be called directly, but is used in ONNX export. + It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, + with some functions modified to enable model tracing. Also supports extra + options controlling what information. See the ONNX export script for details. + """ + + def __init__( + self, + model: Sam, + return_single_mask: bool, + use_stability_score: bool = False, + return_extra_metrics: bool = False, + ) -> None: + super().__init__() + self.mask_decoder = model.mask_decoder + self.model = model + self.img_size = model.image_encoder.img_size + self.return_single_mask = return_single_mask + self.use_stability_score = use_stability_score + self.stability_score_offset = 1.0 + self.return_extra_metrics = return_extra_metrics + + @staticmethod + def resize_longest_image_size( + input_image_size: torch.Tensor, longest_side: int + ) -> torch.Tensor: + input_image_size = input_image_size.to(torch.float32) + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) + return transformed_size + + def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + point_coords = point_coords + 0.5 + point_coords = point_coords / self.img_size + point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( + point_labels == -1 + ) + + for i in range(self.model.prompt_encoder.num_point_embeddings): + point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ + i + ].weight * (point_labels == i) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding + + def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(self.img_size, self.img_size), + mode="bilinear", + align_corners=False, + ) + + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore + + orig_im_size = orig_im_size.to(torch.int64) + h, w = orig_im_size[0], orig_im_size[1] + masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + return masks + + def select_masks( + self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Determine if we should return the multiclick mask or not from the number of points. + # The reweighting is used to avoid control flow. + score_reweight = torch.tensor( + [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] + ).to(iou_preds.device) + score = iou_preds + (num_points - 2.5) * score_reweight + best_idx = torch.argmax(score, dim=1) + masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) + iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) + + return masks, iou_preds + + @torch.no_grad() + def forward( + self, + image_embeddings: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + orig_im_size: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + masks, scores = self.model.mask_decoder.predict_masks( + image_embeddings=image_embeddings, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, + ) + + if self.use_stability_score: + scores = calculate_stability_score( + masks, self.model.mask_threshold, self.stability_score_offset + ) + + if self.return_single_mask: + masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) + + upscaled_masks = self.mask_postprocessing(masks, orig_im_size) + + if self.return_extra_metrics: + stability_scores = calculate_stability_score( + upscaled_masks, self.model.mask_threshold, self.stability_score_offset + ) + areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) + return upscaled_masks, scores, stability_scores, areas, masks + + return upscaled_masks, scores, masks diff --git a/mobile_sam/utils/torch_nms.py b/mobile_sam/utils/torch_nms.py new file mode 100644 index 0000000000000000000000000000000000000000..82f1a1f5c0dcab0292fb414723ba2c01947f081a --- /dev/null +++ b/mobile_sam/utils/torch_nms.py @@ -0,0 +1,20 @@ +import torch +from torchvision.ops.boxes import box_iou + + +def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor: + order = torch.argsort(-scores) + keep = [] + + while order.numel() > 0: + i = order[0] + keep.append(i.item()) + + if order.numel() == 1: + break + + ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0] + mask = ious <= iou_threshold + order = order[1:][mask] + + return torch.tensor(keep, device=bboxes.device) diff --git a/mobile_sam/utils/transforms.py b/mobile_sam/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..f07693952bbffcd23c5226255d1f649476ca7ce6 --- /dev/null +++ b/mobile_sam/utils/transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to the longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a9c1445a379c5dd000d4fa08cffa0c15edabe394 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,24 @@ +--extra-index-url https://download.pytorch.org/whl/cu118 +torch==2.2.2 +torchvision +accelerate +diffusers +gradio<4.0.0 +huggingface-hub +numpy +opencv-python +pillow +segment-anything +transformers<5.0.0 +xformers==0.0.26 +# lama-cleaner +ultralytics +tqdm +packaging +loguru +rich +pydantic +timm +onnxruntime +hydra-core +iopath diff --git a/requirements_mac.txt b/requirements_mac.txt new file mode 100644 index 0000000000000000000000000000000000000000..b16f9e90f7bad11f32226902aac76c0c6a51ee56 --- /dev/null +++ b/requirements_mac.txt @@ -0,0 +1,23 @@ +torch==2.2.2 +torchvision +accelerate +diffusers +gradio<4.0.0 +huggingface-hub +numpy +opencv-python +Pillow +segment-anything +transformers<5.0.0 +# xformers +# lama-cleaner +ultralytics +tqdm +packaging +loguru +rich +pydantic +timm +onnxruntime +hydra-core +iopath diff --git a/sam2/__init__.py b/sam2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1a727abab4fae1d47f2b2c23edad0a967546288 --- /dev/null +++ b/sam2/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings + +from hydra import initialize_config_dir, initialize_config_module # noqa: F401 + +warnings.filterwarnings("ignore", category=UserWarning, module="sam2") + +inpa_basedir = os.path.abspath(os.path.normpath(os.path.join(os.path.dirname(__file__), ".."))) +configs_path = os.path.join(inpa_basedir, "sam2_configs") + +try: + initialize_config_dir(configs_path, version_base="1.2") +except TypeError: + initialize_config_dir(configs_path) +# initialize_config_module("sam2_configs", version_base="1.2") diff --git a/sam2/automatic_mask_generator.py b/sam2/automatic_mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..fe57b6d87d510811471c3c13ef07d76640f8a8f6 --- /dev/null +++ b/sam2/automatic_mask_generator.py @@ -0,0 +1,446 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from sam2.modeling.sam2_base import SAM2Base +from sam2.sam2_image_predictor import SAM2ImagePredictor +from sam2.utils.amg import (MaskData, area_from_rle, batch_iterator, batched_mask_to_box, + box_xyxy_to_xywh, build_all_layer_point_grids, + calculate_stability_score, coco_encode_rle, generate_crop_boxes, + is_box_near_crop_edge, mask_to_rle_pytorch, remove_small_regions, + rle_to_mask, uncrop_boxes_xyxy, uncrop_masks, uncrop_points) + +from .utils.torch_nms import nms + + +class SAM2AutomaticMaskGenerator: + def __init__( + self, + model: SAM2Base, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.8, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + mask_threshold: float = 0.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + use_m2m: bool = False, + multimask_output: bool = True, + ) -> None: + """ + Using a SAM 2 model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM 2 with a HieraL backbone. + + Arguments: + model (Sam): The SAM 2 model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + mask_threshold (float): Threshold for binarizing the mask logits + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + use_m2m (bool): Whether to add a one step refinement using previous mask predictions. + multimask_output (bool): Whether to output multimask at each point of the grid. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + try: + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + except ImportError as e: + print("Please install pycocotools") + raise e + + self.predictor = SAM2ImagePredictor( + model, + max_hole_area=min_mask_region_area, + max_sprinkle_area=min_mask_region_area, + ) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.mask_threshold = mask_threshold + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + self.use_m2m = use_m2m + self.multimask_output = multimask_output + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [ + coco_encode_rle(rle) for rle in mask_data["rles"] + ] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + try: + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + except Exception: + keep_by_nms = nms( + data["boxes"].float(), + scores, + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch( + points, cropped_im_size, crop_box, orig_size, normalize=True + ) + data.cat(batch_data) + del batch_data + self.predictor.reset_predictor() + + # Remove duplicates within this crop. + try: + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + except Exception: + keep_by_nms = nms( + data["boxes"].float(), + data["iou_preds"], + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + normalize=False, + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + # points = torch.as_tensor(points, device=self.predictor.device) + points = torch.as_tensor(points.astype(np.float32), device=self.predictor.device) + in_points = self.predictor._transforms.transform_coords( + points, normalize=normalize, orig_hw=im_size + ) + in_labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, iou_preds, low_res_masks = self.predictor._predict( + in_points[:, None, :], + in_labels[:, None], + multimask_output=self.multimask_output, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=points.repeat_interleave(masks.shape[1], dim=0), + low_res_masks=low_res_masks.flatten(0, 1), + ) + del masks + + if not self.use_m2m: + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate and filter by stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + else: + # One step refinement using previous mask predictions + in_points = self.predictor._transforms.transform_coords( + data["points"], normalize=normalize, orig_hw=im_size + ) + labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, ious = self.refine_with_m2m( + in_points, labels, data["low_res_masks"], self.points_per_batch + ) + data["masks"] = masks.squeeze(1) + data["iou_preds"] = ious.squeeze(1) + + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge( + data["boxes"], crop_box, [0, 0, orig_w, orig_h] + ) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + try: + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + except Exception: + keep_by_nms = nms( + boxes.float(), + torch.as_tensor(scores), + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data + + def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch): + new_masks = [] + new_iou_preds = [] + + for cur_points, cur_point_labels, low_res_mask in batch_iterator( + points_per_batch, points, point_labels, low_res_masks + ): + best_masks, best_iou_preds, _ = self.predictor._predict( + cur_points[:, None, :], + cur_point_labels[:, None], + mask_input=low_res_mask[:, None, :], + multimask_output=False, + return_logits=True, + ) + new_masks.append(best_masks) + new_iou_preds.append(best_iou_preds) + masks = torch.cat(new_masks, dim=0) + return masks, torch.cat(new_iou_preds, dim=0) diff --git a/sam2/build_sam.py b/sam2/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..50f6c91cf442394573d2eaccb0d7112d6995f684 --- /dev/null +++ b/sam2/build_sam.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from hydra import compose +from hydra.utils import instantiate +from omegaconf import OmegaConf + + +def build_sam2( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, +): + + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + ] + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def build_sam2_video_predictor( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, +): + hydra_overrides = [ + "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", + ] + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking + "++model.binarize_mask_from_pts_for_mem_enc=true", + # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) + "++model.fill_hole_area=8", + ] + hydra_overrides.extend(hydra_overrides_extra) + + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def _load_checkpoint(model, ckpt_path): + if ckpt_path is not None: + sd = torch.load(ckpt_path, map_location="cpu")["model"] + missing_keys, unexpected_keys = model.load_state_dict(sd) + if missing_keys: + logging.error(missing_keys) + raise RuntimeError() + if unexpected_keys: + logging.error(unexpected_keys) + raise RuntimeError() + logging.info("Loaded checkpoint sucessfully") diff --git a/sam2/csrc/connected_components.cu b/sam2/csrc/connected_components.cu new file mode 100644 index 0000000000000000000000000000000000000000..6e3fbee0eba762c7198ace220660ceadd13d7402 --- /dev/null +++ b/sam2/csrc/connected_components.cu @@ -0,0 +1,289 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. + +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// adapted from https://github.com/zsef123/Connected_components_PyTorch +// with license found in the LICENSE_cctorch file in the root directory. +#include +#include +#include +#include +#include +#include + +// 2d +#define BLOCK_ROWS 16 +#define BLOCK_COLS 16 + +namespace cc2d { + +template +__device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) { + return (bitmap >> pos) & 1; +} + +__device__ int32_t find(const int32_t* s_buf, int32_t n) { + while (s_buf[n] != n) + n = s_buf[n]; + return n; +} + +__device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) { + const int32_t id = n; + while (s_buf[n] != n) { + n = s_buf[n]; + s_buf[id] = n; + } + return n; +} + +__device__ void union_(int32_t* s_buf, int32_t a, int32_t b) { + bool done; + do { + a = find(s_buf, a); + b = find(s_buf, b); + + if (a < b) { + int32_t old = atomicMin(s_buf + b, a); + done = (old == b); + b = old; + } else if (b < a) { + int32_t old = atomicMin(s_buf + a, b); + done = (old == a); + a = old; + } else + done = true; + + } while (!done); +} + +__global__ void +init_labeling(int32_t* label, const uint32_t W, const uint32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row < H && col < W) + label[idx] = idx; +} + +__global__ void +merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + uint32_t P = 0; + + if (img[idx]) + P |= 0x777; + if (row + 1 < H && img[idx + W]) + P |= 0x777 << 4; + if (col + 1 < W && img[idx + 1]) + P |= 0x777 << 1; + + if (col == 0) + P &= 0xEEEE; + if (col + 1 >= W) + P &= 0x3333; + else if (col + 2 >= W) + P &= 0x7777; + + if (row == 0) + P &= 0xFFF0; + if (row + 1 >= H) + P &= 0xFF; + + if (P > 0) { + // If need check about top-left pixel(if flag the first bit) and hit the + // top-left pixel + if (hasBit(P, 0) && img[idx - W - 1]) { + union_(label, idx, idx - 2 * W - 2); // top left block + } + + if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1])) + union_(label, idx, idx - 2 * W); // top bottom block + + if (hasBit(P, 3) && img[idx + 2 - W]) + union_(label, idx, idx - 2 * W + 2); // top right block + + if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1])) + union_(label, idx, idx - 2); // just left block + } +} + +__global__ void compression(int32_t* label, const int32_t W, const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row < H && col < W) + find_n_compress(label, idx); +} + +__global__ void final_labeling( + const uint8_t* img, + int32_t* label, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx] + 1; + + if (img[idx]) + label[idx] = y; + else + label[idx] = 0; + + if (col + 1 < W) { + if (img[idx + 1]) + label[idx + 1] = y; + else + label[idx + 1] = 0; + + if (row + 1 < H) { + if (img[idx + W + 1]) + label[idx + W + 1] = y; + else + label[idx + W + 1] = 0; + } + } + + if (row + 1 < H) { + if (img[idx + W]) + label[idx + W] = y; + else + label[idx + W] = 0; + } +} + +__global__ void init_counting( + const int32_t* label, + int32_t* count_init, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx]; + if (y > 0) { + int32_t count_idx = y - 1; + atomicAdd(count_init + count_idx, 1); + } +} + +__global__ void final_counting( + const int32_t* label, + const int32_t* count_init, + int32_t* count_final, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx]; + if (y > 0) { + int32_t count_idx = y - 1; + count_final[idx] = count_init[count_idx]; + } else { + count_final[idx] = 0; + } +} + +} // namespace cc2d + +std::vector get_connected_componnets( + const torch::Tensor& inputs) { + AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor"); + AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape"); + AT_ASSERTM( + inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type"); + + const uint32_t N = inputs.size(0); + const uint32_t C = inputs.size(1); + const uint32_t H = inputs.size(2); + const uint32_t W = inputs.size(3); + + AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape"); + AT_ASSERTM((H % 2) == 0, "height must be an even number"); + AT_ASSERTM((W % 2) == 0, "width must be an even number"); + + // label must be uint32_t + auto label_options = + torch::TensorOptions().dtype(torch::kInt32).device(inputs.device()); + torch::Tensor labels = torch::zeros({N, C, H, W}, label_options); + torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options); + torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options); + + dim3 grid = dim3( + ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS, + ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS); + dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS); + dim3 grid_count = + dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS); + dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + for (int n = 0; n < N; n++) { + uint32_t offset = n * H * W; + + cc2d::init_labeling<<>>( + labels.data_ptr() + offset, W, H); + cc2d::merge<<>>( + inputs.data_ptr() + offset, + labels.data_ptr() + offset, + W, + H); + cc2d::compression<<>>( + labels.data_ptr() + offset, W, H); + cc2d::final_labeling<<>>( + inputs.data_ptr() + offset, + labels.data_ptr() + offset, + W, + H); + + // get the counting of each pixel + cc2d::init_counting<<>>( + labels.data_ptr() + offset, + counts_init.data_ptr() + offset, + W, + H); + cc2d::final_counting<<>>( + labels.data_ptr() + offset, + counts_init.data_ptr() + offset, + counts_final.data_ptr() + offset, + W, + H); + } + + // returned values are [labels, counts] + std::vector outputs; + outputs.push_back(labels); + outputs.push_back(counts_final); + return outputs; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "get_connected_componnets", + &get_connected_componnets, + "get_connected_componnets"); +} diff --git a/sam2/modeling/__init__.py b/sam2/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3 --- /dev/null +++ b/sam2/modeling/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2/modeling/backbones/__init__.py b/sam2/modeling/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3 --- /dev/null +++ b/sam2/modeling/backbones/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2/modeling/backbones/hieradet.py b/sam2/modeling/backbones/hieradet.py new file mode 100644 index 0000000000000000000000000000000000000000..f8dea37b8dbc6cd7660e27faa6a855c1c926adbe --- /dev/null +++ b/sam2/modeling/backbones/hieradet.py @@ -0,0 +1,295 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sam2.modeling.backbones.utils import ( + PatchEmbed, + window_partition, + window_unpartition, +) + +from sam2.modeling.sam2_utils import DropPath, MLP + + +def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + ): + super().__init__() + + self.dim = dim + self.dim_out = dim_out + + self.num_heads = num_heads + head_dim = dim_out // num_heads + self.scale = head_dim**-0.5 + + self.q_pool = q_pool + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = torch.unbind(qkv, 2) + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = q.reshape(B, H * W, self.num_heads, -1) + + # Torch's SDPA expects [B, nheads, H*W, C] so we transpose + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) + # Transpose back + x = x.transpose(1, 2) + x = x.reshape(B, H, W, -1) + + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: Union[nn.Module, str] = "LayerNorm", + q_stride: Tuple[int, int] = None, + act_layer: nn.Module = nn.GELU, + window_size: int = 0, + ): + super().__init__() + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) + + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + + self.window_size = window_size + + self.pool, self.q_stride = None, q_stride + if self.q_stride: + self.pool = nn.MaxPool2d( + kernel_size=q_stride, stride=q_stride, ceil_mode=False + ) + + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = MLP( + dim_out, + int(dim_out * mlp_ratio), + dim_out, + num_layers=2, + activation=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x # B, H, W, C + x = self.norm1(x) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + if window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + H, W = shortcut.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Hiera(nn.Module): + """ + Reference: https://arxiv.org/abs/2306.00989 + """ + + def __init__( + self, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + drop_path_rate: float = 0.0, # stochastic depth + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + return_interm_layers=True, # return feats from every stage + ): + super().__init__() + + assert len(stages) == len(window_spec) + self.window_spec = window_spec + + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + self.return_interm_layers = return_interm_layers + + self.patch_embed = PatchEmbed( + embed_dim=embed_dim, + ) + # Which blocks have global att? + self.global_att_blocks = global_att_blocks + + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) + ) + self.pos_embed_window = nn.Parameter( + torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) + ) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + cur_stage = 1 + self.blocks = nn.ModuleList() + + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] + + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + drop_path=dpr[i], + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if return_interm_layers + else [self.blocks[-1].dim_out] + ) + + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile( + [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] + ) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.patch_embed(x) + # x: (B, H, W, C) + + # Add pos embed + x = x + self._get_pos_embed(x.shape[1:3]) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or ( + i in self.stage_ends and self.return_interm_layers + ): + feats = x.permute(0, 3, 1, 2) + outputs.append(feats) + + return outputs diff --git a/sam2/modeling/backbones/image_encoder.py b/sam2/modeling/backbones/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5625fd33ab4b49904ef6056c42f26afd2ad8f1ad --- /dev/null +++ b/sam2/modeling/backbones/image_encoder.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ImageEncoder(nn.Module): + def __init__( + self, + trunk: nn.Module, + neck: nn.Module, + scalp: int = 0, + ): + super().__init__() + self.trunk = trunk + self.neck = neck + self.scalp = scalp + assert ( + self.trunk.channel_list == self.neck.backbone_channel_list + ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" + + def forward(self, sample: torch.Tensor): + # Forward through backbone + features, pos = self.neck(self.trunk(sample)) + if self.scalp > 0: + # Discard the lowest resolution features + features, pos = features[: -self.scalp], pos[: -self.scalp] + + src = features[-1] + output = { + "vision_features": src, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + return output + + +class FpnNeck(nn.Module): + """ + A modified variant of Feature Pyramid Network (FPN) neck + (we remove output conv and also do bicubic interpolation similar to ViT + pos embed interpolation) + """ + + def __init__( + self, + position_encoding: nn.Module, + d_model: int, + backbone_channel_list: List[int], + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + fpn_interp_model: str = "bilinear", + fuse_type: str = "sum", + fpn_top_down_levels: Optional[List[int]] = None, + ): + """Initialize the neck + :param trunk: the backbone + :param position_encoding: the positional encoding to use + :param d_model: the dimension of the model + :param neck_norm: the normalization to use + """ + super().__init__() + self.position_encoding = position_encoding + self.convs = nn.ModuleList() + self.backbone_channel_list = backbone_channel_list + for dim in backbone_channel_list: + current = nn.Sequential() + current.add_module( + "conv", + nn.Conv2d( + in_channels=dim, + out_channels=d_model, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + ) + + self.convs.append(current) + self.fpn_interp_model = fpn_interp_model + assert fuse_type in ["sum", "avg"] + self.fuse_type = fuse_type + + # levels to have top-down features in its outputs + # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 + # have top-down propagation, while outputs of level 0 and level 1 have only + # lateral features from the same backbone level. + if fpn_top_down_levels is None: + # default is to have top-down features on all levels + fpn_top_down_levels = range(len(self.convs)) + self.fpn_top_down_levels = list(fpn_top_down_levels) + + def forward(self, xs: List[torch.Tensor]): + + out = [None] * len(self.convs) + pos = [None] * len(self.convs) + assert len(xs) == len(self.convs) + # fpn forward pass + # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py + prev_features = None + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + x = xs[i] + lateral_features = self.convs[n - i](x) + if i in self.fpn_top_down_levels and prev_features is not None: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode=self.fpn_interp_model, + align_corners=( + None if self.fpn_interp_model == "nearest" else False + ), + antialias=False, + ) + prev_features = lateral_features + top_down_features + if self.fuse_type == "avg": + prev_features /= 2 + else: + prev_features = lateral_features + x_out = prev_features + out[i] = x_out + pos[i] = self.position_encoding(x_out).to(x_out.dtype) + + return out, pos diff --git a/sam2/modeling/backbones/utils.py b/sam2/modeling/backbones/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b7807b275256f83a83e5d1baa6c045ad6c124807 --- /dev/null +++ b/sam2/modeling/backbones/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Some utilities for backbones, in particular for windowing""" + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, ...] = (7, 7), + stride: Tuple[int, ...] = (4, 4), + padding: Tuple[int, ...] = (3, 3), + in_chans: int = 3, + embed_dim: int = 768, + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/sam2/modeling/memory_attention.py b/sam2/modeling/memory_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4adb5cf0a335f7013835dd31e3863b6b04e738 --- /dev/null +++ b/sam2/modeling/memory_attention.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import nn, Tensor + +from sam2.modeling.sam.transformer import RoPEAttention + +from sam2.modeling.sam2_utils import get_activation_fn, get_clones + + +class MemoryAttentionLayer(nn.Module): + + def __init__( + self, + activation: str, + cross_attention: nn.Module, + d_model: int, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + self_attention: nn.Module, + ): + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = self_attention + self.cross_attn_image = cross_attention + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation_str = activation + self.activation = get_activation_fn(activation) + + # Where to add pos enc + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + def _forward_sa(self, tgt, query_pos): + # Self-Attention + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, v=tgt2) + tgt = tgt + self.dropout1(tgt2) + return tgt + + def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): + kwds = {} + if num_k_exclude_rope > 0: + assert isinstance(self.cross_attn_image, RoPEAttention) + kwds = {"num_k_exclude_rope": num_k_exclude_rope} + + # Cross-Attention + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + v=memory, + **kwds, + ) + tgt = tgt + self.dropout2(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + num_k_exclude_rope: int = 0, + ) -> torch.Tensor: + + # Self-Attn, Cross-Attn + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + # MLP + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + +class MemoryAttention(nn.Module): + def __init__( + self, + d_model: int, + pos_enc_at_input: bool, + layer: nn.Module, + num_layers: int, + batch_first: bool = True, # Do layers expect batch first input? + ): + super().__init__() + self.d_model = d_model + self.layers = get_clones(layer, num_layers) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.batch_first = batch_first + + def forward( + self, + curr: torch.Tensor, # self-attention inputs + memory: torch.Tensor, # cross-attention inputs + curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + ): + if isinstance(curr, list): + assert isinstance(curr_pos, list) + assert len(curr) == len(curr_pos) == 1 + curr, curr_pos = ( + curr[0], + curr_pos[0], + ) + + assert ( + curr.shape[1] == memory.shape[1] + ), "Batch size must be the same for curr and memory" + + output = curr + if self.pos_enc_at_input and curr_pos is not None: + output = output + 0.1 * curr_pos + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + memory = memory.transpose(0, 1) + memory_pos = memory_pos.transpose(0, 1) + + for layer in self.layers: + kwds = {} + if isinstance(layer.cross_attn_image, RoPEAttention): + kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} + + output = layer( + tgt=output, + memory=memory, + pos=memory_pos, + query_pos=curr_pos, + **kwds, + ) + normed_output = self.norm(output) + + if self.batch_first: + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + + return normed_output diff --git a/sam2/modeling/memory_encoder.py b/sam2/modeling/memory_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..83f98a2544b225f5bdb6e9a046380b8df5887a30 --- /dev/null +++ b/sam2/modeling/memory_encoder.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d + + +class MaskDownSampler(nn.Module): + """ + Progressively downsample a mask by total_stride, each time by stride. + Note that LayerNorm is applied per *token*, like in ViT. + + With each downsample (by a factor stride**2), channel capacity increases by the same factor. + In the end, we linearly project to embed_dim channels. + """ + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=nn.GELU, + ): + super().__init__() + num_layers = int(math.log2(total_stride) // math.log2(stride)) + assert stride**num_layers == total_stride + self.encoder = nn.Sequential() + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (stride**2) + self.encoder.append( + nn.Conv2d( + mask_in_chans, + mask_out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation()) + mask_in_chans = mask_out_chans + + self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + + def forward(self, x): + return self.encoder(x) + + +# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) +class CXBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + ): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + groups=dim if use_dwconv else 1, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class Fuser(nn.Module): + def __init__(self, layer, num_layers, dim=None, input_projection=False): + super().__init__() + self.proj = nn.Identity() + self.layers = get_clones(layer, num_layers) + + if input_projection: + assert dim is not None + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + + def forward(self, x): + # normally x: (N, C, H, W) + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class MemoryEncoder(nn.Module): + def __init__( + self, + out_dim, + mask_downsampler, + fuser, + position_encoding, + in_dim=256, # in_dim of pix_feats + ): + super().__init__() + + self.mask_downsampler = mask_downsampler + + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.fuser = fuser + self.position_encoding = position_encoding + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward( + self, + pix_feat: torch.Tensor, + masks: torch.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ## Process masks + # sigmoid, so that less domain shift from gt masks which are bool + if not skip_mask_sigmoid: + masks = F.sigmoid(masks) + masks = self.mask_downsampler(masks) + + ## Fuse pix_feats and downsampled masks + # in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) + + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = self.position_encoding(x).to(x.dtype) + + return {"vision_features": x, "vision_pos_enc": [pos]} diff --git a/sam2/modeling/position_encoding.py b/sam2/modeling/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..8f41cf91739001ccedbd61e174df8d661310aee1 --- /dev/null +++ b/sam2/modeling/position_encoding.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional, Tuple + +import numpy as np + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + ): + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + self.cache = {} + + def _encode_xy(self, x, y): + # The positions are expected to be normalized + assert len(x) == len(y) and x.ndim == y.ndim == 1 + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack( + (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 + ).flatten(1) + pos_y = torch.stack( + (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 + ).flatten(1) + return pos_x, pos_y + + @torch.no_grad() + def encode_boxes(self, x, y, w, h): + pos_x, pos_y = self._encode_xy(x, y) + pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + return pos + + encode = encode_boxes # Backwards compatibility + + @torch.no_grad() + def encode_points(self, x, y, labels): + (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape + assert bx == by and nx == ny and bx == bl and nx == nl + pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) + pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) + return pos + + @torch.no_grad() + def forward(self, x: torch.Tensor): + cache_key = (x.shape[-2], x.shape[-1]) + if cache_key in self.cache: + return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + y_embed = ( + torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + .view(1, -1, 1) + .repeat(x.shape[0], 1, x.shape[-1]) + ) + x_embed = ( + torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + .view(1, 1, -1) + .repeat(x.shape[0], x.shape[-2], 1) + ) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + self.cache[cache_key] = pos[0] + return pos + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +# Rotary Positional Encoding, adapted from: +# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py +# 2. https://github.com/naver-ai/rope-vit +# 3. https://github.com/lucidrains/rotary-embedding-torch + + +def init_t_xy(end_x: int, end_y: int): + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = ( + torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + if xk.shape[-2] != 0 + else None + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/sam2/modeling/sam/__init__.py b/sam2/modeling/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3 --- /dev/null +++ b/sam2/modeling/sam/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2/modeling/sam/mask_decoder.py b/sam2/modeling/sam/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..19a45f49a294f72e39cd7006eeb1ca91a4266c94 --- /dev/null +++ b/sam2/modeling/sam/mask_decoder.py @@ -0,0 +1,295 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn + +from sam2.modeling.sam2_utils import LayerNorm2d, MLP + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(1, transformer_dim) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d( + transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 + ) + self.conv_s1 = nn.Conv2d( + transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 + ) + + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid_output=iou_prediction_use_sigmoid, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + torch.Tensor: batched SAM token for mask output + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert ( + image_pe.size(0) == 1 + ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange( + multimask_iou_scores.size(0), device=all_iou_scores.device + ) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out diff --git a/sam2/modeling/sam/prompt_encoder.py b/sam2/modeling/sam/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..91d9952ca8078bedd04fdc2ea0d900529e432528 --- /dev/null +++ b/sam2/modeling/sam/prompt_encoder.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, Type + +import torch +from torch import nn + +from sam2.modeling.position_encoding import PositionEmbeddingRandom + +from sam2.modeling.sam2_utils import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [ + nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) + ] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = ( + 4 * image_embedding_size[0], + 4 * image_embedding_size[1], + ) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + point_embedding[labels == 2] += self.point_embeddings[2].weight + point_embedding[labels == 3] += self.point_embeddings[3].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords( + coords, self.input_image_size + ) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty( + (bs, 0, self.embed_dim), device=self._get_device() + ) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings diff --git a/sam2/modeling/sam/transformer.py b/sam2/modeling/sam/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..9f5e11808b6dc01b48fec292ac8fddddf7876ffe --- /dev/null +++ b/sam2/modeling/sam/transformer.py @@ -0,0 +1,329 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import warnings +from functools import partial +from typing import Tuple, Type + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis + +from sam2.modeling.sam2_utils import MLP +from sam2.utils.misc import get_sdpa_settings + +warnings.simplefilter(action="ignore", category=FutureWarning) +OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLP( + embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation + ) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + dropout: float = 0.0, + kv_in_dim: int = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + self.dropout_p = dropout + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + # with torch.backends.cuda.sdp_kernel( + # enable_flash=USE_FLASH_ATTN, + # # if Flash attention kernel is off, then math kernel needs to be enabled + # enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, + # enable_mem_efficient=OLD_GPU, + # ): + # out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class RoPEAttention(Attention): + """Attention with rotary position encoding.""" + + def __init__( + self, + *args, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.compute_cis = partial( + compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta + ) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 + ) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + # with torch.backends.cuda.sdp_kernel( + # enable_flash=USE_FLASH_ATTN, + # # if Flash attention kernel is off, then math kernel needs to be enabled + # enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, + # enable_mem_efficient=OLD_GPU, + # ): + # out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/sam2/modeling/sam2_base.py b/sam2/modeling/sam2_base.py new file mode 100644 index 0000000000000000000000000000000000000000..1d9014673b4f8f057cc95bd3cb69e74237c95c98 --- /dev/null +++ b/sam2/modeling/sam2_base.py @@ -0,0 +1,829 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed +import torch.nn.functional as F + +from torch.nn.init import trunc_normal_ + +from sam2.modeling.sam.mask_decoder import MaskDecoder +from sam2.modeling.sam.prompt_encoder import PromptEncoder +from sam2.modeling.sam.transformer import TwoWayTransformer +from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +class SAM2Base(torch.nn.Module): + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, # default 1 input frame + 6 previous frames + image_size=512, + backbone_stride=16, # stride of the image backbone output + sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob + sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob + # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder + # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, + # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model + # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. + max_cond_frames_in_attn=-1, + # on the first frame, whether to directly add the no-memory embedding to the image feature + # (instead of using the transformer encoder) + directly_add_no_mem_embed=False, + # whether to use high-resolution feature maps in the SAM mask decoder + use_high_res_features_in_sam=False, + # whether to output multiple (3) masks for the first click on initial conditioning frames + multimask_output_in_sam=False, + # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; + # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) + multimask_min_pt_num=1, + multimask_max_pt_num=1, + # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) + multimask_output_for_tracking=False, + # Whether to use multimask tokens for obj ptr; Only relevant when both + # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True + use_multimask_token_for_obj_ptr: bool = False, + # whether to use sigmoid to restrict ious prediction to [0-1] + iou_prediction_use_sigmoid=False, + # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). + # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of + # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. + memory_temporal_stride_for_eval=1, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) + non_overlap_masks_for_mem_enc=False, + # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder=False, + # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) + max_obj_ptrs_in_encoder=16, + # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) + add_tpos_enc_to_obj_ptrs=True, + # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference + # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + proj_tpos_enc_in_obj_ptrs=False, + # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation + # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) + only_obj_ptrs_in_the_past_for_eval=False, + # Whether to predict if there is an object in the frame + pred_obj_scores: bool = False, + # Whether to use an MLP to predict object scores + pred_obj_scores_mlp: bool = False, + # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; + # Whether to have a fixed no obj pointer when there is no object present + # or to use it as an additive embedding with obj_ptr produced by decoder + fixed_no_obj_ptr: bool = False, + # Soft no object, i.e. mix in no_obj_ptr softly, + # hope to make recovery easier if there is a mistake and mitigate accumulation of errors + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. + sam_mask_decoder_extra_args=None, + compile_image_encoder: bool = False, + ): + super().__init__() + + # Part 1: the image backbone + self.image_encoder = image_encoder + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + + # Part 2: memory attention to condition current frame's visual features + # with memories (and obj ptrs) from past frames + self.memory_attention = memory_attention + self.hidden_dim = memory_attention.d_model + + # Part 3: memory encoder for the previous frame's outputs + self.memory_encoder = memory_encoder + self.mem_dim = self.hidden_dim + if hasattr(self.memory_encoder, "out_proj") and hasattr( + self.memory_encoder.out_proj, "weight" + ): + # if there is compression of memories along channel dim + self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter( + torch.zeros(num_maskmem, 1, 1, self.mem_dim) + ) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.soft_no_obj_ptr = soft_no_obj_ptr + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + + self._build_sam_heads() + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + self.max_cond_frames_in_attn = max_cond_frames_in_attn + + # Model compilation + if compile_image_encoder: + # Compile the forward function (not the full module) to allow loading checkpoints. + print( + "Image encoder compilation is enabled. First forward pass will be slow." + ) + self.image_encoder.forward = torch.compile( + self.image_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "Please use the corresponding methods in SAM2VideoPredictor for inference." + "See notebooks/video_predictor_example.ipynb for an example." + ) + + def _build_sam_heads(self): + """Build SAM-style prompt encoder and mask decoder.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + # build PromptEncoder and MaskDecoder from SAM + # (their hyperparameters like `mask_in_chans=16` are from SAM code) + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + self.sam_mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = MLP( + self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 + ) + else: + self.obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Forward SAM prompt encoders and mask heads. + + Inputs: + - backbone_features: image features of [B, C, H, W] shape + - point_inputs: a dictionary with "point_coords" and "point_labels", where + 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the + absolute pixel-unit coordinate in (x, y) format of the P input points + 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means + positive clicks, 0 means negative clicks, and -1 means padding + - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the + same spatial size as the image. + - high_res_features: either 1) None or 2) or a list of length 2 containing + two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, + which will be used as high-resolution feature maps for SAM decoder. + - multimask_output: if it's True, we output 3 candidate masks and their 3 + corresponding IoU estimates, and if it's False, we output only 1 mask and + its corresponding IoU estimate. + + Outputs: + - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if + `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM + output mask logits (before sigmoid) for the low-resolution masks, with 4x + the resolution (1/4 stride) of the input backbone_features. + - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 + if `multimask_output=True` and M = 1 if `multimask_output=False`), + upsampled from the low-resolution masks, with shape size as the image + (stride is 1 pixel). + - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 + if `multimask_output=False`), the estimated IoU of each output mask. + - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `low_res_multimasks`. + - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `high_res_multimasks`. + - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted + based on the output token from the SAM mask decoder. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=self.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + # Only hard possible with gt + assert not self.teacher_force_obj_scores_for_mem + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """ + Directly turn binary `mask_inputs` into a output mask logits without using SAM. + (same input and output shapes as in _forward_sam_heads above). + """ + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() + if not self.use_obj_ptrs_in_encoder: + # all zeros as a dummy object pointer (of shape [B, C]) + obj_ptr = torch.zeros( + mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device + ) + else: + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs_float), + high_res_features=high_res_features, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + if self.pred_obj_scores: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def forward_image(self, img_batch: torch.Tensor): + """Get the image feature on the input batch.""" + backbone_out = self.image_encoder(img_batch) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0] + ) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1] + ) + return backbone_out + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + ): + """Fuse the current frame's visual feature map with previous memory.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + num_obj_ptr_tokens = 0 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_memory, to_cat_memory_pos_embed = [], [] + # Add conditioning frames's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, cond_outputs, self.max_cond_frames_in_attn + ) + t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with r>1), in which case + # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. + r = self.memory_temporal_stride_for_eval + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + if not track_in_reverse: + # the frame immediately before this frame (i.e. frame_idx - 1) + prev_frame_idx = frame_idx - t_rel + else: + # the frame immediately after this frame (i.e. frame_idx + 1) + prev_frame_idx = frame_idx + t_rel + else: + # for t_rel >= 2, we take the memory frame from every r-th frames + if not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * r + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * r + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out)) + + for t_pos, prev in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to GPU (it's a no-op if it's already on GPU). + feats = prev["maskmem_features"].cuda(non_blocking=True) + to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].cuda() + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + # Temporal positional encoding + maskmem_enc = ( + maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + ) + to_cat_memory_pos_embed.append(maskmem_enc) + + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + (abs(frame_idx - t), out["obj_ptr"]) + for t, out in ptr_cond_outputs.items() + ] + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + out = output_dict["non_cond_frame_outputs"].get( + t, unselected_cond_outputs.get(t, None) + ) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"])) + # If we have at least one object pointer, add them to the across attention + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + if self.add_tpos_enc_to_obj_ptrs: + t_diff_max = max_obj_ptrs_in_encoder - 1 + tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + obj_pos = torch.tensor(pos_list, device=device) + obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) + obj_pos = self.obj_ptr_tpos_proj(obj_pos) + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) + else: + obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape( + -1, B, C // self.mem_dim, self.mem_dim + ) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + if self.directly_add_no_mem_embed: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + # Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder) + to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + memory = torch.cat(to_cat_memory, dim=0) + memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats, + curr_pos=current_vision_pos_embeds, + memory=memory, + memory_pos=memory_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + is_mask_from_pts, + ): + """Encode the current image and its prediction into a memory feature.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied + ) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + + return maskmem_features, maskmem_pos_enc + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + _, + ) = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + return current_out + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Whether to use multimask output in the SAM head.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + return multimask_output + + def _apply_non_overlapping_constraints(self, pred_masks): + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks diff --git a/sam2/modeling/sam2_utils.py b/sam2/modeling/sam2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b28713e3ea2b14da57798b2a8030d4882bb98221 --- /dev/null +++ b/sam2/modeling/sam2_utils.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` + that are temporally closest to the current frame at `frame_idx`. Here, we take + - a) the closest conditioning frame before `frame_idx` (if any); + - b) the closest conditioning frame after `frame_idx` (if any); + - c) any other temporally closest conditioning frames until reaching a total + of `max_cond_frame_num` conditioning frames. + + Outputs: + - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. + - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = { + t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs + } + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DropPath(nn.Module): + # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py + def __init__(self, drop_prob=0.0, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: nn.Module = nn.ReLU, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + self.act = activation() + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..89cb0468e1fcb7e3b1e359949f5d5a0c26c0c0e2 --- /dev/null +++ b/sam2/sam2_image_predictor.py @@ -0,0 +1,446 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL.Image import Image + +from sam2.modeling.sam2_base import SAM2Base + +from sam2.utils.transforms import SAM2Transforms + + +class SAM2ImagePredictor: + def __init__( + self, + sam_model: SAM2Base, + mask_threshold=0.0, + max_hole_area=0.0, + max_sprinkle_area=0.0, + ) -> None: + """ + Uses SAM-2 to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam-2): The model to use for mask prediction. + mask_threshold (float): The threshold to use when converting mask logits + to binary masks. Masks are thresholded at 0 by default. + fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to + the maximum area of fill_hole_area in low_res_masks. + """ + super().__init__() + self.model = sam_model + self._transforms = SAM2Transforms( + resolution=self.model.image_size, + mask_threshold=mask_threshold, + max_hole_area=max_hole_area, + max_sprinkle_area=max_sprinkle_area, + ) + + # Predictor state + self._is_image_set = False + self._features = None + self._orig_hw = None + # Whether the predictor is set for single image or a batch of images + self._is_batch = False + + # Predictor config + self.mask_threshold = mask_threshold + + # Spatial dim for backbone feature maps + self._bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + @torch.no_grad() + def set_image( + self, + image: Union[np.ndarray, Image], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image + with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + self.reset_predictor() + # Transform the image to the form expected by the model + if isinstance(image, np.ndarray): + logging.info("For numpy array image, we assume (HxWxC) format") + self._orig_hw = [image.shape[:2]] + elif isinstance(image, Image): + w, h = image.size + self._orig_hw = [(h, w)] + else: + raise NotImplementedError("Image format not supported") + + input_image = self._transforms(image) + input_image = input_image[None, ...].to(self.device) + + assert ( + len(input_image.shape) == 4 and input_image.shape[1] == 3 + ), f"input_image must be of size 1x3xHxW, got {input_image.shape}" + logging.info("Computing image embeddings for the provided image...") + backbone_out = self.model.forward_image(input_image) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + logging.info("Image embeddings computed.") + + @torch.no_grad() + def set_image_batch( + self, + image_list: List[Union[np.ndarray]], + ) -> None: + """ + Calculates the image embeddings for the provided image batch, allowing + masks to be predicted with the 'predict_batch' method. + + Arguments: + image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray + with pixel values in [0, 255]. + """ + self.reset_predictor() + assert isinstance(image_list, list) + self._orig_hw = [] + for image in image_list: + assert isinstance( + image, np.ndarray + ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC" + self._orig_hw.append(image.shape[:2]) + # Transform the image to the form expected by the model + img_batch = self._transforms.forward_batch(image_list) + img_batch = img_batch.to(self.device) + batch_size = img_batch.shape[0] + assert ( + len(img_batch.shape) == 4 and img_batch.shape[1] == 3 + ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}" + logging.info("Computing image embeddings for the provided images...") + backbone_out = self.model.forward_image(img_batch) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + self._is_batch = True + logging.info("Image embeddings computed.") + + def predict_batch( + self, + point_coords_batch: List[np.ndarray] = None, + point_labels_batch: List[np.ndarray] = None, + box_batch: List[np.ndarray] = None, + mask_input_batch: List[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: + """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images. + It returns a tupele of lists of masks, ious, and low_res_masks_logits. + """ + assert self._is_batch, "This function should only be used when in batched mode" + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image_batch(...) before mask prediction." + ) + num_images = len(self._features["image_embed"]) + all_masks = [] + all_ious = [] + all_low_res_masks = [] + for img_idx in range(num_images): + # Transform input prompts + point_coords = ( + point_coords_batch[img_idx] if point_coords_batch is not None else None + ) + point_labels = ( + point_labels_batch[img_idx] if point_labels_batch is not None else None + ) + box = box_batch[img_idx] if box_batch is not None else None + mask_input = ( + mask_input_batch[img_idx] if mask_input_batch is not None else None + ) + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, + point_labels, + box, + mask_input, + normalize_coords, + img_idx=img_idx, + ) + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + img_idx=img_idx, + ) + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = ( + iou_predictions.squeeze(0).float().detach().cpu().numpy() + ) + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + all_masks.append(masks_np) + all_ious.append(iou_predictions_np) + all_low_res_masks.append(low_res_masks_np) + + return all_masks, all_ious, all_low_res_masks + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + # Transform input prompts + + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, point_labels, box, mask_input, normalize_coords + ) + + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy() + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + def _prep_prompts( + self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1 + ): + + unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = torch.as_tensor( + point_coords, dtype=torch.float, device=self.device + ) + unnorm_coords = self._transforms.transform_coords( + point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) + labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + if len(unnorm_coords.shape) == 2: + unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...] + if box is not None: + box = torch.as_tensor(box, dtype=torch.float, device=self.device) + unnorm_box = self._transforms.transform_boxes( + box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) # Bx2x2 + if mask_logits is not None: + mask_input = torch.as_tensor( + mask_logits, dtype=torch.float, device=self.device + ) + if len(mask_input.shape) == 3: + mask_input = mask_input[None, :, :, :] + return mask_input, unnorm_coords, labels, unnorm_box + + @torch.no_grad() + def _predict( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + img_idx: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using SAM2Transforms. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + if point_coords is not None: + concat_points = (point_coords, point_labels) + else: + concat_points = None + + # Embed prompts + if boxes is not None: + box_coords = boxes.reshape(-1, 2, 2) + box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device) + box_labels = box_labels.repeat(boxes.size(0), 1) + # we merge "boxes" and "points" into a single "concat_points" input (where + # boxes are added at the beginning) to sam_prompt_encoder + if concat_points is not None: + concat_coords = torch.cat([box_coords, concat_points[0]], dim=1) + concat_labels = torch.cat([box_labels, concat_points[1]], dim=1) + concat_points = (concat_coords, concat_labels) + else: + concat_points = (box_coords, box_labels) + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=concat_points, + boxes=None, + masks=mask_input, + ) + + # Predict masks + batched_mode = ( + concat_points is not None and concat_points[0].shape[0] > 1 + ) # multi object prediction + high_res_features = [ + feat_level[img_idx].unsqueeze(0) + for feat_level in self._features["high_res_feats"] + ] + low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( + image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + + # Upscale the masks to the original image resolution + masks = self._transforms.postprocess_masks( + low_res_masks, self._orig_hw[img_idx] + ) + low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0) + if not return_logits: + masks = masks > self.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert ( + self._features is not None + ), "Features must exist if an image has been set." + return self._features["image_embed"] + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_predictor(self) -> None: + """ + Resets the image embeddings and other state variables. + """ + self._is_image_set = False + self._features = None + self._orig_hw = None + self._is_batch = False diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..6fa492a0e092f24dcbeed27d4a206ebcc1fa5aa8 --- /dev/null +++ b/sam2/sam2_video_predictor.py @@ -0,0 +1,898 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import torch + +from tqdm import tqdm + +from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base +from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames + + +class SAM2VideoPredictor(SAM2Base): + """The predictor class to handle user interactions and manage inference states.""" + + def __init__( + self, + fill_hole_area=0, + # whether to apply non-overlapping constraints on the output object masks + non_overlap_masks=False, + # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; + # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) + clear_non_cond_mem_around_input=False, + # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). + clear_non_cond_mem_for_multi_obj=False, + **kwargs, + ): + super().__init__(**kwargs) + self.fill_hole_area = fill_hole_area + self.non_overlap_masks = non_overlap_masks + self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input + self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj + + @torch.inference_mode() + def init_state( + self, + video_path, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + async_loading_frames=False, + ): + """Initialize a inference state.""" + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = torch.device("cuda") + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = torch.device("cuda") + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state + + def _obj_id_to_idx(self, inference_state, obj_id): + """Map client-side object id to model-side object index.""" + obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(inference_state["obj_id_to_idx"]) + inference_state["obj_id_to_idx"][obj_id] = obj_idx + inference_state["obj_idx_to_id"][obj_idx] = obj_id + inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + inference_state["point_inputs_per_obj"][obj_idx] = {} + inference_state["mask_inputs_per_obj"][obj_idx] = {} + inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _obj_idx_to_id(self, inference_state, obj_idx): + """Map model-side object index to client-side object id.""" + return inference_state["obj_idx_to_id"][obj_idx] + + def _get_obj_num(self, inference_state): + """Get the total number of unique object ids received so far in this session.""" + return len(inference_state["obj_idx_to_id"]) + + @torch.inference_mode() + def add_new_points( + self, + inference_state, + frame_idx, + obj_id, + points, + labels, + clear_old_points=True, + normalize_coords=True, + ): + """Add new points to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if not isinstance(points, torch.Tensor): + points = torch.tensor(points, dtype=torch.float32) + if not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, dtype=torch.int32) + if points.dim() == 2: + points = points.unsqueeze(0) # add batch dimension + if labels.dim() == 1: + labels = labels.unsqueeze(0) # add batch dimension + if normalize_coords: + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + points = points / torch.tensor([video_W, video_H]).to(points.device) + # scale the (normalized) coordinates by the model's internal image size + points = points * self.image_size + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + + if not clear_old_points: + point_inputs = point_inputs_per_frame.get(frame_idx, None) + else: + point_inputs = None + point_inputs = concat_points(point_inputs, points, labels) + + point_inputs_per_frame[frame_idx] = point_inputs + mask_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + prev_out = obj_temp_output_dict[storage_key].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + + if prev_out is not None and prev_out["pred_masks"] is not None: + prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=None, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + @torch.inference_mode() + def add_new_mask( + self, + inference_state, + frame_idx, + obj_id, + mask, + ): + """Add new mask to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if not isinstance(mask, torch.Tensor): + mask = torch.tensor(mask, dtype=torch.bool) + assert mask.dim() == 2 + mask_H, mask_W = mask.shape + mask_inputs_orig = mask[None, None] # add batch and channel dimension + mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) + + # resize the mask if it doesn't match the model's image size + if mask_H != self.image_size or mask_W != self.image_size: + mask_inputs = torch.nn.functional.interpolate( + mask_inputs_orig, + size=(self.image_size, self.image_size), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + mask_inputs = (mask_inputs >= 0.5).float() + else: + mask_inputs = mask_inputs_orig + + mask_inputs_per_frame[frame_idx] = mask_inputs + point_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=None, + mask_inputs=mask_inputs, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def _get_orig_video_res_output(self, inference_state, any_res_masks): + """ + Resize the object scores to the original video resolution (video_res_masks) + and apply non-overlapping constraints for final output. + """ + device = inference_state["device"] + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + any_res_masks = any_res_masks.to(device, non_blocking=True) + if any_res_masks.shape[-2:] == (video_H, video_W): + video_res_masks = any_res_masks + else: + video_res_masks = torch.nn.functional.interpolate( + any_res_masks, + size=(video_H, video_W), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks: + video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) + return any_res_masks, video_res_masks + + def _consolidate_temp_output_across_obj( + self, + inference_state, + frame_idx, + is_cond, + run_mem_encoder, + consolidate_at_video_res=False, + ): + """ + Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on + a frame into a single output for all objects, including + 1) fill any missing objects either from `output_dict_per_obj` (if they exist in + `output_dict_per_obj` for this frame) or leave them as placeholder values + (if they don't exist in `output_dict_per_obj` for this frame); + 2) if specified, rerun memory encoder after apply non-overlapping constraints + on the object scores. + """ + batch_size = self._get_obj_num(inference_state) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Optionally, we allow consolidating the temporary outputs at the original + # video resolution (to provide a better editing experience for mask prompts). + if consolidate_at_video_res: + assert not run_mem_encoder, "memory encoder cannot run at video resolution" + consolidated_H = inference_state["video_height"] + consolidated_W = inference_state["video_width"] + consolidated_mask_key = "pred_masks_video_res" + else: + consolidated_H = consolidated_W = self.image_size // 4 + consolidated_mask_key = "pred_masks" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + consolidated_mask_key: torch.full( + size=(batch_size, 1, consolidated_H, consolidated_W), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["storage_device"], + ), + "obj_ptr": torch.full( + size=(batch_size, self.hidden_dim), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["device"], + ), + } + empty_mask_ptr = None + for obj_idx in range(batch_size): + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_temp_output_dict[storage_key].get(frame_idx, None) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + if out is None: + out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) + if out is None: + out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + if empty_mask_ptr is None: + empty_mask_ptr = self._get_empty_mask_ptr( + inference_state, frame_idx + ) + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr + continue + # Add the temporary object output mask to consolidated output mask + obj_mask = out["pred_masks"] + consolidated_pred_masks = consolidated_out[consolidated_mask_key] + if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: + consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask + else: + # Resize first if temporary object mask has a different resolution + resized_obj_mask = torch.nn.functional.interpolate( + obj_mask, + size=consolidated_pred_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ) + consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + + # Optionally, apply non-overlapping constraints on the consolidated scores + # and rerun the memory encoder + if run_mem_encoder: + device = inference_state["device"] + high_res_masks = torch.nn.functional.interpolate( + consolidated_out["pred_masks"].to(device, non_blocking=True), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks_for_mem_enc: + high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) + maskmem_features, maskmem_pos_enc = self._run_memory_encoder( + inference_state=inference_state, + frame_idx=frame_idx, + batch_size=batch_size, + high_res_masks=high_res_masks, + is_mask_from_pts=True, # these frames are what the user interacted with + ) + consolidated_out["maskmem_features"] = maskmem_features + consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc + + return consolidated_out + + def _get_empty_mask_ptr(self, inference_state, frame_idx): + """Get a dummy object pointer based on an empty mask on the current frame.""" + # A dummy (empty) mask with a single object + batch_size = 1 + mask_inputs = torch.zeros( + (batch_size, 1, self.image_size, self.image_size), + dtype=torch.float32, + device=inference_state["device"], + ) + + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + mask_inputs=mask_inputs, + output_dict={}, + num_frames=inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + @torch.inference_mode() + def propagate_in_video_preflight(self, inference_state): + """Prepare inference_state and consolidate temporary outputs before tracking.""" + # Tracking has started and we don't allow adding new objects until session is reset. + inference_state["tracking_has_started"] = True + batch_size = self._get_obj_num(inference_state) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + output_dict = inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + for is_cond in [False, True]: + # Separately consolidate conditioning and non-conditioning temp outptus + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temprary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object( + inference_state, frame_idx, consolidated_out, storage_key + ) + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] + | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + ): + """Propagate the input points across frames to track in the entire video.""" + self.propagate_in_video_preflight(inference_state) + + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + obj_ids = inference_state["obj_ids"] + num_frames = inference_state["num_frames"] + batch_size = self._get_obj_num(inference_state) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + + # set start index, end index, and processing order + if start_frame_idx is None: + # default: start from the earliest frame with input points + start_frame_idx = min(output_dict["cond_frame_outputs"]) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + processing_order = [] # skip reverse tracking if starting from frame 0 + else: + end_frame_idx = min( + start_frame_idx + max_frame_num_to_track, num_frames - 1 + ) + processing_order = range(start_frame_idx, end_frame_idx + 1) + + for frame_idx in tqdm(processing_order, desc="propagate in video"): + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + else: + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=output_dict, + frame_idx=frame_idx, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=True, + ) + output_dict[storage_key][frame_idx] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object( + inference_state, frame_idx, current_out, storage_key + ) + inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, pred_masks + ) + yield frame_idx, obj_ids, video_res_masks + + def _add_output_per_object( + self, inference_state, frame_idx, current_out, storage_key + ): + """ + Split a multi-object output into per-object output slices and add them into + `output_dict_per_obj`. The resulting slices share the same tensor storage. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + output_dict_per_obj = inference_state["output_dict_per_obj"] + for obj_idx, obj_output_dict in output_dict_per_obj.items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + @torch.inference_mode() + def reset_state(self, inference_state): + """Remove all input points or mask in all frames throughout the video.""" + self._reset_tracking_results(inference_state) + # Remove all object ids + inference_state["obj_id_to_idx"].clear() + inference_state["obj_idx_to_id"].clear() + inference_state["obj_ids"].clear() + inference_state["point_inputs_per_obj"].clear() + inference_state["mask_inputs_per_obj"].clear() + inference_state["output_dict_per_obj"].clear() + inference_state["temp_output_dict_per_obj"].clear() + + def _reset_tracking_results(self, inference_state): + """Reset all tracking inputs and results across the videos.""" + for v in inference_state["point_inputs_per_obj"].values(): + v.clear() + for v in inference_state["mask_inputs_per_obj"].values(): + v.clear() + for v in inference_state["output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + for v in inference_state["temp_output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + inference_state["output_dict"]["cond_frame_outputs"].clear() + inference_state["output_dict"]["non_cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"].clear() + + def _get_image_feature(self, inference_state, frame_idx, batch_size): + """Compute the image features on a given frame.""" + # Look up in the cache first + image, backbone_out = inference_state["cached_features"].get( + frame_idx, (None, None) + ) + if backbone_out is None: + # Cache miss -- we will run inference on a single image + image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0) + backbone_out = self.forward_image(image) + # Cache the most recent frame's feature (for repeated interactions with + # a frame; we can use an LRU cache for more frames in the future). + inference_state["cached_features"] = {frame_idx: (image, backbone_out)} + + # expand the features to have the same dimension as the number of objects + expanded_image = image.expand(batch_size, -1, -1, -1) + expanded_backbone_out = { + "backbone_fpn": backbone_out["backbone_fpn"].copy(), + "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), + } + for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): + expanded_backbone_out["backbone_fpn"][i] = feat.expand( + batch_size, -1, -1, -1 + ) + for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): + pos = pos.expand(batch_size, -1, -1, -1) + expanded_backbone_out["vision_pos_enc"][i] = pos + + features = self._prepare_backbone_features(expanded_backbone_out) + features = (expanded_image,) + features + return features + + def _run_single_frame_inference( + self, + inference_state, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """Run tracking on a single frame based on current inputs and previous memory.""" + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + pred_masks_gpu = current_out["pred_masks"] + # potentially fill holes in the predicted masks + if self.fill_hole_area > 0: + pred_masks_gpu = fill_holes_in_mask_scores( + pred_masks_gpu, self.fill_hole_area + ) + pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) + # object pointer is a small tensor, so we always keep it on GPU memory for fast access + obj_ptr = current_out["obj_ptr"] + # make a compact version of this frame's output to reduce the state size + compact_current_out = { + "maskmem_features": maskmem_features, + "maskmem_pos_enc": maskmem_pos_enc, + "pred_masks": pred_masks, + "obj_ptr": obj_ptr, + } + return compact_current_out, pred_masks_gpu + + def _run_memory_encoder( + self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts + ): + """ + Run the memory encoder on `high_res_masks`. This is usually after applying + non-overlapping constraints to object scores. Since their scores changed, their + memory also need to be computed again with the memory encoder. + """ + # Retrieve correct image features + _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( + inference_state, frame_idx, batch_size + ) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + is_mask_from_pts=is_mask_from_pts, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc( + inference_state, {"maskmem_pos_enc": maskmem_pos_enc} + ) + return maskmem_features, maskmem_pos_enc + + def _get_maskmem_pos_enc(self, inference_state, current_out): + """ + `maskmem_pos_enc` is the same across frames and objects, so we cache it as + a constant in the inference session to reduce session storage size. + """ + model_constants = inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + out_maskmem_pos_enc = current_out["maskmem_pos_enc"] + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + expanded_maskmem_pos_enc = [ + x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc + ] + else: + expanded_maskmem_pos_enc = None + return expanded_maskmem_pos_enc + + def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): + """ + Remove the non-conditioning memory around the input frame. When users provide + correction clicks, the surrounding frames' non-conditioning memories can still + contain outdated object appearance information and could confuse the model. + + This method clears those non-conditioning memories surrounding the interacted + frame to avoid giving the model both old and new information about the object. + """ + r = self.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.num_maskmem + frame_idx_end = frame_idx + r * self.num_maskmem + output_dict = inference_state["output_dict"] + non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] + for t in range(frame_idx_begin, frame_idx_end + 1): + non_cond_frame_outputs.pop(t, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/sam2/utils/__init__.py b/sam2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3 --- /dev/null +++ b/sam2/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2/utils/amg.py b/sam2/utils/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..2543a17555e304caaf80f8e62e7b19c3d9da0c36 --- /dev/null +++ b/sam2/utils/amg.py @@ -0,0 +1,348 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + +import numpy as np +import torch + +# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.float().detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size: (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx: idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..f8efad91aadbc70dff2b42908be14d51957bc9a4 --- /dev/null +++ b/sam2/utils/misc.py @@ -0,0 +1,238 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from threading import Thread + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +def get_connected_components(mask): + """ + Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). + + Inputs: + - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is + background. + + Outputs: + - labels: A tensor of shape (N, 1, H, W) containing the connected component labels + for foreground pixels and 0 for background pixels. + - counts: A tensor of shape (N, 1, H, W) containing the area of the connected + components for foreground pixels and 0 for background pixels. + """ + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] boxes, dtype=torch.Tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + + return bbox_coords + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self._images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.cuda(non_blocking=True) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError("Only JPEG frames are supported at this moment") + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, image_size, offload_video_to_cpu, img_mean, img_std + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.cuda() + img_mean = img_mean.cuda() + img_std = img_std.cuda() + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def fill_holes_in_mask_scores(mask, max_area): + """ + A post processor to fill small holes in mask scores with area under `max_area`. + """ + # Holes are those connected components in background with area <= self.max_area + # (background regions are those with mask scores <= 0) + assert max_area > 0, "max_area must be positive" + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + return mask + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} diff --git a/sam2/utils/torch_nms.py b/sam2/utils/torch_nms.py new file mode 100644 index 0000000000000000000000000000000000000000..82f1a1f5c0dcab0292fb414723ba2c01947f081a --- /dev/null +++ b/sam2/utils/torch_nms.py @@ -0,0 +1,20 @@ +import torch +from torchvision.ops.boxes import box_iou + + +def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor: + order = torch.argsort(-scores) + keep = [] + + while order.numel() > 0: + i = order[0] + keep.append(i.item()) + + if order.numel() == 1: + break + + ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0] + mask = ious <= iou_threshold + order = order[1:][mask] + + return torch.tensor(keep, device=bboxes.device) diff --git a/sam2/utils/transforms.py b/sam2/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..c5d405a5b74e4878898b6b95664a868a3da0cddd --- /dev/null +++ b/sam2/utils/transforms.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Normalize, Resize, ToTensor + + +class SAM2Transforms(nn.Module): + def __init__( + self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 + ): + """ + Transforms for SAM2. + """ + super().__init__() + self.resolution = resolution + self.mask_threshold = mask_threshold + self.max_hole_area = max_hole_area + self.max_sprinkle_area = max_sprinkle_area + self.mean = [0.485, 0.456, 0.406] + self.std = [0.229, 0.224, 0.225] + self.to_tensor = ToTensor() + self.transforms = nn.Sequential( + Resize((self.resolution, self.resolution), antialias=True), + Normalize(self.mean, self.std), + ) + + def __call__(self, x): + x = self.to_tensor(x) + return self.transforms(x) + + def forward_batch(self, img_list): + img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] + img_batch = torch.stack(img_batch, dim=0) + return img_batch + + def transform_coords( + self, coords: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. + """ + if normalize: + assert orig_hw is not None + h, w = orig_hw + coords = coords.clone() + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + + coords = coords * self.resolution # unnormalize coords + return coords + + def transform_boxes( + self, boxes: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, + if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + """ + boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) + return boxes + + def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: + """ + Perform PostProcessing on output masks. + """ + from sam2.utils.misc import get_connected_components + + masks = masks.float() + if self.max_hole_area > 0: + # Holes are those connected components in background with area <= self.fill_hole_area + # (background regions are those with mask scores <= self.mask_threshold) + mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image + labels, areas = get_connected_components(mask_flat <= self.mask_threshold) + is_hole = (labels > 0) & (areas <= self.max_hole_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with a small positive mask score (10.0) to change them to foreground. + masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) + + if self.max_sprinkle_area > 0: + labels, areas = get_connected_components(mask_flat > self.mask_threshold) + is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with negative mask score (-10.0) to change them to background. + masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) + + masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) + return masks diff --git a/sam2_configs/__init__.py b/sam2_configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3 --- /dev/null +++ b/sam2_configs/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2_configs/sam2_hiera_b+.yaml b/sam2_configs/sam2_hiera_b+.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9fdcfa4054b6d03a159c4fad01515fd3153d23d8 --- /dev/null +++ b/sam2_configs/sam2_hiera_b+.yaml @@ -0,0 +1,113 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2_configs/sam2_hiera_l.yaml b/sam2_configs/sam2_hiera_l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5ef6cc3a4c7656f2791d12575b70b3dbb665bb25 --- /dev/null +++ b/sam2_configs/sam2_hiera_l.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2_configs/sam2_hiera_s.yaml b/sam2_configs/sam2_hiera_s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e6ebeeae747874ba1938ffdf69876202e7a98c0a --- /dev/null +++ b/sam2_configs/sam2_hiera_s.yaml @@ -0,0 +1,116 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 11, 2] + global_att_blocks: [7, 10, 13] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2_configs/sam2_hiera_t.yaml b/sam2_configs/sam2_hiera_t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0c0b3a36094f1cee9b6d320eacd1c5774e019fb2 --- /dev/null +++ b/sam2_configs/sam2_hiera_t.yaml @@ -0,0 +1,118 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 7, 2] + global_att_blocks: [5, 7, 9] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + # SAM decoder + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # HieraT does not currently support compilation, should always be set to False + compile_image_encoder: False diff --git a/segment_anything_fb/__init__.py b/segment_anything_fb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c0c366d59522d086e17a36a03e6eb73171e8066 --- /dev/null +++ b/segment_anything_fb/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .build_sam import ( + build_sam, + build_sam_vit_h, + build_sam_vit_l, + build_sam_vit_b, + sam_model_registry, +) +from .predictor import SamPredictor +from .automatic_mask_generator import SamAutomaticMaskGenerator + +__all__ = [ + "build_sam", + "build_sam_vit_h", + "build_sam_vit_l", + "build_sam_vit_b", + "sam_model_registry", + "SamPredictor", + "SamAutomaticMaskGenerator", +] diff --git a/segment_anything_fb/automatic_mask_generator.py b/segment_anything_fb/automatic_mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3c70227dbef3028e6e3673518181923be55272 --- /dev/null +++ b/segment_anything_fb/automatic_mask_generator.py @@ -0,0 +1,383 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from .modeling import Sam +from .predictor import SamPredictor +from .utils.amg import (MaskData, area_from_rle, batch_iterator, batched_mask_to_box, + box_xyxy_to_xywh, build_all_layer_point_grids, calculate_stability_score, + coco_encode_rle, generate_crop_boxes, is_box_near_crop_edge, + mask_to_rle_pytorch, remove_small_regions, rle_to_mask, uncrop_boxes_xyxy, + uncrop_masks, uncrop_points) +from .utils.torch_nms import nms + + +class SamAutomaticMaskGenerator: + def __init__( + self, + model: Sam, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + ) -> None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = SamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + max(self.box_nms_thresh, self.crop_nms_thresh), + ) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + try: + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + except Exception: + keep_by_nms = nms( + data["boxes"].float(), + scores, + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) + data.cat(batch_data) + del batch_data + self.predictor.reset_image() + + # Remove duplicates within this crop. + try: + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + except Exception: + keep_by_nms = nms( + data["boxes"].float(), + data["iou_preds"], + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + transformed_points = self.predictor.transform.apply_coords(points, im_size).astype( + np.float32 + ) + in_points = torch.as_tensor(transformed_points, device=self.predictor.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + masks, iou_preds, _ = self.predictor.predict_torch( + in_points[:, None, :], + in_labels[:, None], + multimask_output=True, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.predictor.model.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + try: + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + except Exception: + keep_by_nms = nms( + boxes.float(), + torch.as_tensor(scores), + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data diff --git a/segment_anything_fb/build_sam.py b/segment_anything_fb/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..72dd32e4838ac8b34b5a6773ce5a43867b203a7d --- /dev/null +++ b/segment_anything_fb/build_sam.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam diff --git a/segment_anything_fb/modeling/__init__.py b/segment_anything_fb/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81ae8454e62dfcab8cca0e3f79929199f909a00e --- /dev/null +++ b/segment_anything_fb/modeling/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer + +__all__ = [ + "Sam", + "ImageEncoderViT", + "MaskDecoder", + "PromptEncoder", + "TwoWayTransformer", +] diff --git a/segment_anything_fb/modeling/common.py b/segment_anything_fb/modeling/common.py new file mode 100644 index 0000000000000000000000000000000000000000..5c92073d1fd6a44d9a7f3abb9ab610d3ccbcac12 --- /dev/null +++ b/segment_anything_fb/modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/segment_anything_fb/modeling/image_encoder.py b/segment_anything_fb/modeling/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6e74d81fd0bd8e7c33c3e323ba16ab81f37a779b --- /dev/null +++ b/segment_anything_fb/modeling/image_encoder.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/segment_anything_fb/modeling/mask_decoder.py b/segment_anything_fb/modeling/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b46fc0bc5ddc7f4368b86eb6a8c8468b21b880e8 --- /dev/null +++ b/segment_anything_fb/modeling/mask_decoder.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/segment_anything_fb/modeling/prompt_encoder.py b/segment_anything_fb/modeling/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4f73520ad1318da91f271a623c8497c8b9a31475 --- /dev/null +++ b/segment_anything_fb/modeling/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/segment_anything_fb/modeling/sam.py b/segment_anything_fb/modeling/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..49bf2d7e26bbaef527bdeb8b8299adda054fa195 --- /dev/null +++ b/segment_anything_fb/modeling/sam.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/segment_anything_fb/modeling/transformer.py b/segment_anything_fb/modeling/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d99f8e8265b5780dd3be1d8c6bbd33156ac1d8f4 --- /dev/null +++ b/segment_anything_fb/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/segment_anything_fb/predictor.py b/segment_anything_fb/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..dadf27ea8e9962418f7714d08ebfadcf9c4d3182 --- /dev/null +++ b/segment_anything_fb/predictor.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from segment_anything.modeling import Sam + +from typing import Optional, Tuple + +from .utils.transforms import ResizeLongestSide + + +class SamPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/segment_anything_fb/utils/__init__.py b/segment_anything_fb/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3 --- /dev/null +++ b/segment_anything_fb/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/segment_anything_fb/utils/amg.py b/segment_anything_fb/utils/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..1b3177dea0c282cef17942a35479bda5b299d4b8 --- /dev/null +++ b/segment_anything_fb/utils/amg.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size: (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx: idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords.int(), dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords.int(), dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords.int(), dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords.int(), dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/segment_anything_fb/utils/onnx.py b/segment_anything_fb/utils/onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a9d9e2f1c5990f6b279ef7d1bb847063c68e5e --- /dev/null +++ b/segment_anything_fb/utils/onnx.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from typing import Tuple + +from ..modeling import Sam +from .amg import calculate_stability_score + + +class SamOnnxModel(nn.Module): + """ + This model should not be called directly, but is used in ONNX export. + It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, + with some functions modified to enable model tracing. Also supports extra + options controlling what information. See the ONNX export script for details. + """ + + def __init__( + self, + model: Sam, + return_single_mask: bool, + use_stability_score: bool = False, + return_extra_metrics: bool = False, + ) -> None: + super().__init__() + self.mask_decoder = model.mask_decoder + self.model = model + self.img_size = model.image_encoder.img_size + self.return_single_mask = return_single_mask + self.use_stability_score = use_stability_score + self.stability_score_offset = 1.0 + self.return_extra_metrics = return_extra_metrics + + @staticmethod + def resize_longest_image_size( + input_image_size: torch.Tensor, longest_side: int + ) -> torch.Tensor: + input_image_size = input_image_size.to(torch.float32) + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) + return transformed_size + + def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + point_coords = point_coords + 0.5 + point_coords = point_coords / self.img_size + point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( + point_labels == -1 + ) + + for i in range(self.model.prompt_encoder.num_point_embeddings): + point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ + i + ].weight * (point_labels == i) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding + + def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(self.img_size, self.img_size), + mode="bilinear", + align_corners=False, + ) + + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore + + orig_im_size = orig_im_size.to(torch.int64) + h, w = orig_im_size[0], orig_im_size[1] + masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + return masks + + def select_masks( + self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Determine if we should return the multiclick mask or not from the number of points. + # The reweighting is used to avoid control flow. + score_reweight = torch.tensor( + [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] + ).to(iou_preds.device) + score = iou_preds + (num_points - 2.5) * score_reweight + best_idx = torch.argmax(score, dim=1) + masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) + iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) + + return masks, iou_preds + + @torch.no_grad() + def forward( + self, + image_embeddings: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + orig_im_size: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + masks, scores = self.model.mask_decoder.predict_masks( + image_embeddings=image_embeddings, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, + ) + + if self.use_stability_score: + scores = calculate_stability_score( + masks, self.model.mask_threshold, self.stability_score_offset + ) + + if self.return_single_mask: + masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) + + upscaled_masks = self.mask_postprocessing(masks, orig_im_size) + + if self.return_extra_metrics: + stability_scores = calculate_stability_score( + upscaled_masks, self.model.mask_threshold, self.stability_score_offset + ) + areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) + return upscaled_masks, scores, stability_scores, areas, masks + + return upscaled_masks, scores, masks diff --git a/segment_anything_fb/utils/torch_nms.py b/segment_anything_fb/utils/torch_nms.py new file mode 100644 index 0000000000000000000000000000000000000000..82f1a1f5c0dcab0292fb414723ba2c01947f081a --- /dev/null +++ b/segment_anything_fb/utils/torch_nms.py @@ -0,0 +1,20 @@ +import torch +from torchvision.ops.boxes import box_iou + + +def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor: + order = torch.argsort(-scores) + keep = [] + + while order.numel() > 0: + i = order[0] + keep.append(i.item()) + + if order.numel() == 1: + break + + ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0] + mask = ious <= iou_threshold + order = order[1:][mask] + + return torch.tensor(keep, device=bboxes.device) diff --git a/segment_anything_fb/utils/transforms.py b/segment_anything_fb/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..f07693952bbffcd23c5226255d1f649476ca7ce6 --- /dev/null +++ b/segment_anything_fb/utils/transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to the longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/segment_anything_hq/__init__.py b/segment_anything_hq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7f197f87732aa2e19f6c95a627efb0c8abc51c1 --- /dev/null +++ b/segment_anything_hq/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .build_sam import ( + build_sam, + build_sam_vit_h, + build_sam_vit_l, + build_sam_vit_b, + sam_model_registry, +) +from .build_sam_baseline import sam_model_registry_baseline +from .predictor import SamPredictor +from .automatic_mask_generator import SamAutomaticMaskGenerator + +__all__ = [ + "build_sam", + "build_sam_vit_h", + "build_sam_vit_l", + "build_sam_vit_b", + "sam_model_registry", + "sam_model_registry_baseline", + "SamPredictor", + "SamAutomaticMaskGenerator", +] diff --git a/segment_anything_hq/automatic_mask_generator.py b/segment_anything_hq/automatic_mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..3701cc687d8378c0a61f34c02522c82be0d3b3d3 --- /dev/null +++ b/segment_anything_hq/automatic_mask_generator.py @@ -0,0 +1,389 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from .modeling import Sam +from .predictor import SamPredictor +from .utils.amg import (MaskData, area_from_rle, batch_iterator, batched_mask_to_box, + box_xyxy_to_xywh, build_all_layer_point_grids, calculate_stability_score, + coco_encode_rle, generate_crop_boxes, is_box_near_crop_edge, + mask_to_rle_pytorch, remove_small_regions, rle_to_mask, uncrop_boxes_xyxy, + uncrop_masks, uncrop_points) +from .utils.torch_nms import nms + + +class SamAutomaticMaskGenerator: + def __init__( + self, + model: Sam, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + ) -> None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = SamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + @torch.no_grad() + def generate(self, image: np.ndarray, multimask_output: bool = True) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image, multimask_output) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + max(self.box_nms_thresh, self.crop_nms_thresh), + ) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray, multimask_output: bool = True) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size, multimask_output) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + try: + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + except Exception: + keep_by_nms = nms( + data["boxes"].float(), + scores, + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + multimask_output: bool = True, + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # CPU Offloading + self.predictor.model.image_encoder.to("cpu") + self.predictor.features.to(self.predictor.device) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size, multimask_output) + data.cat(batch_data) + del batch_data + self.predictor.reset_image() + + # Remove duplicates within this crop. + try: + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + except Exception: + keep_by_nms = nms( + data["boxes"].float(), + data["iou_preds"], + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + multimask_output: bool = True, + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + transformed_points = self.predictor.transform.apply_coords(points, im_size).astype( + np.float32 + ) + in_points = torch.as_tensor(transformed_points, device=self.predictor.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + masks, iou_preds, _ = self.predictor.predict_torch( + in_points[:, None, :], + in_labels[:, None], + multimask_output=multimask_output, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.predictor.model.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + try: + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + except Exception: + keep_by_nms = nms( + boxes.float(), + torch.as_tensor(scores), + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data diff --git a/segment_anything_hq/build_sam.py b/segment_anything_hq/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7f5def2412d72ec24e9a498928be9fb4aae52c --- /dev/null +++ b/segment_anything_hq/build_sam.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer +import platform + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoderHQ( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + vit_dim=encoder_embed_dim, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + if platform.system() == "Darwin": + if torch.backends.mps.is_available() and torch.backends.mps.is_built(): + state_dict = torch.load(f, map_location=torch.device("mps")) + else: + state_dict = torch.load(f, map_location=torch.device("cpu")) + else: + if torch.cuda.is_available(): + state_dict = torch.load(f) + else: + state_dict = torch.load(f, map_location=torch.device("cpu")) + # info = sam.load_state_dict(state_dict, strict=False) + # print(info) + sam.load_state_dict(state_dict, strict=False) + for n, p in sam.named_parameters(): + if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n: + p.requires_grad = False + + return sam diff --git a/segment_anything_hq/build_sam_baseline.py b/segment_anything_hq/build_sam_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a0b5b9adb5417f3fcb478471d017cefb65078f --- /dev/null +++ b/segment_anything_hq/build_sam_baseline.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +sam_model_registry_baseline = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam diff --git a/segment_anything_hq/modeling/__init__.py b/segment_anything_hq/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a38eeeb41844a53369c2101be8ed82c160dde482 --- /dev/null +++ b/segment_anything_hq/modeling/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder_hq import MaskDecoderHQ +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer + +__all__ = [ + "Sam", + "ImageEncoderViT", + "MaskDecoderHQ", + "MaskDecoder", + "PromptEncoder", + "TwoWayTransformer", +] diff --git a/segment_anything_hq/modeling/common.py b/segment_anything_hq/modeling/common.py new file mode 100644 index 0000000000000000000000000000000000000000..5c92073d1fd6a44d9a7f3abb9ab610d3ccbcac12 --- /dev/null +++ b/segment_anything_hq/modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/segment_anything_hq/modeling/image_encoder.py b/segment_anything_hq/modeling/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..81644dffaab32f255f150d608c471292388cd1c7 --- /dev/null +++ b/segment_anything_hq/modeling/image_encoder.py @@ -0,0 +1,398 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + interm_embeddings = [] + for blk in self.blocks: + x = blk(x) + if blk.window_size == 0: + interm_embeddings.append(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x, interm_embeddings + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/segment_anything_hq/modeling/mask_decoder.py b/segment_anything_hq/modeling/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9dd734ed12f20ba714725d30667928aef6e960e6 --- /dev/null +++ b/segment_anything_hq/modeling/mask_decoder.py @@ -0,0 +1,178 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + hq_token_only: bool, + interm_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/segment_anything_hq/modeling/mask_decoder_hq.py b/segment_anything_hq/modeling/mask_decoder_hq.py new file mode 100644 index 0000000000000000000000000000000000000000..9587f7ca2485c80d23a1310f1934499b2c71d54c --- /dev/null +++ b/segment_anything_hq/modeling/mask_decoder_hq.py @@ -0,0 +1,230 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Modified by HQ-SAM team +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoderHQ(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + vit_dim: int = 1024, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + # HQ-SAM parameters + self.hf_token = nn.Embedding(1, transformer_dim) # HQ-Ouptput-Token + self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) # corresponding new MLP layer for HQ-Ouptput-Token + self.num_mask_tokens = self.num_mask_tokens + 1 + + # three conv fusion layers for obtaining HQ-Feature + self.compress_vit_feat = nn.Sequential( + nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim), + nn.GELU(), + nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2)) + + self.embedding_encoder = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + nn.GELU(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + ) + self.embedding_maskfeature = nn.Sequential( + nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1), + LayerNorm2d(transformer_dim // 4), + nn.GELU(), + nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1)) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + hq_token_only: bool, + interm_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the ViT image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT + hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features) + + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + hq_features=hq_features, + ) + + # Select the correct mask or masks for output + if multimask_output: + # mask with highest score + mask_slice = slice(1, self.num_mask_tokens-1) + iou_pred = iou_pred[:, mask_slice] + iou_pred, max_iou_idx = torch.max(iou_pred, dim=1) + iou_pred = iou_pred.unsqueeze(1) + masks_multi = masks[:, mask_slice, :, :] + masks_sam = masks_multi[torch.arange(masks_multi.size(0)), max_iou_idx].unsqueeze(1) + else: + # singale mask output, default + mask_slice = slice(0, 1) + iou_pred = iou_pred[:, mask_slice] + masks_sam = masks[:, mask_slice] + + masks_hq = masks[:, slice(self.num_mask_tokens-1, self.num_mask_tokens)] + if hq_token_only: + masks = masks_hq + else: + masks = masks_sam + masks_hq + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + hq_features: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + + upscaled_embedding_sam = self.output_upscaling(src) + upscaled_embedding_hq = self.embedding_maskfeature(upscaled_embedding_sam) + hq_features.repeat(b, 1, 1, 1) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + if i < self.num_mask_tokens - 1: + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + else: + hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :])) + + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding_sam.shape + + masks_sam = (hyper_in[:, :self.num_mask_tokens-1] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w) + masks_sam_hq = (hyper_in[:, self.num_mask_tokens-1:] @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w) + masks = torch.cat([masks_sam, masks_sam_hq], dim=1) + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/segment_anything_hq/modeling/prompt_encoder.py b/segment_anything_hq/modeling/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4f73520ad1318da91f271a623c8497c8b9a31475 --- /dev/null +++ b/segment_anything_hq/modeling/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/segment_anything_hq/modeling/sam.py b/segment_anything_hq/modeling/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..a30a0f07305370adc0704d025d4b9b3e6bee5a37 --- /dev/null +++ b/segment_anything_hq/modeling/sam.py @@ -0,0 +1,178 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + hq_token_only: bool = False, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings, interm_embeddings = self.image_encoder(input_images) + interm_embeddings = interm_embeddings[0] # early layer + + outputs = [] + for image_record, curr_embedding, curr_interm in zip(batched_input, image_embeddings, interm_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + hq_token_only=hq_token_only, + interm_embeddings=curr_interm.unsqueeze(0).unsqueeze(0), + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/segment_anything_hq/modeling/transformer.py b/segment_anything_hq/modeling/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d99f8e8265b5780dd3be1d8c6bbd33156ac1d8f4 --- /dev/null +++ b/segment_anything_hq/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/segment_anything_hq/predictor.py b/segment_anything_hq/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..eec89afbc4e65fc5fa71160f76fc34c8d49d5288 --- /dev/null +++ b/segment_anything_hq/predictor.py @@ -0,0 +1,276 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from .modeling import Sam + +from typing import Optional, Tuple + +from .utils.transforms import ResizeLongestSide + + +class SamPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + # import pdb;pdb.set_trace() + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + # import pdb;pdb.set_trace() + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features, self.interm_features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + hq_token_only: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + hq_token_only=hq_token_only, + ) + + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + hq_token_only: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + hq_token_only=hq_token_only, + interm_embeddings=self.interm_features, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/segment_anything_hq/utils/__init__.py b/segment_anything_hq/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3 --- /dev/null +++ b/segment_anything_hq/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/segment_anything_hq/utils/amg.py b/segment_anything_hq/utils/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..1b3177dea0c282cef17942a35479bda5b299d4b8 --- /dev/null +++ b/segment_anything_hq/utils/amg.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size: (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx: idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords.int(), dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords.int(), dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords.int(), dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords.int(), dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/segment_anything_hq/utils/onnx.py b/segment_anything_hq/utils/onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a9d9e2f1c5990f6b279ef7d1bb847063c68e5e --- /dev/null +++ b/segment_anything_hq/utils/onnx.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from typing import Tuple + +from ..modeling import Sam +from .amg import calculate_stability_score + + +class SamOnnxModel(nn.Module): + """ + This model should not be called directly, but is used in ONNX export. + It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, + with some functions modified to enable model tracing. Also supports extra + options controlling what information. See the ONNX export script for details. + """ + + def __init__( + self, + model: Sam, + return_single_mask: bool, + use_stability_score: bool = False, + return_extra_metrics: bool = False, + ) -> None: + super().__init__() + self.mask_decoder = model.mask_decoder + self.model = model + self.img_size = model.image_encoder.img_size + self.return_single_mask = return_single_mask + self.use_stability_score = use_stability_score + self.stability_score_offset = 1.0 + self.return_extra_metrics = return_extra_metrics + + @staticmethod + def resize_longest_image_size( + input_image_size: torch.Tensor, longest_side: int + ) -> torch.Tensor: + input_image_size = input_image_size.to(torch.float32) + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) + return transformed_size + + def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + point_coords = point_coords + 0.5 + point_coords = point_coords / self.img_size + point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( + point_labels == -1 + ) + + for i in range(self.model.prompt_encoder.num_point_embeddings): + point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ + i + ].weight * (point_labels == i) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding + + def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(self.img_size, self.img_size), + mode="bilinear", + align_corners=False, + ) + + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore + + orig_im_size = orig_im_size.to(torch.int64) + h, w = orig_im_size[0], orig_im_size[1] + masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + return masks + + def select_masks( + self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Determine if we should return the multiclick mask or not from the number of points. + # The reweighting is used to avoid control flow. + score_reweight = torch.tensor( + [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] + ).to(iou_preds.device) + score = iou_preds + (num_points - 2.5) * score_reweight + best_idx = torch.argmax(score, dim=1) + masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) + iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) + + return masks, iou_preds + + @torch.no_grad() + def forward( + self, + image_embeddings: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + orig_im_size: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + masks, scores = self.model.mask_decoder.predict_masks( + image_embeddings=image_embeddings, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, + ) + + if self.use_stability_score: + scores = calculate_stability_score( + masks, self.model.mask_threshold, self.stability_score_offset + ) + + if self.return_single_mask: + masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) + + upscaled_masks = self.mask_postprocessing(masks, orig_im_size) + + if self.return_extra_metrics: + stability_scores = calculate_stability_score( + upscaled_masks, self.model.mask_threshold, self.stability_score_offset + ) + areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) + return upscaled_masks, scores, stability_scores, areas, masks + + return upscaled_masks, scores, masks diff --git a/segment_anything_hq/utils/torch_nms.py b/segment_anything_hq/utils/torch_nms.py new file mode 100644 index 0000000000000000000000000000000000000000..82f1a1f5c0dcab0292fb414723ba2c01947f081a --- /dev/null +++ b/segment_anything_hq/utils/torch_nms.py @@ -0,0 +1,20 @@ +import torch +from torchvision.ops.boxes import box_iou + + +def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor: + order = torch.argsort(-scores) + keep = [] + + while order.numel() > 0: + i = order[0] + keep.append(i.item()) + + if order.numel() == 1: + break + + ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0] + mask = ious <= iou_threshold + order = order[1:][mask] + + return torch.tensor(keep, device=bboxes.device) diff --git a/segment_anything_hq/utils/transforms.py b/segment_anything_hq/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..f07693952bbffcd23c5226255d1f649476ca7ce6 --- /dev/null +++ b/segment_anything_hq/utils/transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to the longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww)