Feature Extraction
Transformers
Safetensors
custom_code
gheinrich commited on
Commit
7aea24b
·
verified ·
1 Parent(s): 1f50832

Upload model

Browse files
Files changed (2) hide show
  1. hf_model.py +3 -0
  2. open_clip_adaptor.py +41 -0
hf_model.py CHANGED
@@ -23,6 +23,8 @@ from .common import RESOURCE_MAP, DEFAULT_VERSION
23
 
24
  # Import all required modules.
25
  from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
 
 
26
  from .adaptor_registry import adaptor_registry
27
  from .cls_token import ClsToken
28
  from .enable_cpe_support import enable_cpe
@@ -31,6 +33,7 @@ from .eradio_model import eradio
31
  from .radio_model import create_model_from_args
32
  from .radio_model import RADIOModel as RADIOModelBase, Resolution
33
  from .input_conditioner import get_default_conditioner, InputConditioner
 
34
  from .vit_patch_generator import ViTPatchGenerator
35
  from .vitdet import apply_vitdet_arch, VitDetArgs
36
 
 
23
 
24
  # Import all required modules.
25
  from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
26
+ from .adaptor_generic import GenericAdaptor, AdaptorBase
27
+ from .adaptor_mlp import create_mlp_from_state
28
  from .adaptor_registry import adaptor_registry
29
  from .cls_token import ClsToken
30
  from .enable_cpe_support import enable_cpe
 
33
  from .radio_model import create_model_from_args
34
  from .radio_model import RADIOModel as RADIOModelBase, Resolution
35
  from .input_conditioner import get_default_conditioner, InputConditioner
36
+ from .open_clip_adaptor import OpenCLIP_RADIO
37
  from .vit_patch_generator import ViTPatchGenerator
38
  from .vitdet import apply_vitdet_arch, VitDetArgs
39
 
open_clip_adaptor.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from argparse import Namespace
9
+
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+
14
+ from .adaptor_registry import adaptor_registry, dict_t, state_t
15
+
16
+ from .adaptor_generic import GenericAdaptor
17
+
18
+
19
+ class OpenCLIP_RADIO(GenericAdaptor):
20
+ def __init__(self, main_config: Namespace, adaptor_config: dict_t, state: state_t):
21
+ super().__init__(main_config, adaptor_config, state)
22
+
23
+ import open_clip
24
+
25
+ self.oc_model = open_clip.create_model_from_pretrained(
26
+ model_name=adaptor_config['model'],
27
+ pretrained=adaptor_config['pretrained'],
28
+ return_transform=False,
29
+ )
30
+ # Unload these parameters
31
+ self.oc_model.visual = None
32
+
33
+ self.tokenizer = open_clip.get_tokenizer(model_name=adaptor_config['model'])
34
+
35
+ def encode_text(self, text, normalize: bool = False):
36
+ return self.oc_model.encode_text(text, normalize=normalize)
37
+
38
+
39
+ @adaptor_registry.register_adaptor("open_clip")
40
+ def create_open_clip_adaptor(main_config: Namespace, adaptor_config: dict_t, state: state_t):
41
+ return OpenCLIP_RADIO(main_config, adaptor_config, state)