MatterGPT / mattergpt_wrapper.py
xiaohang07's picture
Upload 7 files
4475574 verified
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig
from model import GPT, GPTConfig # Import your original model and config classes
import json
class CustomGPTConfig(PretrainedConfig):
model_type = "gpt"
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
class MatterGPTWrapper(PreTrainedModel):
config_class = CustomGPTConfig
base_model_prefix = "gpt"
def __init__(self, config):
super().__init__(config)
self.model = GPT(GPTConfig(**config.__dict__))
def forward(self, input_ids, attention_mask=None, labels=None, prop=None):
return self.model(input_ids, targets=labels, prop=prop)
def generate(self, input_ids, prop, max_length, num_return_sequences=1, **kwargs):
steps = max_length - input_ids.shape[1]
return self.model.sample(input_ids, steps, prop=prop, **kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_path, *model_args, **kwargs):
config_file = f"{pretrained_model_path}/config.json"
with open(config_file, 'r') as f:
config_dict = json.load(f)
config = CustomGPTConfig(**config_dict)
model = cls(config)
state_dict = torch.load(f"{pretrained_model_path}/pytorch_model.pt", map_location="cpu")
model.model.load_state_dict(state_dict)
return model
def save_pretrained(self, save_directory):
self.config.save_pretrained(save_directory)
torch.save(self.model.state_dict(), f"{save_directory}/pytorch_model.pt")
class SimpleTokenizer:
def __init__(self, vocab_file):
with open(vocab_file, 'r') as f:
self.vocab = f.read().splitlines()
self.vocab = sorted(set(self.vocab + ['<', '>']))
self.stoi = {ch: i for i, ch in enumerate(self.vocab)}
self.itos = {i: ch for i, ch in enumerate(self.vocab)}
def encode(self, text):
return [self.stoi[token] for token in text.split()]
def decode(self, ids):
return " ".join([self.itos[int(i)] for i in ids if i in self.itos]).replace("<", "").strip()
def __call__(self, text, return_tensors=None):
encoded = self.encode(text)
if return_tensors == 'pt':
import torch
return {'input_ids': torch.tensor([encoded])}
return {'input_ids': [encoded]}