abideen commited on
Commit
389a649
·
verified ·
1 Parent(s): 8f0a9cf

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ wandb
2
+ src/__pycache__
3
+ scripts/run_orpo.sh
4
+ src/accelerate/fsdp.yaml
README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # **ORPO**
2
+
3
+ This is the official repository for <a class="link" href="https://arxiv.org/abs/2403.07691">**Reference-free Monolithic Preference Optimization with Odds Ratio**</a>. The detailed results in the paper can be found in:
4
+ - [Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard?query=kaist-ai%2Fmistral-orpo-beta)
5
+ - [AlpacaEval](#alpacaeval)
6
+ - [MT-Bench](#mt-bench)
7
+ - [IFEval](#ifeval)
8
+
9
+ &nbsp;
10
+
11
+ ### **`Model Checkpoints`**
12
+
13
+ Our models trained with ORPO can be found in:
14
+
15
+ - [X] **Mistral-ORPO-⍺**: 🤗 <a class="link" href="https://huggingface.co/kaist-ai/mistral-orpo-alpha">kaist-ai/mistral-orpo-alpha</a>
16
+ - [X] **Mistral-ORPO-β**: 🤗 <a class="link" href="https://huggingface.co/kaist-ai/mistral-orpo-beta">kaist-ai/mistral-orpo-beta</a>
17
+
18
+ And the corresponding logs for the average log probabilities of chosen/rejected responses during training are reported in:
19
+
20
+ - [X] **Mistral-ORPO-⍺**: <a class="link" href="https://wandb.ai/jiwooya1000/PREF/reports/Mistral-ORPO-7B-Training-Log--Vmlldzo3MTE1NzE0?accessToken=rms6o4mg5vo3feu1bvbpk632m4cspe19l0u1p4he3othx5bgean82chn9neiile6">Wandb Report for Mistral-ORPO-⍺</a>
21
+ - [X] **Mistral-ORPO-β**: <a class="link" href="https://wandb.ai/jiwooya1000/PREF/reports/Mistral-ORPO-7B-Training-Log--Vmlldzo3MTE3MzMy?accessToken=dij4qbp6dcrofsanzbgobjsne9el8a2zkly2u5z82rxisd4wiwv1rhp0s2dub11e">Wandb Report for Mistral-ORPO-β</a>
22
+
23
+ &nbsp;
24
+
25
+ ### **`AlpacaEval`**
26
+
27
+ <figure>
28
+ <img class="png" src="/assets/img/alpaca_blog.png" alt="Description of the image">
29
+ <figcaption><b>Figure 1.</b> AlpacaEval 2.0 score for the models trained with different alignment methods.</figcaption>
30
+ </figure>
31
+
32
+ &nbsp;
33
+
34
+ ### **`MT-Bench`**
35
+
36
+ <figure>
37
+ <img class="png" src="/assets/img/mtbench_hf.png" alt="Description of the image">
38
+ <figcaption><b>Figure 2.</b> MT-Bench result by category.</figcaption>
39
+ </figure>
40
+
41
+ &nbsp;
42
+
43
+ ### **`IFEval`**
44
+
45
+ IFEval scores are measured with <a class="link" href="https://github.com/EleutherAI/lm-evaluation-harness">EleutherAI/lm-evaluation-harness</a> by applying the chat template. The scores for Llama-2-Chat (70B), Zephyr-β (7B), and Mixtral-8X7B-Instruct-v0.1 are originally reported in <a class="link" href="https://twitter.com/wiskojo/status/1739767758462877823">this tweet</a>.
46
+
47
+ | **Model Type** | **Prompt-Strict** | **Prompt-Loose** | **Inst-Strict** | **Inst-Loose** |
48
+ |--------------------|:-----------------:|:----------------:|:---------------:|----------------|
49
+ | **Llama-2-Chat (70B)** | 0.4436 | 0.5342 | 0.5468 | 0.6319 |
50
+ | **Zephyr-β (7B)** | 0.4233 | 0.4547 | 0.5492 | 0.5767 |
51
+ | **Mixtral-8X7B-Instruct-v0.1** | 0.5213 | **0.5712** | 0.6343 | **0.6823** |
52
+ | **Mistral-ORPO-⍺ (7B)** | 0.5009 | 0.5083 | 0.5995 | 0.6163 |
53
+ | **Mistral-ORPO-β (7B)** | **0.5287** | 0.5564 | **0.6355** | 0.6619 |
assets/img/alpaca_blog.png ADDED
assets/img/mtbench_hf.png ADDED
main.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import time
4
+ import wandb
5
+ import torch
6
+ import argparse
7
+ from datasets import load_dataset
8
+ from typing import List, Dict, Union
9
+ from transformers import (
10
+ AutoTokenizer,
11
+ AutoModelForCausalLM,
12
+ TrainingArguments,
13
+ DataCollatorForLanguageModeling
14
+ )
15
+
16
+ from src.args import default_args
17
+ from src.orpo_trainer import ORPOTrainer
18
+ from src.utils import preprocess_logits_for_metrics, dataset_split_selector
19
+
20
+ class ORPO(object):
21
+ def __init__(self, args) -> None:
22
+ self.start = time.gmtime()
23
+ self.args = args
24
+
25
+ # Load Tokenizer
26
+ print(">>> 1. Loading Tokenizer")
27
+ self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name, cache_dir=self.args.cache_dir)
28
+ if self.tokenizer.chat_template is None:
29
+ self.tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
30
+ print(" 1-1. Chat Template Applied (<|user|> <|assistant|>)")
31
+ else:
32
+ pass
33
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
34
+
35
+ # Load Model
36
+ print(">>> 2. Loading Model")
37
+ if self.args.flash_attention_2:
38
+ self.model = AutoModelForCausalLM.from_pretrained(self.args.model_name,
39
+ cache_dir=self.args.cache_dir,
40
+ torch_dtype=torch.bfloat16,
41
+ attn_implementation="flash_attention_2")
42
+ else:
43
+ self.model = AutoModelForCausalLM.from_pretrained(self.args.model_name,
44
+ cache_dir=self.args.cache_dir,
45
+ torch_dtype=torch.bfloat16)
46
+
47
+ # Load Dataset
48
+ print(">>> 3. Loading Dataset")
49
+ self.data = load_dataset(self.args.data_name, cache_dir=self.args.cache_dir)
50
+
51
+ # Preprocess Dataset
52
+ print(">>> 4. Filtering and Preprocessing Dataset")
53
+ data_split = dataset_split_selector(self.data)
54
+
55
+ if len(data_split) == 1:
56
+ self.is_test = False
57
+ train_split = data_split[0]
58
+ else:
59
+ self.is_test = True
60
+ train_split = data_split[0]
61
+ test_split = data_split[0]
62
+
63
+ test = self.data[test_split].filter(self.filter_dataset)
64
+ self.test = test.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[test_split].column_names)
65
+
66
+ train = self.data[train_split].filter(self.filter_dataset)
67
+ print(f"\n\n>>> {len(train)} / {len(self.data[train_split])} rows left after filtering by prompt length.")
68
+ self.train = train.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[train_split].column_names)
69
+
70
+ # Set WANDB & Logging Configurations
71
+ self.run_name = f"{self.args.model_name.split('/')[-1]}-{self.args.data_name.split('/')[-1]}-ORPO-{self.start.tm_mday}-{self.start.tm_hour}-{self.start.tm_min}"
72
+ self.save_dir = os.path.join('./checkpoints/', f"{self.args.data_name.split('/')[-1]}/{self.run_name}")
73
+ self.log_dir = os.path.join('./checkpoints/', f"{self.args.data_name.split('/')[-1]}/{self.run_name}/logs")
74
+
75
+ os.makedirs(self.save_dir, exist_ok=True)
76
+ os.makedirs(self.log_dir, exist_ok=True)
77
+
78
+ def preprocess_dataset(self, examples: Union[List, Dict]):
79
+ if 'instruction' in examples.keys():
80
+ prompt_key = 'instruction'
81
+ prompt = [self.tokenizer.apply_chat_template([{'role': 'user', 'content': item}], tokenize=False, add_generation_prompt=True) for item in examples[prompt_key]]
82
+ chosen = [self.tokenizer.apply_chat_template([{'role': 'user', 'content': item_prompt}, {'role': 'assistant', 'content': item_chosen}], tokenize=False) for item_prompt, item_chosen in zip(examples[prompt_key], examples['chosen'])]
83
+ rejected = [self.tokenizer.apply_chat_template([{'role': 'user', 'content': item_prompt}, {'role': 'assistant', 'content': item_rejected}], tokenize=False) for item_prompt, item_rejected in zip(examples[prompt_key], examples['rejected'])]
84
+ else:
85
+ prompt = [self.tokenizer.apply_chat_template([item[0]], tokenize=False, add_generation_prompt=True) for item in examples['chosen']]
86
+ chosen = [self.tokenizer.apply_chat_template(item, tokenize=False) for item in examples['chosen']]
87
+ rejected = [self.tokenizer.apply_chat_template(item, tokenize=False) for item in examples['rejected']]
88
+
89
+ model_inputs = self.tokenizer(prompt,
90
+ max_length=self.args.response_max_length,
91
+ padding='max_length',
92
+ truncation=True,
93
+ return_tensors='pt')
94
+ pos_labels = self.tokenizer(chosen,
95
+ max_length=self.args.response_max_length,
96
+ padding='max_length',
97
+ truncation=True,
98
+ return_tensors='pt')
99
+ neg_labels = self.tokenizer(rejected,
100
+ max_length=self.args.response_max_length,
101
+ padding='max_length',
102
+ truncation=True,
103
+ return_tensors='pt')
104
+
105
+ model_inputs['positive_input_ids'] = pos_labels['input_ids']
106
+ model_inputs['positive_attention_mask'] = pos_labels['attention_mask']
107
+
108
+ model_inputs['negative_input_ids'] = neg_labels['input_ids']
109
+ model_inputs['negative_attention_mask'] = neg_labels['attention_mask']
110
+
111
+ return model_inputs
112
+
113
+ def filter_dataset(self, examples: Union[List, Dict]):
114
+ if 'instruction' in examples.keys():
115
+ query = examples['instruction']
116
+ prompt_length = self.tokenizer.apply_chat_template([{'content': query, 'role': 'user'}], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)
117
+ else:
118
+ prompt_length = self.tokenizer.apply_chat_template([examples['chosen'][0]], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)
119
+
120
+ if prompt_length < self.args.prompt_max_length:
121
+ return True
122
+ else:
123
+ return False
124
+
125
+ def prepare_trainer(self):
126
+ wandb.init(name=self.run_name)
127
+ arguments = TrainingArguments(
128
+ torch_compile=self.args.torch_compile,
129
+ output_dir=self.save_dir, # The output directory
130
+ logging_dir=self.log_dir,
131
+ logging_steps=50,
132
+ learning_rate=self.args.lr,
133
+ overwrite_output_dir=True, # overwrite the content of the output directory
134
+ num_train_epochs=self.args.num_train_epochs, # number of training epochs
135
+ per_device_train_batch_size=self.args.per_device_train_batch_size, # batch size for training
136
+ per_device_eval_batch_size=self.args.per_device_eval_batch_size, # batch size for evaluation
137
+ evaluation_strategy=self.args.evaluation_strategy, # batch size for evaluation
138
+ save_strategy=self.args.evaluation_strategy,
139
+ optim=self.args.optim,
140
+ warmup_steps=self.args.warmup_steps,
141
+ gradient_accumulation_steps=self.args.gradient_accumulation_steps,
142
+ gradient_checkpointing=True, #if ('llama' in self.args.model_name.lower()) or ('mistral' in self.args.model_name.lower()) else False,
143
+ gradient_checkpointing_kwargs={'use_reentrant':True},
144
+ load_best_model_at_end=True,
145
+ do_train=True,
146
+ do_eval= self.is_test,
147
+ lr_scheduler_type=self.args.lr_scheduler_type,
148
+ remove_unused_columns=False,
149
+ report_to='wandb',
150
+ run_name=self.run_name,
151
+ bf16=True
152
+ )
153
+
154
+ data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm=False)
155
+
156
+ self.trainer = ORPOTrainer(
157
+ model=self.model,
158
+ alpha=self.args.alpha,
159
+ pad=self.tokenizer.pad_token_id,
160
+ args=arguments,
161
+ train_dataset=self.train,
162
+ eval_dataset=self.test if self.is_test else None,
163
+ data_collator=data_collator,
164
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics
165
+ )
166
+
167
+ def run(self):
168
+ print(">>> 5. Preparing ORPOTrainer")
169
+ self.prepare_trainer()
170
+ self.trainer.train()
171
+
172
+ # Saving code for FSDP
173
+ if self.trainer.is_fsdp_enabled:
174
+ self.trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
175
+ self.trainer.save_model()
176
+
177
+
178
+ if __name__ == '__main__':
179
+ parser = argparse.ArgumentParser("ORPO")
180
+ args = default_args(parser)
181
+
182
+ # Set WANDB configurations
183
+ if args.wandb_entity is not None and args.wandb_project_name is not None:
184
+ os.environ["WANDB_ENTITY"] = args.wandb_entity
185
+ os.environ["WANDB_PROJECT"] = args.wandb_project_name
186
+ else:
187
+ pass
188
+ os.environ["TOKENIZERS_PARALLELISM"] = 'false'
189
+
190
+ print("================================================================================================\n")
191
+ print(f">>> Fine-tuning {args.model_name} with ORPO on {args.data_name}\n")
192
+ print("================================================================================================")
193
+ print("\n\n>>> Summary:")
194
+ print(f" - Lambda : {args.alpha}")
195
+ print(f" - Training Epochs : {args.num_train_epochs}")
196
+ print(f" - Prompt Max Length : {args.prompt_max_length}")
197
+ print(f" - Response Max Length : {args.response_max_length}")
198
+
199
+ item = ORPO(args=args)
200
+ item.run()
outputs/alpacaeval/Mistral-ORPO-alpha.json ADDED
The diff for this file is too large to render. See raw diff
 
outputs/alpacaeval/Mistral-ORPO-beta.json ADDED
The diff for this file is too large to render. See raw diff
 
outputs/mtbench/Mistral-ORPO-alpha.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
outputs/mtbench/Mistral-ORPO-beta.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
src/accelerate/ds2.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ gradient_accumulation_steps: 1
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: false
8
+ zero_stage: 2
9
+ distributed_type: DEEPSPEED
10
+ downcast_bf16: 'no'
11
+ machine_rank: 0
12
+ main_training_function: main
13
+ mixed_precision: bf16
14
+ num_machines: 1
15
+ num_processes: 2
16
+ rdzv_backend: static
17
+ same_network: true
18
+ tpu_env: []
19
+ tpu_use_cluster: false
20
+ tpu_use_sudo: false
21
+ use_cpu: false
src/args.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def default_args(parser):
2
+ parser.add_argument("--cache_dir", default=None, type=str)
3
+ parser.add_argument("--save_dir", default='./saved', type=str)
4
+ parser.add_argument("--data_name", default='HuggingfaceH4/UltraFeedback', type=str)
5
+ parser.add_argument("--model_name", default="gpt2", type=str)
6
+
7
+ # Training Arguments
8
+ parser.add_argument("--torch_compile", default=True, type=bool)
9
+ parser.add_argument("--flash_attention_2", action='store_true')
10
+ parser.add_argument("--lr_scheduler_type", default="cosine", type=str)
11
+ parser.add_argument("--optim", default="paged_adamw_32bit", type=str)
12
+ parser.add_argument("--overwrite_output_dir", default=True, type=bool)
13
+ parser.add_argument("--lr", default=2e-5, type=float)
14
+ parser.add_argument("--num_proc", default=1, type=int)
15
+ parser.add_argument("--num_train_epochs", default=10, type=int)
16
+ parser.add_argument("--per_device_train_batch_size", default=2, type=int)
17
+ parser.add_argument("--per_device_eval_batch_size", default=2, type=int)
18
+ parser.add_argument("--warmup_steps", default=5000, type=int)
19
+ parser.add_argument("--evaluation_strategy", default='epoch', type=str)
20
+ parser.add_argument("--do_eval", action='store_true')
21
+ parser.add_argument("--gradient_accumulation_steps", default=1, type=int)
22
+ parser.add_argument("--save_strategy", default='epoch', type=str)
23
+ parser.add_argument("--prompt_max_length", default=256, type=int)
24
+ parser.add_argument("--response_max_length", default=1024, type=int)
25
+ parser.add_argument("--alpha", default=1.0, type=float)
26
+
27
+ # Wandb Configurations
28
+ parser.add_argument("--wandb_entity", default=None, type=str)
29
+ parser.add_argument("--wandb_project_name", default=None, type=str)
30
+
31
+
32
+ args = parser.parse_args()
33
+
34
+ return args
src/orpo_trainer.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import wandb
4
+ from transformers import Trainer
5
+
6
+
7
+ class ORPOTrainer(Trainer):
8
+ def __init__(self, alpha, pad, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.pad = pad
11
+ self.alpha = alpha
12
+ self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
13
+ print("Pad Token ID: ", self.pad)
14
+
15
+ def compute_custom_loss(self, logits, labels):
16
+
17
+ logits = logits.contiguous()
18
+
19
+ if labels is not None:
20
+ # move labels to correct device to enable model parallelism
21
+ labels = labels.to(logits.device)
22
+ # Shift so that tokens < n predict n
23
+ shift_logits = logits[..., :-1, :].contiguous()
24
+ shift_labels = labels[..., 1:].contiguous()
25
+
26
+ # Flatten the tokens
27
+ loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(dim=-1)
28
+
29
+ return loss
30
+
31
+ def compute_logps(self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits):
32
+ mask = chosen_attention_mask[:, :-1] - prompt_attention_mask[:, 1:]
33
+ per_token_logps = torch.gather(logits[:, :-1, :].log_softmax(-1), dim=2,
34
+ index=(mask * chosen_inputs[:, 1:]).unsqueeze(2)).squeeze(2)
35
+ return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(dtype=torch.float64) / mask.sum(dim=1).to(dtype=torch.float64)
36
+
37
+ def compute_loss(self, model, inputs, return_outputs=False):
38
+ if self.label_smoother is not None and "labels" in inputs:
39
+ labels = inputs.pop("labels")
40
+ else:
41
+ labels = None
42
+
43
+ # Generate the hidden states for 'chosen' and 'reject'
44
+ neg_labels = inputs['negative_input_ids'].clone()
45
+ pos_labels = inputs['positive_input_ids'].clone()
46
+
47
+ neg_labels[neg_labels == self.pad] = -100
48
+ pos_labels[pos_labels == self.pad] = -100
49
+
50
+ outputs_neg = model(**{'input_ids': inputs['negative_input_ids'],
51
+ 'attention_mask': inputs['negative_attention_mask'],
52
+ 'labels': neg_labels,}, output_hidden_states=True)
53
+ outputs_pos = model(**{'input_ids': inputs['positive_input_ids'],
54
+ 'attention_mask': inputs['positive_attention_mask'],
55
+ 'labels': pos_labels,}, output_hidden_states=True)
56
+
57
+ # Calculate NLL loss
58
+ pos_loss = self.compute_custom_loss(logits=outputs_pos.logits, labels=inputs['positive_input_ids'])
59
+
60
+ # Calculate Log Probability
61
+ pos_prob = self.compute_logps(prompt_attention_mask=inputs['attention_mask'],
62
+ chosen_inputs=inputs['positive_input_ids'],
63
+ chosen_attention_mask=inputs['positive_attention_mask'],
64
+ logits=outputs_pos.logits)
65
+ neg_prob = self.compute_logps(prompt_attention_mask=inputs['attention_mask'],
66
+ chosen_inputs=inputs['negative_input_ids'],
67
+ chosen_attention_mask=inputs['negative_attention_mask'],
68
+ logits=outputs_neg.logits)
69
+
70
+ # Calculate log odds
71
+ log_odds = (pos_prob - neg_prob) - (torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)))
72
+ sig_ratio = torch.nn.functional.sigmoid(log_odds)
73
+ ratio = torch.log(sig_ratio)
74
+
75
+ # Calculate the Final Loss
76
+ loss = torch.mean(pos_loss - self.alpha * ratio).to(dtype=torch.bfloat16)
77
+
78
+ wandb.log({'Positive Geometric Mean': torch.mean(pos_prob).item(),
79
+ 'Negative Geometric Mean': torch.mean(neg_prob).item(),
80
+ 'Log Odds Ratio': torch.mean(ratio).item(),
81
+ 'Log Odds': torch.mean(log_odds).item()})
82
+
83
+ return (loss, outputs_pos) if return_outputs else loss
src/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List
3
+
4
+ def preprocess_logits_for_metrics(logits, labels):
5
+ if isinstance(logits, tuple):
6
+ logits = logits[0]
7
+ return logits.argmax(dim=-1)
8
+
9
+ def dataset_split_selector(data) -> List:
10
+ """
11
+ This is a function for automating the process of selecting data split.
12
+ Will be further updated.
13
+ """
14
+ if len(data.keys()) == 1:
15
+ return ['train']
16
+ else:
17
+ if 'train_prefs' in data.keys():
18
+ return ['train_prefs', 'test_prefs']
19
+ else:
20
+ return ['train', 'test']