AlekseyCalvin commited on
Commit
ed633c6
·
verified ·
1 Parent(s): ec0e9e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -37
app.py CHANGED
@@ -4,55 +4,27 @@ import logging
4
  import torch
5
  from PIL import Image
6
  import spaces
7
- from diffusers import DiffusionPipeline
8
  import copy
9
  import random
10
  import time
11
  from huggingface_hub import hf_hub_download
12
- from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
13
- from accelerate import init_empty_weights
14
- from convert_nf4_flux import replace_with_bnb_linear, create_quantized_param, check_quantized_param
15
  from diffusers import FluxTransformer2DModel, FluxPipeline
16
  import safetensors.torch
17
  import gc
18
- import torch
19
-
20
- # Set dtype and check for float8 support
21
- dtype = torch.bfloat16
22
- is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
23
-
24
- ckpt_path = hf_hub_download("ABDALLALSWAITI/Maxwell", filename="diffusion_pytorch_model.safetensors")
25
- original_state_dict = safetensors.torch.load_file(ckpt_path)
26
 
27
- with init_empty_weights():
28
- config = FluxTransformer2DModel.load_config("ABDALLALSWAITI/Maxwell")
29
- model = FluxTransformer2DModel.from_config(config).to(dtype)
30
- expected_state_dict_keys = list(model.state_dict().keys())
31
 
32
- # Load the state dict into the quantized model
33
- for param_name, param in original_state_dict.items():
34
- if param_name not in expected_state_dict_keys:
35
- continue
36
-
37
- is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
38
- if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
39
- param = param.to(dtype)
40
-
41
- if not check_quantized_param(model, param_name):
42
- set_module_tensor_to_device(model, param_name, device=0, value=param)
43
- else:
44
- create_quantized_param(
45
- model, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=True
46
- )
47
 
48
- # Clean up
49
- del original_state_dict
50
- gc.collect()
51
 
52
- # Print model size
53
- print(compute_module_sizes(model)[""] / 1024 / 1204)
 
 
54
 
55
- pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
56
  pipe.enable_model_cpu_offload()
57
 
58
  # Load LoRAs from JSON file
 
4
  import torch
5
  from PIL import Image
6
  import spaces
 
7
  import copy
8
  import random
9
  import time
10
  from huggingface_hub import hf_hub_download
 
 
 
11
  from diffusers import FluxTransformer2DModel, FluxPipeline
12
  import safetensors.torch
13
  import gc
 
 
 
 
 
 
 
 
14
 
15
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
16
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
17
+ os.environ["HF_HUB_CACHE"] = cache_path
18
+ os.environ["HF_HOME"] = cache_path
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ torch.backends.cuda.matmul.allow_tf32 = True
 
 
22
 
23
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
24
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
25
+ pipe.fuse_lora(lora_scale=0.125)
26
+ pipe.to(device="cuda", dtype=torch.bfloat16)
27
 
 
28
  pipe.enable_model_cpu_offload()
29
 
30
  # Load LoRAs from JSON file