from transformers import Pipeline import torch from typing import Dict, List, Union class MatterGPTPipeline(Pipeline): def __init__(self, model, tokenizer, device=-1): super().__init__(model=model, tokenizer=tokenizer, device=device) def _sanitize_parameters(self, **kwargs): return {}, {}, {} def preprocess(self, inputs: Union[Dict[str, float], List[Dict[str, float]]]) -> Dict[str, torch.Tensor]: if isinstance(inputs, dict): inputs = [inputs] conditions = [[input['formation_energy'], input['band_gap']] for input in inputs] context = '>' x = torch.tensor([self.tokenizer.stoi[context]], dtype=torch.long)[None,...].repeat(len(conditions), 1).to(self.device) p = torch.tensor(conditions, dtype=torch.float).unsqueeze(1).to(self.device) return {"input_ids": x, "prop": p} def _forward(self, model_inputs): return self.model.generate( model_inputs["input_ids"], prop=model_inputs["prop"], max_length=self.model.config.block_size, temperature=1.2, do_sample=True, top_k=0, top_p=0.9 ) def postprocess(self, model_outputs): return [self.tokenizer.decode(seq.tolist()) for seq in model_outputs] def __call__(self, inputs: Union[Dict[str, float], List[Dict[str, float]]]): pre_processed = self.preprocess(inputs) model_outputs = self._forward(pre_processed) return self.postprocess(model_outputs)