File size: 795 Bytes
0dce0bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import math
import copy
import torch
from torch.nn import functional as F
import torch.nn as nn
from .model_proteinglm_clm import ProteinGLMForGeneration
class MSAGPT(ProteinGLMForGeneration):
def __init__(self, args, transformer=None, **kwargs):
super().__init__(
args,
transformer=transformer,
**kwargs
)
@classmethod
def add_model_specific_args(cls, parser):
group = parser.add_argument_group('MSAGPT-inference', 'MSAGPT inference Configurations')
return super().add_model_specific_args(parser)
class FineTuneMSAGPT(MSAGPT):
def __init__(self, args, transformer=None, **kwargs):
super().__init__(
args,
transformer=transformer,
**kwargs
)
pass |