DJStomp commited on
Commit
b95cb12
·
verified ·
1 Parent(s): ba8904e

Prioritize ZeroGPU init over torch import

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
2
  import random
 
 
3
  import gradio as gr
4
  import torch
5
  from diffusers.utils import load_image
@@ -7,7 +9,6 @@ from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipe
7
  from diffusers.models.controlnet_flux import FluxControlNetModel
8
  import numpy as np
9
  from huggingface_hub import login, snapshot_download
10
- import spaces
11
 
12
 
13
  # Configuration
@@ -52,9 +53,10 @@ def infer(
52
  seed,
53
  randomize_seed,
54
  ):
55
- device = "cuda" if torch.cuda.is_available() else "cpu"
56
- torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
57
- print(f"Inference: using device: {device} (torch_dtype={torch_dtype})")
 
58
  if randomize_seed:
59
  seed = random.randint(0, MAX_SEED)
60
 
 
1
  import os
2
  import random
3
+
4
+ import spaces
5
  import gradio as gr
6
  import torch
7
  from diffusers.utils import load_image
 
9
  from diffusers.models.controlnet_flux import FluxControlNetModel
10
  import numpy as np
11
  from huggingface_hub import login, snapshot_download
 
12
 
13
 
14
  # Configuration
 
53
  seed,
54
  randomize_seed,
55
  ):
56
+ global DEVICE, TORCH_DTYPE
57
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
58
+ TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
59
+ print(f"Inference: using device: {DEVICE} (torch_dtype={TORCH_DTYPE})")
60
  if randomize_seed:
61
  seed = random.randint(0, MAX_SEED)
62