File size: 2,113 Bytes
6fda347 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from transformers import AutoModel, AutoConfig
from DaViT.modeling_davit import DaViTModel
from DaViT.configuration_davit import DaViTConfig
from unittest.mock import patch
import os
import logging
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
from typing import Tuple, Dict, Any, Union, List
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
"""
Custom workaround for the import error related to flash_attn.
Args:
filename (str | os.PathLike): The filename to check for imports.
Returns:
list[str]: List of required imports.
"""
if not str(filename).endswith("modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
if "flash_attn" in imports:
imports.remove("flash_attn")
return imports
current_directory = os.getcwd()
# Register the configuration and model
AutoConfig.register("davit", DaViTConfig)
AutoModel.register(DaViTConfig, DaViTModel)
# Register Huggingface Model
DaViTConfig.register_for_auto_class()
DaViTModel.register_for_auto_class("AutoModel")
AutoConfig.register("davit", DaViTConfig)
AutoModel.register(DaViTConfig, DaViTModel)
# Step 1: Create a configuration object
config = DaViTConfig()
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large-ft",
trust_remote_code=True,
cache_dir=current_directory,
device_map="cpu",
torch_dtype=torch.float16,
)
processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-large-ft",
trust_remote_code=True,
cache_dir=current_directory,
device_map="cpu",
)
# Step 2: Create a model object
model2 = AutoModel.from_config(config)
model2.to(torch.float16)
model2.load_state_dict(model.vision_tower.state_dict())
model2.push_to_hub("DaViT-Florence-2-large-ft")
processor.push_to_hub("DaViT-Florence-2-large-ft") |