MatterGPT / mattergpt_pipeline.py
xiaohang07's picture
Upload 7 files
4475574 verified
from transformers import Pipeline
import torch
from typing import Dict, List, Union
class MatterGPTPipeline(Pipeline):
def __init__(self, model, tokenizer, device=-1):
super().__init__(model=model, tokenizer=tokenizer, device=device)
def _sanitize_parameters(self, **kwargs):
return {}, {}, {}
def preprocess(self, inputs: Union[Dict[str, float], List[Dict[str, float]]]) -> Dict[str, torch.Tensor]:
if isinstance(inputs, dict):
inputs = [inputs]
conditions = [[input['formation_energy'], input['band_gap']] for input in inputs]
context = '>'
x = torch.tensor([self.tokenizer.stoi[context]], dtype=torch.long)[None,...].repeat(len(conditions), 1).to(self.device)
p = torch.tensor(conditions, dtype=torch.float).unsqueeze(1).to(self.device)
return {"input_ids": x, "prop": p}
def _forward(self, model_inputs):
return self.model.generate(
model_inputs["input_ids"],
prop=model_inputs["prop"],
max_length=self.model.config.block_size,
temperature=1.2,
do_sample=True,
top_k=0,
top_p=0.9
)
def postprocess(self, model_outputs):
return [self.tokenizer.decode(seq.tolist()) for seq in model_outputs]
def __call__(self, inputs: Union[Dict[str, float], List[Dict[str, float]]]):
pre_processed = self.preprocess(inputs)
model_outputs = self._forward(pre_processed)
return self.postprocess(model_outputs)