manbeast3b
commited on
Update src/pipeline.py
Browse files- src/pipeline.py +4 -6
src/pipeline.py
CHANGED
@@ -40,8 +40,6 @@ torch.backends.cudnn.enabled = True
|
|
40 |
Pipeline = None
|
41 |
ckpt_id = "black-forest-labs/FLUX.1-schnell"
|
42 |
ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
|
43 |
-
TinyVAE = "madebyollin/taef1"
|
44 |
-
TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
|
45 |
|
46 |
def empty_cache():
|
47 |
gc.collect()
|
@@ -50,13 +48,13 @@ def empty_cache():
|
|
50 |
torch.cuda.reset_peak_memory_stats()
|
51 |
|
52 |
def load_pipeline() -> Pipeline:
|
53 |
-
text_encoder_2 = T5EncoderModel.from_pretrained("manbeast3b/flux.1-schnell-full1", revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146", subfolder="text_encoder_2",torch_dtype=torch.bfloat16)
|
54 |
path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146/transformer")
|
55 |
-
transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
|
|
|
56 |
pipeline = FluxPipeline.from_pretrained(ckpt_id, revision=ckpt_revision, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16,)
|
57 |
pipeline.to("cuda")
|
58 |
-
|
59 |
-
for _ in range(1):
|
60 |
pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
|
61 |
return pipeline
|
62 |
|
|
|
40 |
Pipeline = None
|
41 |
ckpt_id = "black-forest-labs/FLUX.1-schnell"
|
42 |
ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
|
|
|
|
|
43 |
|
44 |
def empty_cache():
|
45 |
gc.collect()
|
|
|
48 |
torch.cuda.reset_peak_memory_stats()
|
49 |
|
50 |
def load_pipeline() -> Pipeline:
|
51 |
+
text_encoder_2 = T5EncoderModel.from_pretrained("manbeast3b/flux.1-schnell-full1", revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146", subfolder="text_encoder_2",torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
|
52 |
path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146/transformer")
|
53 |
+
transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False).to(memory_format=torch.channels_last)
|
54 |
+
quantize_(AutoencoderKL.from_pretrained(ids,revision=Revision, subfolder="vae", local_files_only=True, torch_dtype=torch.bfloat16,), int8_weight_only())
|
55 |
pipeline = FluxPipeline.from_pretrained(ckpt_id, revision=ckpt_revision, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16,)
|
56 |
pipeline.to("cuda")
|
57 |
+
with torch.inference_mode():
|
|
|
58 |
pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
|
59 |
return pipeline
|
60 |
|