Upload folder using huggingface_hub
Browse files- .gitignore +4 -0
- README.md +53 -0
- assets/img/alpaca_blog.png +0 -0
- assets/img/mtbench_hf.png +0 -0
- main.py +200 -0
- outputs/alpacaeval/Mistral-ORPO-alpha.json +0 -0
- outputs/alpacaeval/Mistral-ORPO-beta.json +0 -0
- outputs/mtbench/Mistral-ORPO-alpha.jsonl +0 -0
- outputs/mtbench/Mistral-ORPO-beta.jsonl +0 -0
- src/accelerate/ds2.yaml +21 -0
- src/args.py +34 -0
- src/orpo_trainer.py +83 -0
- src/utils.py +20 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
wandb
|
2 |
+
src/__pycache__
|
3 |
+
scripts/run_orpo.sh
|
4 |
+
src/accelerate/fsdp.yaml
|
README.md
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# **ORPO**
|
2 |
+
|
3 |
+
This is the official repository for <a class="link" href="https://arxiv.org/abs/2403.07691">**Reference-free Monolithic Preference Optimization with Odds Ratio**</a>. The detailed results in the paper can be found in:
|
4 |
+
- [Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard?query=kaist-ai%2Fmistral-orpo-beta)
|
5 |
+
- [AlpacaEval](#alpacaeval)
|
6 |
+
- [MT-Bench](#mt-bench)
|
7 |
+
- [IFEval](#ifeval)
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
### **`Model Checkpoints`**
|
12 |
+
|
13 |
+
Our models trained with ORPO can be found in:
|
14 |
+
|
15 |
+
- [X] **Mistral-ORPO-⍺**: 🤗 <a class="link" href="https://huggingface.co/kaist-ai/mistral-orpo-alpha">kaist-ai/mistral-orpo-alpha</a>
|
16 |
+
- [X] **Mistral-ORPO-β**: 🤗 <a class="link" href="https://huggingface.co/kaist-ai/mistral-orpo-beta">kaist-ai/mistral-orpo-beta</a>
|
17 |
+
|
18 |
+
And the corresponding logs for the average log probabilities of chosen/rejected responses during training are reported in:
|
19 |
+
|
20 |
+
- [X] **Mistral-ORPO-⍺**: <a class="link" href="https://wandb.ai/jiwooya1000/PREF/reports/Mistral-ORPO-7B-Training-Log--Vmlldzo3MTE1NzE0?accessToken=rms6o4mg5vo3feu1bvbpk632m4cspe19l0u1p4he3othx5bgean82chn9neiile6">Wandb Report for Mistral-ORPO-⍺</a>
|
21 |
+
- [X] **Mistral-ORPO-β**: <a class="link" href="https://wandb.ai/jiwooya1000/PREF/reports/Mistral-ORPO-7B-Training-Log--Vmlldzo3MTE3MzMy?accessToken=dij4qbp6dcrofsanzbgobjsne9el8a2zkly2u5z82rxisd4wiwv1rhp0s2dub11e">Wandb Report for Mistral-ORPO-β</a>
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
### **`AlpacaEval`**
|
26 |
+
|
27 |
+
<figure>
|
28 |
+
<img class="png" src="/assets/img/alpaca_blog.png" alt="Description of the image">
|
29 |
+
<figcaption><b>Figure 1.</b> AlpacaEval 2.0 score for the models trained with different alignment methods.</figcaption>
|
30 |
+
</figure>
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
### **`MT-Bench`**
|
35 |
+
|
36 |
+
<figure>
|
37 |
+
<img class="png" src="/assets/img/mtbench_hf.png" alt="Description of the image">
|
38 |
+
<figcaption><b>Figure 2.</b> MT-Bench result by category.</figcaption>
|
39 |
+
</figure>
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
### **`IFEval`**
|
44 |
+
|
45 |
+
IFEval scores are measured with <a class="link" href="https://github.com/EleutherAI/lm-evaluation-harness">EleutherAI/lm-evaluation-harness</a> by applying the chat template. The scores for Llama-2-Chat (70B), Zephyr-β (7B), and Mixtral-8X7B-Instruct-v0.1 are originally reported in <a class="link" href="https://twitter.com/wiskojo/status/1739767758462877823">this tweet</a>.
|
46 |
+
|
47 |
+
| **Model Type** | **Prompt-Strict** | **Prompt-Loose** | **Inst-Strict** | **Inst-Loose** |
|
48 |
+
|--------------------|:-----------------:|:----------------:|:---------------:|----------------|
|
49 |
+
| **Llama-2-Chat (70B)** | 0.4436 | 0.5342 | 0.5468 | 0.6319 |
|
50 |
+
| **Zephyr-β (7B)** | 0.4233 | 0.4547 | 0.5492 | 0.5767 |
|
51 |
+
| **Mixtral-8X7B-Instruct-v0.1** | 0.5213 | **0.5712** | 0.6343 | **0.6823** |
|
52 |
+
| **Mistral-ORPO-⍺ (7B)** | 0.5009 | 0.5083 | 0.5995 | 0.6163 |
|
53 |
+
| **Mistral-ORPO-β (7B)** | **0.5287** | 0.5564 | **0.6355** | 0.6619 |
|
assets/img/alpaca_blog.png
ADDED
assets/img/mtbench_hf.png
ADDED
main.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import wandb
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
from datasets import load_dataset
|
8 |
+
from typing import List, Dict, Union
|
9 |
+
from transformers import (
|
10 |
+
AutoTokenizer,
|
11 |
+
AutoModelForCausalLM,
|
12 |
+
TrainingArguments,
|
13 |
+
DataCollatorForLanguageModeling
|
14 |
+
)
|
15 |
+
|
16 |
+
from src.args import default_args
|
17 |
+
from src.orpo_trainer import ORPOTrainer
|
18 |
+
from src.utils import preprocess_logits_for_metrics, dataset_split_selector
|
19 |
+
|
20 |
+
class ORPO(object):
|
21 |
+
def __init__(self, args) -> None:
|
22 |
+
self.start = time.gmtime()
|
23 |
+
self.args = args
|
24 |
+
|
25 |
+
# Load Tokenizer
|
26 |
+
print(">>> 1. Loading Tokenizer")
|
27 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name, cache_dir=self.args.cache_dir)
|
28 |
+
if self.tokenizer.chat_template is None:
|
29 |
+
self.tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
|
30 |
+
print(" 1-1. Chat Template Applied (<|user|> <|assistant|>)")
|
31 |
+
else:
|
32 |
+
pass
|
33 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
34 |
+
|
35 |
+
# Load Model
|
36 |
+
print(">>> 2. Loading Model")
|
37 |
+
if self.args.flash_attention_2:
|
38 |
+
self.model = AutoModelForCausalLM.from_pretrained(self.args.model_name,
|
39 |
+
cache_dir=self.args.cache_dir,
|
40 |
+
torch_dtype=torch.bfloat16,
|
41 |
+
attn_implementation="flash_attention_2")
|
42 |
+
else:
|
43 |
+
self.model = AutoModelForCausalLM.from_pretrained(self.args.model_name,
|
44 |
+
cache_dir=self.args.cache_dir,
|
45 |
+
torch_dtype=torch.bfloat16)
|
46 |
+
|
47 |
+
# Load Dataset
|
48 |
+
print(">>> 3. Loading Dataset")
|
49 |
+
self.data = load_dataset(self.args.data_name, cache_dir=self.args.cache_dir)
|
50 |
+
|
51 |
+
# Preprocess Dataset
|
52 |
+
print(">>> 4. Filtering and Preprocessing Dataset")
|
53 |
+
data_split = dataset_split_selector(self.data)
|
54 |
+
|
55 |
+
if len(data_split) == 1:
|
56 |
+
self.is_test = False
|
57 |
+
train_split = data_split[0]
|
58 |
+
else:
|
59 |
+
self.is_test = True
|
60 |
+
train_split = data_split[0]
|
61 |
+
test_split = data_split[0]
|
62 |
+
|
63 |
+
test = self.data[test_split].filter(self.filter_dataset)
|
64 |
+
self.test = test.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[test_split].column_names)
|
65 |
+
|
66 |
+
train = self.data[train_split].filter(self.filter_dataset)
|
67 |
+
print(f"\n\n>>> {len(train)} / {len(self.data[train_split])} rows left after filtering by prompt length.")
|
68 |
+
self.train = train.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[train_split].column_names)
|
69 |
+
|
70 |
+
# Set WANDB & Logging Configurations
|
71 |
+
self.run_name = f"{self.args.model_name.split('/')[-1]}-{self.args.data_name.split('/')[-1]}-ORPO-{self.start.tm_mday}-{self.start.tm_hour}-{self.start.tm_min}"
|
72 |
+
self.save_dir = os.path.join('./checkpoints/', f"{self.args.data_name.split('/')[-1]}/{self.run_name}")
|
73 |
+
self.log_dir = os.path.join('./checkpoints/', f"{self.args.data_name.split('/')[-1]}/{self.run_name}/logs")
|
74 |
+
|
75 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
76 |
+
os.makedirs(self.log_dir, exist_ok=True)
|
77 |
+
|
78 |
+
def preprocess_dataset(self, examples: Union[List, Dict]):
|
79 |
+
if 'instruction' in examples.keys():
|
80 |
+
prompt_key = 'instruction'
|
81 |
+
prompt = [self.tokenizer.apply_chat_template([{'role': 'user', 'content': item}], tokenize=False, add_generation_prompt=True) for item in examples[prompt_key]]
|
82 |
+
chosen = [self.tokenizer.apply_chat_template([{'role': 'user', 'content': item_prompt}, {'role': 'assistant', 'content': item_chosen}], tokenize=False) for item_prompt, item_chosen in zip(examples[prompt_key], examples['chosen'])]
|
83 |
+
rejected = [self.tokenizer.apply_chat_template([{'role': 'user', 'content': item_prompt}, {'role': 'assistant', 'content': item_rejected}], tokenize=False) for item_prompt, item_rejected in zip(examples[prompt_key], examples['rejected'])]
|
84 |
+
else:
|
85 |
+
prompt = [self.tokenizer.apply_chat_template([item[0]], tokenize=False, add_generation_prompt=True) for item in examples['chosen']]
|
86 |
+
chosen = [self.tokenizer.apply_chat_template(item, tokenize=False) for item in examples['chosen']]
|
87 |
+
rejected = [self.tokenizer.apply_chat_template(item, tokenize=False) for item in examples['rejected']]
|
88 |
+
|
89 |
+
model_inputs = self.tokenizer(prompt,
|
90 |
+
max_length=self.args.response_max_length,
|
91 |
+
padding='max_length',
|
92 |
+
truncation=True,
|
93 |
+
return_tensors='pt')
|
94 |
+
pos_labels = self.tokenizer(chosen,
|
95 |
+
max_length=self.args.response_max_length,
|
96 |
+
padding='max_length',
|
97 |
+
truncation=True,
|
98 |
+
return_tensors='pt')
|
99 |
+
neg_labels = self.tokenizer(rejected,
|
100 |
+
max_length=self.args.response_max_length,
|
101 |
+
padding='max_length',
|
102 |
+
truncation=True,
|
103 |
+
return_tensors='pt')
|
104 |
+
|
105 |
+
model_inputs['positive_input_ids'] = pos_labels['input_ids']
|
106 |
+
model_inputs['positive_attention_mask'] = pos_labels['attention_mask']
|
107 |
+
|
108 |
+
model_inputs['negative_input_ids'] = neg_labels['input_ids']
|
109 |
+
model_inputs['negative_attention_mask'] = neg_labels['attention_mask']
|
110 |
+
|
111 |
+
return model_inputs
|
112 |
+
|
113 |
+
def filter_dataset(self, examples: Union[List, Dict]):
|
114 |
+
if 'instruction' in examples.keys():
|
115 |
+
query = examples['instruction']
|
116 |
+
prompt_length = self.tokenizer.apply_chat_template([{'content': query, 'role': 'user'}], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)
|
117 |
+
else:
|
118 |
+
prompt_length = self.tokenizer.apply_chat_template([examples['chosen'][0]], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)
|
119 |
+
|
120 |
+
if prompt_length < self.args.prompt_max_length:
|
121 |
+
return True
|
122 |
+
else:
|
123 |
+
return False
|
124 |
+
|
125 |
+
def prepare_trainer(self):
|
126 |
+
wandb.init(name=self.run_name)
|
127 |
+
arguments = TrainingArguments(
|
128 |
+
torch_compile=self.args.torch_compile,
|
129 |
+
output_dir=self.save_dir, # The output directory
|
130 |
+
logging_dir=self.log_dir,
|
131 |
+
logging_steps=50,
|
132 |
+
learning_rate=self.args.lr,
|
133 |
+
overwrite_output_dir=True, # overwrite the content of the output directory
|
134 |
+
num_train_epochs=self.args.num_train_epochs, # number of training epochs
|
135 |
+
per_device_train_batch_size=self.args.per_device_train_batch_size, # batch size for training
|
136 |
+
per_device_eval_batch_size=self.args.per_device_eval_batch_size, # batch size for evaluation
|
137 |
+
evaluation_strategy=self.args.evaluation_strategy, # batch size for evaluation
|
138 |
+
save_strategy=self.args.evaluation_strategy,
|
139 |
+
optim=self.args.optim,
|
140 |
+
warmup_steps=self.args.warmup_steps,
|
141 |
+
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
|
142 |
+
gradient_checkpointing=True, #if ('llama' in self.args.model_name.lower()) or ('mistral' in self.args.model_name.lower()) else False,
|
143 |
+
gradient_checkpointing_kwargs={'use_reentrant':True},
|
144 |
+
load_best_model_at_end=True,
|
145 |
+
do_train=True,
|
146 |
+
do_eval= self.is_test,
|
147 |
+
lr_scheduler_type=self.args.lr_scheduler_type,
|
148 |
+
remove_unused_columns=False,
|
149 |
+
report_to='wandb',
|
150 |
+
run_name=self.run_name,
|
151 |
+
bf16=True
|
152 |
+
)
|
153 |
+
|
154 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm=False)
|
155 |
+
|
156 |
+
self.trainer = ORPOTrainer(
|
157 |
+
model=self.model,
|
158 |
+
alpha=self.args.alpha,
|
159 |
+
pad=self.tokenizer.pad_token_id,
|
160 |
+
args=arguments,
|
161 |
+
train_dataset=self.train,
|
162 |
+
eval_dataset=self.test if self.is_test else None,
|
163 |
+
data_collator=data_collator,
|
164 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics
|
165 |
+
)
|
166 |
+
|
167 |
+
def run(self):
|
168 |
+
print(">>> 5. Preparing ORPOTrainer")
|
169 |
+
self.prepare_trainer()
|
170 |
+
self.trainer.train()
|
171 |
+
|
172 |
+
# Saving code for FSDP
|
173 |
+
if self.trainer.is_fsdp_enabled:
|
174 |
+
self.trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
175 |
+
self.trainer.save_model()
|
176 |
+
|
177 |
+
|
178 |
+
if __name__ == '__main__':
|
179 |
+
parser = argparse.ArgumentParser("ORPO")
|
180 |
+
args = default_args(parser)
|
181 |
+
|
182 |
+
# Set WANDB configurations
|
183 |
+
if args.wandb_entity is not None and args.wandb_project_name is not None:
|
184 |
+
os.environ["WANDB_ENTITY"] = args.wandb_entity
|
185 |
+
os.environ["WANDB_PROJECT"] = args.wandb_project_name
|
186 |
+
else:
|
187 |
+
pass
|
188 |
+
os.environ["TOKENIZERS_PARALLELISM"] = 'false'
|
189 |
+
|
190 |
+
print("================================================================================================\n")
|
191 |
+
print(f">>> Fine-tuning {args.model_name} with ORPO on {args.data_name}\n")
|
192 |
+
print("================================================================================================")
|
193 |
+
print("\n\n>>> Summary:")
|
194 |
+
print(f" - Lambda : {args.alpha}")
|
195 |
+
print(f" - Training Epochs : {args.num_train_epochs}")
|
196 |
+
print(f" - Prompt Max Length : {args.prompt_max_length}")
|
197 |
+
print(f" - Response Max Length : {args.response_max_length}")
|
198 |
+
|
199 |
+
item = ORPO(args=args)
|
200 |
+
item.run()
|
outputs/alpacaeval/Mistral-ORPO-alpha.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
outputs/alpacaeval/Mistral-ORPO-beta.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
outputs/mtbench/Mistral-ORPO-alpha.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
outputs/mtbench/Mistral-ORPO-beta.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/accelerate/ds2.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
deepspeed_config:
|
4 |
+
gradient_accumulation_steps: 1
|
5 |
+
offload_optimizer_device: none
|
6 |
+
offload_param_device: none
|
7 |
+
zero3_init_flag: false
|
8 |
+
zero_stage: 2
|
9 |
+
distributed_type: DEEPSPEED
|
10 |
+
downcast_bf16: 'no'
|
11 |
+
machine_rank: 0
|
12 |
+
main_training_function: main
|
13 |
+
mixed_precision: bf16
|
14 |
+
num_machines: 1
|
15 |
+
num_processes: 2
|
16 |
+
rdzv_backend: static
|
17 |
+
same_network: true
|
18 |
+
tpu_env: []
|
19 |
+
tpu_use_cluster: false
|
20 |
+
tpu_use_sudo: false
|
21 |
+
use_cpu: false
|
src/args.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def default_args(parser):
|
2 |
+
parser.add_argument("--cache_dir", default=None, type=str)
|
3 |
+
parser.add_argument("--save_dir", default='./saved', type=str)
|
4 |
+
parser.add_argument("--data_name", default='HuggingfaceH4/UltraFeedback', type=str)
|
5 |
+
parser.add_argument("--model_name", default="gpt2", type=str)
|
6 |
+
|
7 |
+
# Training Arguments
|
8 |
+
parser.add_argument("--torch_compile", default=True, type=bool)
|
9 |
+
parser.add_argument("--flash_attention_2", action='store_true')
|
10 |
+
parser.add_argument("--lr_scheduler_type", default="cosine", type=str)
|
11 |
+
parser.add_argument("--optim", default="paged_adamw_32bit", type=str)
|
12 |
+
parser.add_argument("--overwrite_output_dir", default=True, type=bool)
|
13 |
+
parser.add_argument("--lr", default=2e-5, type=float)
|
14 |
+
parser.add_argument("--num_proc", default=1, type=int)
|
15 |
+
parser.add_argument("--num_train_epochs", default=10, type=int)
|
16 |
+
parser.add_argument("--per_device_train_batch_size", default=2, type=int)
|
17 |
+
parser.add_argument("--per_device_eval_batch_size", default=2, type=int)
|
18 |
+
parser.add_argument("--warmup_steps", default=5000, type=int)
|
19 |
+
parser.add_argument("--evaluation_strategy", default='epoch', type=str)
|
20 |
+
parser.add_argument("--do_eval", action='store_true')
|
21 |
+
parser.add_argument("--gradient_accumulation_steps", default=1, type=int)
|
22 |
+
parser.add_argument("--save_strategy", default='epoch', type=str)
|
23 |
+
parser.add_argument("--prompt_max_length", default=256, type=int)
|
24 |
+
parser.add_argument("--response_max_length", default=1024, type=int)
|
25 |
+
parser.add_argument("--alpha", default=1.0, type=float)
|
26 |
+
|
27 |
+
# Wandb Configurations
|
28 |
+
parser.add_argument("--wandb_entity", default=None, type=str)
|
29 |
+
parser.add_argument("--wandb_project_name", default=None, type=str)
|
30 |
+
|
31 |
+
|
32 |
+
args = parser.parse_args()
|
33 |
+
|
34 |
+
return args
|
src/orpo_trainer.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import wandb
|
4 |
+
from transformers import Trainer
|
5 |
+
|
6 |
+
|
7 |
+
class ORPOTrainer(Trainer):
|
8 |
+
def __init__(self, alpha, pad, *args, **kwargs):
|
9 |
+
super().__init__(*args, **kwargs)
|
10 |
+
self.pad = pad
|
11 |
+
self.alpha = alpha
|
12 |
+
self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
13 |
+
print("Pad Token ID: ", self.pad)
|
14 |
+
|
15 |
+
def compute_custom_loss(self, logits, labels):
|
16 |
+
|
17 |
+
logits = logits.contiguous()
|
18 |
+
|
19 |
+
if labels is not None:
|
20 |
+
# move labels to correct device to enable model parallelism
|
21 |
+
labels = labels.to(logits.device)
|
22 |
+
# Shift so that tokens < n predict n
|
23 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
24 |
+
shift_labels = labels[..., 1:].contiguous()
|
25 |
+
|
26 |
+
# Flatten the tokens
|
27 |
+
loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(dim=-1)
|
28 |
+
|
29 |
+
return loss
|
30 |
+
|
31 |
+
def compute_logps(self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits):
|
32 |
+
mask = chosen_attention_mask[:, :-1] - prompt_attention_mask[:, 1:]
|
33 |
+
per_token_logps = torch.gather(logits[:, :-1, :].log_softmax(-1), dim=2,
|
34 |
+
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2)).squeeze(2)
|
35 |
+
return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(dtype=torch.float64) / mask.sum(dim=1).to(dtype=torch.float64)
|
36 |
+
|
37 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
38 |
+
if self.label_smoother is not None and "labels" in inputs:
|
39 |
+
labels = inputs.pop("labels")
|
40 |
+
else:
|
41 |
+
labels = None
|
42 |
+
|
43 |
+
# Generate the hidden states for 'chosen' and 'reject'
|
44 |
+
neg_labels = inputs['negative_input_ids'].clone()
|
45 |
+
pos_labels = inputs['positive_input_ids'].clone()
|
46 |
+
|
47 |
+
neg_labels[neg_labels == self.pad] = -100
|
48 |
+
pos_labels[pos_labels == self.pad] = -100
|
49 |
+
|
50 |
+
outputs_neg = model(**{'input_ids': inputs['negative_input_ids'],
|
51 |
+
'attention_mask': inputs['negative_attention_mask'],
|
52 |
+
'labels': neg_labels,}, output_hidden_states=True)
|
53 |
+
outputs_pos = model(**{'input_ids': inputs['positive_input_ids'],
|
54 |
+
'attention_mask': inputs['positive_attention_mask'],
|
55 |
+
'labels': pos_labels,}, output_hidden_states=True)
|
56 |
+
|
57 |
+
# Calculate NLL loss
|
58 |
+
pos_loss = self.compute_custom_loss(logits=outputs_pos.logits, labels=inputs['positive_input_ids'])
|
59 |
+
|
60 |
+
# Calculate Log Probability
|
61 |
+
pos_prob = self.compute_logps(prompt_attention_mask=inputs['attention_mask'],
|
62 |
+
chosen_inputs=inputs['positive_input_ids'],
|
63 |
+
chosen_attention_mask=inputs['positive_attention_mask'],
|
64 |
+
logits=outputs_pos.logits)
|
65 |
+
neg_prob = self.compute_logps(prompt_attention_mask=inputs['attention_mask'],
|
66 |
+
chosen_inputs=inputs['negative_input_ids'],
|
67 |
+
chosen_attention_mask=inputs['negative_attention_mask'],
|
68 |
+
logits=outputs_neg.logits)
|
69 |
+
|
70 |
+
# Calculate log odds
|
71 |
+
log_odds = (pos_prob - neg_prob) - (torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)))
|
72 |
+
sig_ratio = torch.nn.functional.sigmoid(log_odds)
|
73 |
+
ratio = torch.log(sig_ratio)
|
74 |
+
|
75 |
+
# Calculate the Final Loss
|
76 |
+
loss = torch.mean(pos_loss - self.alpha * ratio).to(dtype=torch.bfloat16)
|
77 |
+
|
78 |
+
wandb.log({'Positive Geometric Mean': torch.mean(pos_prob).item(),
|
79 |
+
'Negative Geometric Mean': torch.mean(neg_prob).item(),
|
80 |
+
'Log Odds Ratio': torch.mean(ratio).item(),
|
81 |
+
'Log Odds': torch.mean(log_odds).item()})
|
82 |
+
|
83 |
+
return (loss, outputs_pos) if return_outputs else loss
|
src/utils.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
def preprocess_logits_for_metrics(logits, labels):
|
5 |
+
if isinstance(logits, tuple):
|
6 |
+
logits = logits[0]
|
7 |
+
return logits.argmax(dim=-1)
|
8 |
+
|
9 |
+
def dataset_split_selector(data) -> List:
|
10 |
+
"""
|
11 |
+
This is a function for automating the process of selecting data split.
|
12 |
+
Will be further updated.
|
13 |
+
"""
|
14 |
+
if len(data.keys()) == 1:
|
15 |
+
return ['train']
|
16 |
+
else:
|
17 |
+
if 'train_prefs' in data.keys():
|
18 |
+
return ['train_prefs', 'test_prefs']
|
19 |
+
else:
|
20 |
+
return ['train', 'test']
|