PRESTO / README.md
hcaoaf's picture
Upload 10 files
2b72642 verified
|
raw
history blame
15.5 kB
metadata
license: apache-2.0
dataset: sft
tags:
  - finetuned
  - multimodal
inference: false

These are weights for a version of checkpoints/stage2/llava-moleculestm-vicuna-7b-v1.5-pretrain_rxn_nc finetuned for multimodal applications.

Modalities

  • Molecule2DModality (use <molecule_2d> in text and provide molecules

Usage

GitHub: https://github.com/sshh12/bioagent (includes training scripts and basic inference server)

Dataset

sft (765299 examples)

{'molecule_2d': [(tensor([[34,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 6,  0],
        [ 5,  0]]), tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 1],
        [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 1, 6]]), tensor([[0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0]])), (tensor([[5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [6, 0],
        [5, 0],
        [5, 0]]), tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 4, 6, 6, 7, 7, 1],
        [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 4, 7, 6, 1, 7]]), tensor([[0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0]])), (tensor([[ 7,  0],
        [15,  0],
        [ 7,  0],
        [ 7,  0],
        [45,  0],
        [ 6,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 8,  0],
        [ 8,  0],
        [ 8,  0]]), tensor([[ 0,  1,  1,  2,  1,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,  8,  9,
          9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17,  1, 18,
         18, 19, 18, 20, 18, 21, 17,  4, 11,  6, 17, 12],
        [ 1,  0,  2,  1,  3,  1,  4,  3,  5,  4,  6,  5,  7,  6,  8,  7,  9,  8,
         10,  9, 11, 10, 12, 11, 13, 12, 14, 13, 15, 14, 16, 15, 17, 16, 18,  1,
         19, 18, 20, 18, 21, 18,  4, 17,  6, 11, 12, 17]]), tensor([[1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0]])), (tensor([[ 5,  0],
        [ 7,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 7,  0],
        [ 5,  0],
        [ 5,  0],
        [14,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0],
        [ 5,  0]]), tensor([[ 0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  5,  8,  8,  9,
          9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 16, 18,
         14, 19,  9, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27,
         26, 28, 24, 29,  8, 30, 30, 31, 31, 32, 32, 33, 33, 34, 33, 35, 32, 36,
         36, 37, 37, 38, 38, 39, 38, 40, 37, 41, 41, 42, 42, 43, 43, 44, 43, 45,
         30,  2, 42, 31, 18, 10, 28, 20, 19, 10, 29, 20, 17, 12, 27, 22],
        [ 1,  0,  2,  1,  3,  2,  4,  3,  5,  4,  6,  5,  7,  6,  8,  5,  9,  8,
         10,  9, 11, 10, 12, 11, 13, 12, 14, 13, 15, 14, 16, 15, 17, 16, 18, 16,
         19, 14, 20,  9, 21, 20, 22, 21, 23, 22, 24, 23, 25, 24, 26, 25, 27, 26,
         28, 26, 29, 24, 30,  8, 31, 30, 32, 31, 33, 32, 34, 33, 35, 33, 36, 32,
         37, 36, 38, 37, 39, 38, 40, 38, 41, 37, 42, 41, 43, 42, 44, 43, 45, 43,
          2, 30, 31, 42, 10, 18, 20, 28, 10, 19, 20, 29, 12, 17, 22, 27]]), tensor([[0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0]])), (tensor([[ 5,  0],
        [ 5,  0],
        [ 6,  0],
        [14,  0],
        [ 6,  0],
        [14,  0],
        [ 6,  0],
        [ 5,  0],
        [ 5,  0],
        [ 6,  0],
        [ 5,  0],
        [ 5,  0],
        [ 6,  0],
        [ 5,  0],
        [ 5,  0],
        [ 6,  0],
        [ 5,  0],
        [ 5,  0],
        [ 6,  0],
        [ 5,  0],
        [ 5,  0]]), tensor([[ 0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  6,  8,  5,  9,
          9, 10,  9, 11,  5, 12, 12, 13, 12, 14,  3, 15, 15, 16, 15, 17,  3, 18,
         18, 19, 18, 20],
        [ 1,  0,  2,  1,  3,  2,  4,  3,  5,  4,  6,  5,  7,  6,  8,  6,  9,  5,
         10,  9, 11,  9, 12,  5, 13, 12, 14, 12, 15,  3, 16, 15, 17, 15, 18,  3,
         19, 18, 20, 18]]), tensor([[0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [1, 0],
        [1, 0],
        [0, 0],
        [0, 0],
        [1, 0],
        [1, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0]])), (tensor([[5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [6, 0],
        [7, 0]]), tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 1],
        [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 1, 5]]), tensor([[0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0]])), (tensor([[5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [6, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [6, 0],
        [5, 0],
        [5, 0],
        [5, 0]]), tensor([[ 0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,  8,  9,
          9, 10, 10, 11,  4, 12, 12, 13, 13,  1, 11,  6],
        [ 1,  0,  2,  1,  3,  2,  4,  3,  5,  4,  6,  5,  7,  6,  8,  7,  9,  8,
         10,  9, 11, 10, 12,  4, 13, 12,  1, 13,  6, 11]]), tensor([[0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0]]))], 'input_ids': tensor([    1,   518, 25580, 29962,  3532, 14816, 29903,  6778,    13,  3492,
          526,   263,  8950,   391, 29889,  2567,   366,   526,  2183,   263,
        19848,  6306, 29889,  3529,  8500,   278,  1950,   337,   351,  1237,
          310,   278, 19848, 29889,   450, 19848,  6306,   756,   278,  1494,
         3402, 29901,    13, 28956,    13,  8423,   424, 29896, 29889,  8423,
          424, 29906, 29889,  2023,   869,  8423,   424, 29940,  6778,  4704,
           13, 28956,    13,  1576,   736,   995,   881,   367,   297,  3464,
          310, 29871, 29900, 29899, 29896, 29889,   450,  6133,   278,   995,
        29892,   278,   901,  5517,   278, 19848,   338,   304,  6403, 29889,
        29871,    13,  4806,  3867,   278,  3829,   310,   278,   337,  7387,
        29889,    13, 29966,   829, 14816, 29903,  6778,    13,    13, -3996,
        -3996, -3996, -3996, -3996, -3996, -3996,   869, -3996, -3996, -3996,
        -3996, -3996, -3996, -3996, -3996,   869, -3996, -3996, -3996, -3996,
        -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996,
        -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996,   869, -3996,
        -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996,
        -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996,
        -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996,
        -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996,
        -3996, -3996, -3996, -3996, -3996,   869, -3996, -3996, -3996, -3996,
        -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996,
        -3996, -3996, -3996, -3996, -3996, -3996, -3996,   869, -3996, -3996,
        -3996, -3996, -3996, -3996,  5099, -3996, -3996, -3996, -3996, -3996,
        -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996, -3996, 29871,
         5293,   278, 22233, 19848,  2472, 29892,   825,   338,   278, 11959,
          310,   278, 19848, 29915, 29879,  7709, 29973,   518, 29914, 25580,
        29962,   259, 29900, 29889, 29900, 29941, 29953, 29947, 29871,     2]), 'labels': tensor([    1,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,   259, 29900, 29889, 29900, 29941, 29953, 29947, 29871,     2])}

Training Device(s)

name, pci.bus_id, vbios_version
A100-SXM4-40GB, 00000000:07:00.0, 92.00.45.00.03
A100-SXM4-40GB, 00000000:0F:00.0, 92.00.45.00.03
A100-SXM4-40GB, 00000000:47:00.0, 92.00.45.00.03
A100-SXM4-40GB, 00000000:4E:00.0, 92.00.45.00.03
A100-SXM4-40GB, 00000000:87:00.0, 92.00.45.00.03
A100-SXM4-40GB, 00000000:90:00.0, 92.00.45.00.03
A100-SXM4-40GB, 00000000:B7:00.0, 92.00.45.00.03
A100-SXM4-40GB, 00000000:BD:00.0, 92.00.45.00.03

Model

LlamaLMMForCausalLM.model =

LlamaLMMForCausalLM(
  (model): LlamaLMMModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
    (molecule_2d_lmm_projector): _MLPVectorProjector(
      (mlp): Sequential(
        (0): Linear(in_features=300, out_features=4096, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=4096, out_features=4096, bias=True)
      )
    )
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)