Spaces:
Running
Running
File size: 2,004 Bytes
f3b2c5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from enum import Enum
import torch
from model_classes import Model200M, Model5M, SyntheticV2
from model_transforms import transform_200M, transform_5M, transform_synthetic
class ModelType(str, Enum):
MIDJOURNEY_200M = "midjourney_200M"
DIFFUSIONS_200M = "diffusions_200M"
MIDJOURNEY_5M = "midjourney_5M"
DIFFUSIONS_5M = "diffusions_5M"
SYNTHETIC_DETECTOR_V2 = "synthetic_detector_v2"
def __str__(self):
return str(self.value)
@staticmethod
def get_list():
return [model_type.value for model_type in ModelType]
def load_model(value: ModelType):
model = type_to_class[value]
path = type_to_path[value]
ckpt = torch.load(path, map_location=torch.device('cpu'))
model.load_state_dict(ckpt)
model.eval()
return model
type_to_class = {
ModelType.MIDJOURNEY_200M : Model200M(),
ModelType.DIFFUSIONS_200M : Model200M(),
ModelType.MIDJOURNEY_5M : Model5M(),
ModelType.DIFFUSIONS_5M : Model5M(),
ModelType.SYNTHETIC_DETECTOR_V2 : SyntheticV2(),
}
type_to_path = {
ModelType.MIDJOURNEY_200M : 'models/midjourney200M.pt',
ModelType.DIFFUSIONS_200M : 'models/diffusions200M.pt',
ModelType.MIDJOURNEY_5M : 'models/midjourney5M.pt',
ModelType.DIFFUSIONS_5M : 'models/diffusions5M.pt',
ModelType.SYNTHETIC_DETECTOR_V2 : 'models/synthetic_detector_v2.pt',
}
type_to_loaded_model = {
ModelType.MIDJOURNEY_200M: load_model(ModelType.MIDJOURNEY_200M),
ModelType.DIFFUSIONS_200M: load_model(ModelType.DIFFUSIONS_200M),
ModelType.MIDJOURNEY_5M: load_model(ModelType.MIDJOURNEY_5M),
ModelType.DIFFUSIONS_5M: load_model(ModelType.DIFFUSIONS_5M),
ModelType.SYNTHETIC_DETECTOR_V2: load_model(ModelType.SYNTHETIC_DETECTOR_V2)
}
type_to_transforms = {
ModelType.MIDJOURNEY_200M: transform_200M,
ModelType.DIFFUSIONS_200M: transform_200M,
ModelType.MIDJOURNEY_5M: transform_5M,
ModelType.DIFFUSIONS_5M: transform_5M,
ModelType.SYNTHETIC_DETECTOR_V2: transform_synthetic
} |