These are weights for a version of mistralai/Mistral-7B-Instruct-v0.1 finetuned for multimodal applications.

Modalities

  • VoxelModality (use <voxel> in text and provide voxel_data, encoded as 2 tokens)

Usage

GitHub: https://github.com/sshh12/multi_token (includes training scripts and basic inference server)

Dataset

/workspace/multi_token/new/multi_token/data/sentence-voxel-pretrain (202 examples)

{'voxel_data': [-1.2669486999511719, -4.342422008514404, 0.08342710882425308, -1.1121463775634766, -1.7241164445877075, 0.8711026906967163, 1.6187070608139038, 2.1467154026031494, 1.55600106716156, 2.7908051013946533, 2.6149775981903076, 0.48798438906669617, -1.8658868074417114, -0.9153737425804138, 1.0539007186889648, 2.9938547611236572, -1.4584662914276123, 0.06789205223321915, 0.7774376273155212, 0.21760278940200806, -1.8041378259658813, 2.964979648590088, -1.1315451860427856, 0.17553456127643585, -0.30490806698799133, -0.2574838697910309, 0.46714287996292114, -1.0232142210006714, -0.8084980845451355, -1.2524477243423462, -3.438807487487793, 1.2044878005981445, -1.3203097581863403, -1.5149697065353394, 1.3110711574554443, -0.6502295136451721, 0.2924231290817261, -1.8042508363723755, 1.156070351600647, 3.68827748298645, -1.2678762674331665, -0.48739099502563477, -1.9123613834381104, -0.5652288794517517, 0.30757156014442444, -2.6405975818634033, -0.5657948851585388, 0.1962834596633911, 0.4952268898487091, -1.7487742900848389, 1.7829053401947021, -1.7034624814987183, -0.5107262134552002, -0.3320123553276062, -0.06942156702280045, 0.4950488209724426, 2.344041109085083, -1.5664364099502563, 0.19259212911128998, -3.1398189067840576, 0.04002213105559349, -1.2993210554122925, -1.6680536270141602, -1.251158595085144, 1.8072421550750732, -1.0329501628875732, 0.9539159536361694, 1.3106855154037476, -2.569223165512085, -1.2958600521087646, 0.126902237534523, 0.5233652591705322, 0.5843154788017273, -0.5259942412376404, -0.6380230784416199, -0.6816728115081787, -1.121833324432373, 0.3703728914260864, 1.237956166267395, 0.5594802498817444, -0.5233862996101379, -0.13332879543304443, 0.675186276435852, -1.2282785177230835, -3.3140101432800293, 0.7235065698623657, -0.35910749435424805, -2.077662467956543, 0.25364214181900024, -0.04129992425441742, -1.2904301881790161, -1.616705060005188, -1.6876271963119507, -0.7963595390319824, 0.030134305357933044, 1.8337446451187134, -0.7175531983375549, -1.975988745689392, 2.4509336948394775, 0.7048704028129578, 1.4666917324066162, 1.7357171773910522, -2.5205185413360596, 0.3177747130393982, 3.1697638034820557, -0.9803237915039062, 0.2490101158618927, 0.685883104801178, -0.5148935317993164, -0.6637391448020935, 1.1980229616165161, -2.6742348670959473, -0.3336712718009949, 0.7613745927810669, 0.4145558178424835, -0.39548221230506897, -0.8612095713615417, 0.47160154581069946, 1.5164895057678223, -0.7074841260910034, -1.4712883234024048, 0.9962572455406189, -1.2678629159927368, -0.37773820757865906, -1.8931519985198975, -0.05409574508666992, 2.9137215614318848, -0.8817853331565857, 0.6903612613677979, 0.4531203806400299, -1.6106483936309814, 0.23891609907150269, -0.7575222253799438, -0.8597385883331299, -0.4505012631416321, -1.0164486169815063, -2.209623336791992, -0.4585776627063751, -0.8505887389183044, 2.003972291946411, -1.3250545263290405, 3.2319674491882324, 2.2695298194885254, -0.8775315880775452, -0.628717303276062, -0.43926355242729187, 1.9588313102722168, -0.93973308801651, 0.12314625084400177, -0.33370646834373474, 0.07034939527511597, -2.8057355880737305, 1.337593674659729, -0.555436372756958, -2.6099681854248047, -0.712677538394928, 1.286773920059204, 0.38860979676246643, 0.8785397410392761, -1.712486743927002, -0.24093347787857056, 0.1924627721309662, -0.0006318278610706329, -1.6611075401306152, 0.2844694256782532, -1.7149747610092163, -0.5365468859672546, 0.13996855914592743, -0.056381598114967346, 1.8396815061569214, 0.8105614185333252, -1.2487802505493164, 0.4743833541870117, 0.1982801854610443, -0.15110887587070465, 1.4873329401016235, 0.5023205280303955, 0.1126936599612236, 1.627712607383728, -1.4724937677383423, 1.760959267616272, 0.17591479420661926, -1.152338981628418, -0.9325122833251953, 1.3554235696792603, 0.8807990550994873, 0.19217203557491302, -0.3776297867298126, 0.6159052848815918, -0.8186436891555786, 0.2990851104259491, 0.09922473132610321, 0.2839311957359314, 0.3771292567253113, -0.12268450111150742, -1.2299126386642456, 0.5846585631370544, -0.3947390019893646, 1.7231228351593018, 0.33239540457725525, -1.3260372877120972, 0.4368828535079956, 0.2650435268878937, 0.5281450152397156, -1.058358073234558, 0.6126224994659424, -0.688051700592041, 0.8823887705802917, -0.9234603047370911, -0.18388473987579346, -1.1497560739517212, -0.10189923644065857, -1.4299086332321167, 0.4046390950679779, -0.3188319206237793, 1.111311912536621, -1.0168960094451904], 'messages': [{'content': 'What scene might have been seen to cause these voxel activations? <voxel> ', 'role': 'user'}, {'content': 'Plate of spaghetti with basil, peppers, tomatoes, and bananas background.', 'role': 'assistant'}]}

Training Device(s)

name, pci.bus_id, vbios_version
NVIDIA H100 80GB HBM3, 00000000:66:00.0, 96.00.99.00.01

Model

MistralLMMForCausalLM.model =

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MistralLMMForCausalLM(
      (model): MistralLMMModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralSdpaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=1024, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (v_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=1024, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (o_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (rotary_emb): MistralRotaryEmbedding()
            )
            (mlp): MistralMLP(
              (gate_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=14336, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=14336, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (up_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=14336, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=14336, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (down_proj): lora.Linear(
                (base_layer): Linear(in_features=14336, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=14336, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (act_fn): SiLU()
            )
            (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
            (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)
          )
        )
        (norm): MistralRMSNorm((4096,), eps=1e-05)
        (voxel_lmm_projector): _MLPVectorProjector(
          (mlps): ModuleList(
            (0-1): 2 x Sequential(
              (0): Linear(in_features=217, out_features=4096, bias=True)
              (1): GELU(approximate='none')
              (2): Linear(in_features=4096, out_features=4096, bias=True)
              (3): GELU(approximate='none')
              (4): Linear(in_features=4096, out_features=4096, bias=True)
              (5): GELU(approximate='none')
              (6): Linear(in_features=4096, out_features=4096, bias=True)
            )
          )
        )
      )
      (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
    )
  )
)

Framework versions

  • PEFT 0.12.0
Downloads last month
0
Inference API
Inference API (serverless) has been turned off for this model.

Model tree for brainvivo/measured_voxel_decoder_2_tokens

Adapter
(366)
this model