|
import torch |
|
|
|
class LatentFormat: |
|
scale_factor = 1.0 |
|
latent_rgb_factors = None |
|
taesd_decoder_name = None |
|
|
|
def process_in(self, latent): |
|
return latent * self.scale_factor |
|
|
|
def process_out(self, latent): |
|
return latent / self.scale_factor |
|
|
|
class SD15(LatentFormat): |
|
def __init__(self, scale_factor=0.18215): |
|
self.scale_factor = scale_factor |
|
self.latent_rgb_factors = [ |
|
|
|
[ 0.3512, 0.2297, 0.3227], |
|
[ 0.3250, 0.4974, 0.2350], |
|
[-0.2829, 0.1762, 0.2721], |
|
[-0.2120, -0.2616, -0.7177] |
|
] |
|
self.taesd_decoder_name = "taesd_decoder" |
|
|
|
class SDXL(LatentFormat): |
|
def __init__(self): |
|
self.scale_factor = 0.13025 |
|
self.latent_rgb_factors = [ |
|
|
|
[ 0.3920, 0.4054, 0.4549], |
|
[-0.2634, -0.0196, 0.0653], |
|
[ 0.0568, 0.1687, -0.0755], |
|
[-0.3112, -0.2359, -0.2076] |
|
] |
|
self.taesd_decoder_name = "taesdxl_decoder" |
|
|
|
class SDXL_Playground_2_5(LatentFormat): |
|
def __init__(self): |
|
self.scale_factor = 0.5 |
|
self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1) |
|
self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1) |
|
|
|
self.latent_rgb_factors = [ |
|
|
|
[ 0.3920, 0.4054, 0.4549], |
|
[-0.2634, -0.0196, 0.0653], |
|
[ 0.0568, 0.1687, -0.0755], |
|
[-0.3112, -0.2359, -0.2076] |
|
] |
|
self.taesd_decoder_name = "taesdxl_decoder" |
|
|
|
def process_in(self, latent): |
|
latents_mean = self.latents_mean.to(latent.device, latent.dtype) |
|
latents_std = self.latents_std.to(latent.device, latent.dtype) |
|
return (latent - latents_mean) * self.scale_factor / latents_std |
|
|
|
def process_out(self, latent): |
|
latents_mean = self.latents_mean.to(latent.device, latent.dtype) |
|
latents_std = self.latents_std.to(latent.device, latent.dtype) |
|
return latent * latents_std / self.scale_factor + latents_mean |
|
|
|
|
|
class SD_X4(LatentFormat): |
|
def __init__(self): |
|
self.scale_factor = 0.08333 |
|
self.latent_rgb_factors = [ |
|
[-0.2340, -0.3863, -0.3257], |
|
[ 0.0994, 0.0885, -0.0908], |
|
[-0.2833, -0.2349, -0.3741], |
|
[ 0.2523, -0.0055, -0.1651] |
|
] |
|
|
|
class SC_Prior(LatentFormat): |
|
def __init__(self): |
|
self.scale_factor = 1.0 |
|
self.latent_rgb_factors = [ |
|
[-0.0326, -0.0204, -0.0127], |
|
[-0.1592, -0.0427, 0.0216], |
|
[ 0.0873, 0.0638, -0.0020], |
|
[-0.0602, 0.0442, 0.1304], |
|
[ 0.0800, -0.0313, -0.1796], |
|
[-0.0810, -0.0638, -0.1581], |
|
[ 0.1791, 0.1180, 0.0967], |
|
[ 0.0740, 0.1416, 0.0432], |
|
[-0.1745, -0.1888, -0.1373], |
|
[ 0.2412, 0.1577, 0.0928], |
|
[ 0.1908, 0.0998, 0.0682], |
|
[ 0.0209, 0.0365, -0.0092], |
|
[ 0.0448, -0.0650, -0.1728], |
|
[-0.1658, -0.1045, -0.1308], |
|
[ 0.0542, 0.1545, 0.1325], |
|
[-0.0352, -0.1672, -0.2541] |
|
] |
|
|
|
class SC_B(LatentFormat): |
|
def __init__(self): |
|
self.scale_factor = 1.0 |
|
self.latent_rgb_factors = [ |
|
[ 0.1121, 0.2006, 0.1023], |
|
[-0.2093, -0.0222, -0.0195], |
|
[-0.3087, -0.1535, 0.0366], |
|
[ 0.0290, -0.1574, -0.4078] |
|
] |
|
|