hanchier commited on
Commit
e0b11c9
·
1 Parent(s): d613367
app.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/spaces/Glaciohound/LM-Steer
2
+
3
+ import torch
4
+ import streamlit as st
5
+ import random
6
+ import numpy as np
7
+ import pandas as pd
8
+ from lm_steer.models.get_model import get_model
9
+
10
+
11
+ @st.cache_resource(show_spinner="Loading model...")
12
+ def st_get_model(model_name, low_resource_mode):
13
+ device = torch.device("cuda:0") if torch.cuda.is_available() \
14
+ else torch.device("cpu")
15
+ model, tokenizer = get_model(
16
+ model_name, "final_layer", "multiply",
17
+ 4,
18
+ 1000, 1e-3, 1e-2, low_resource_mode
19
+ )
20
+ model.to_device(device)
21
+ ckpt = torch.load(f"checkpoints/{model_name}.pt", map_location=device)
22
+ model.load_state_dict(ckpt[1])
23
+ return model, tokenizer
24
+
25
+
26
+ def word_embedding_space_analysis(model, tokenizer, dim):
27
+ matrix = model.steer.projector1.data[dim].matmul(
28
+ model.steer.projector2.data[dim].transpose(0, 1))
29
+ S, V, D = torch.linalg.svd(matrix)
30
+ embeddings = model.steer.lm_head.weight
31
+
32
+ data = []
33
+ for _i in range(10):
34
+ left_tokens = embeddings.matmul(D[_i]).argsort()[-20:].flip(0)
35
+ right_tokens = embeddings.matmul(D[_i]).argsort()[:20]
36
+
37
+ def filter_words(side_tokens):
38
+ output = []
39
+ for t in side_tokens:
40
+ word = tokenizer.decode([t])
41
+ if not word[0].isalpha() and word[1:].isalpha():
42
+ output.append(word[1:]+"-")
43
+ return output
44
+
45
+ data.append([
46
+ ", ".join(filter_words(side_tokens))
47
+ for side_tokens in [left_tokens, right_tokens]
48
+ ])
49
+ st.table(pd.DataFrame(
50
+ data,
51
+ columns=["One Direction", "Another Direction"],
52
+ index=[f"Dim {_i}" for _i in range(10)],
53
+ ))
54
+
55
+
56
+ def main():
57
+ # set up the page
58
+ random.seed(0)
59
+ title = "LM-Steer: Word Embeddings Are Steers for Language Models"
60
+ st.set_page_config(
61
+ layout="wide",
62
+ page_title=title,
63
+ page_icon="🛞",
64
+ )
65
+ st.title(title)
66
+ '''
67
+ Live demo for the paper ["**LM-Steer: Word Embeddings Are Steers for
68
+ Language Models**"](https://arxiv.org/abs/2305.12798) (**ACL 2024
69
+ Outstanding Paper Award**) by Chi Han, Jialiang Xu, Manling Li, Yi Fung,
70
+ Chenkai Sun, Nan Jiang, Tarek Abdelzaher, Heng Ji. GitHub repository:
71
+ https://github.com/Glaciohound/LM-Steer.
72
+ '''
73
+ st.subheader("Overview")
74
+ st.image('https://raw.githubusercontent.com/Glaciohound/LM-Steer'
75
+ '/refs/heads/main/assets/overview_fig.jpg')
76
+ '''
77
+ Language models (LMs) automatically learn word embeddings during
78
+ pre-training on language corpora. Although word embeddings are usually
79
+ interpreted as feature vectors for individual words, their roles in
80
+ language model generation remain underexplored. In this work, we
81
+ theoretically and empirically revisit output word embeddings and find that
82
+ their linear transformations are equivalent to steering language model
83
+ generation styles. We name such steers LM-Steers and find them existing in
84
+ LMs of all sizes. It requires learning parameters equal to 0.2% of the
85
+ original LMs' size for steering each style.
86
+ '''
87
+
88
+ # set up the model
89
+ st.divider()
90
+ st.divider()
91
+ st.subheader("Select a model:")
92
+ '''
93
+ Due to resource limits, we are only able to provide a few models for
94
+ steering. You can also refer to the Github repository:
95
+ https://github.com/Glaciohound/LM-Steer for hosting larger models.
96
+ '''
97
+ col1, col2 = st.columns(2)
98
+ st.session_state.model_name = col1.selectbox(
99
+ "Select a model to steer",
100
+ [
101
+ "gpt2",
102
+ "gpt2-medium",
103
+ "gpt2-large",
104
+ "EleutherAI/pythia-70m",
105
+ "EleutherAI/pythia-160m",
106
+ "EleutherAI/pythia-410m",
107
+ # "EleutherAI/pythia-1b", "EleutherAI/pythia-1.4b",
108
+ # "EleutherAI/pythia-2.8b", "EleutherAI/pythia-6.9b",
109
+ # "EleutherAI/gpt-j-6B",
110
+ ],
111
+ )
112
+ low_resource_mode = True if st.session_state.model_name in (
113
+ "EleutherAI/pythia-1.4b", "EleutherAI/pythia-2.8b",
114
+ "EleutherAI/pythia-6.9b", "EleutherAI/gpt-j-6B",
115
+ ) else False
116
+ model, tokenizer = st_get_model(
117
+ st.session_state.model_name, low_resource_mode)
118
+ num_param = model.steer.projector1.data.shape[1] ** 2 / 1024 ** 2
119
+ total_param = sum(p.numel() for _, p in model.named_parameters()) / \
120
+ 1024 ** 2
121
+ ratio = num_param / total_param
122
+ col2.write(f"Steered {num_param:.1f}M out of {total_param:.1f}M "
123
+ "parameters, ratio: {:.2%}".format(ratio))
124
+
125
+ # steering
126
+ steer_range = 4.
127
+ steer_interval = 0.5
128
+ st.subheader("Enter a sentence and steer the model")
129
+ st.session_state.prompt = st.text_input(
130
+ "Enter a prompt",
131
+ st.session_state.get("prompt", "My life")
132
+ )
133
+ # col1, col2, col3 = st.columns(3, gap="medium")
134
+ col1, col2, col3 = st.columns([2, 2, 1], gap="medium")
135
+ sentiment = col1.slider(
136
+ "Sentiment", -steer_range, steer_range, 3.0, steer_interval)
137
+ detoxification = col2.slider(
138
+ "Detoxification Strength", -steer_range, steer_range, 0.0,
139
+ steer_interval)
140
+ max_length = col3.number_input("Max length", 50, 300, 50, 50)
141
+ col1, col2, col3, _ = st.columns(4)
142
+ randomness = col2.checkbox("Random sampling", value=False)
143
+
144
+ if "output" not in st.session_state:
145
+ st.session_state.output = ""
146
+ if col1.button("Steer and generate!", type="primary"):
147
+ steer_values = [detoxification, 0, sentiment, 0]
148
+ st.session_state.output = model.generate(
149
+ st.session_state.prompt,
150
+ steer_values,
151
+ seed=None if randomness else 0,
152
+ min_length=0,
153
+ max_length=max_length,
154
+ do_sample=True,
155
+ )
156
+ analyzed_text = \
157
+ st.text_area("Generated text:", st.session_state.output, height=200)
158
+
159
+ # Analysing the sentence
160
+ st.divider()
161
+ st.divider()
162
+ st.subheader("Analyzing Styled Texts")
163
+ '''
164
+ LM-Steer also serves as a probe for analyzing the text. It can be used to
165
+ analyze the sentiment and detoxification of the text. Now, we proceed and
166
+ use LM-Steer to analyze the text in the box above. You can also modify the
167
+ text or use your own. Please note that these two dimensions can be
168
+ entangled, as a negative sentiment may also detoxify the text.
169
+ '''
170
+ if st.session_state.get("output", "") != "" and \
171
+ st.button("Analyze the styled text", type="primary"):
172
+ col1, col2 = st.columns(2)
173
+ for name, col, dim, color in zip(
174
+ ["Sentiment", "Detoxification"],
175
+ [col1, col2],
176
+ [2, 0],
177
+ ["#ff7f0e", "#1f77b4"],
178
+ ):
179
+ col.subheader(name)
180
+ # classification
181
+ col.markdown("##### Dimension-Wise Classification Distribution")
182
+ _, dist_list, _ = model.steer_analysis(
183
+ analyzed_text,
184
+ dim, -steer_range, steer_range,
185
+ bins=2*int(steer_range)+1,
186
+ )
187
+ dist_list = np.array(dist_list)
188
+ col.bar_chart(
189
+ pd.DataFrame(
190
+ {
191
+ "Value": dist_list[:, 0],
192
+ "Probability": dist_list[:, 1],
193
+ }
194
+ ), x="Value", y="Probability",
195
+ color=color,
196
+ )
197
+
198
+ # key tokens
199
+ pos_steer, neg_steer = np.zeros((2, 4))
200
+ pos_steer[dim] = 1
201
+ neg_steer[dim] = -1
202
+ _, token_evidence = model.evidence_words(
203
+ analyzed_text,
204
+ [pos_steer, neg_steer],
205
+ )
206
+ tokens = tokenizer(analyzed_text).input_ids
207
+ tokens = [f"{i:3d}: {tokenizer.decode([t])}"
208
+ for i, t in enumerate(tokens)]
209
+ col.markdown("##### Token's Evidence Score in the Dimension")
210
+ col.bar_chart(
211
+ pd.DataFrame(
212
+ {
213
+ "Token": tokens[1:],
214
+ "Evidence": token_evidence,
215
+ }
216
+ ), x="Token", y="Evidence",
217
+ horizontal=True, color=color,
218
+ )
219
+
220
+ st.divider()
221
+ st.divider()
222
+ st.subheader("The Word Embeddings Space Analysis")
223
+ '''
224
+ LM-Steer provides a lens on how word embeddings correlate with LM word
225
+ embeddings: what word dimensions contribute to or contrast to a specific
226
+ style. This analysis can be used to understand the word embedding space
227
+ and how it steers the model's generation.
228
+ Note that due to the bidirectional nature of the embedding spaces, in each
229
+ dimension, sometimes only one side of the word embeddings is most relevant
230
+ to the style (can be either left or right).
231
+ '''
232
+ dimension = st.selectbox(
233
+ "Select a dimension to analyze",
234
+ ["Sentiment", "Detoxification"],
235
+ )
236
+ dim = 2 if dimension == "Sentiment" else 0
237
+ word_embedding_space_analysis(model, tokenizer, dim)
238
+
239
+
240
+ if __name__ == "__main__":
241
+ main()
lm_steer/__init__.py ADDED
File without changes
lm_steer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (166 Bytes). View file
 
lm_steer/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.44 kB). View file
 
lm_steer/arguments.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pprint import pprint
2
+ import argparse
3
+ from .utils import set_seed
4
+
5
+
6
+ def parse_args():
7
+ # Model related
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--model_name", type=str,
10
+ default='EleutherAI/gpt-neo-2.7B')
11
+ parser.add_argument("--adaptor_class", type=str, default="multiply")
12
+ parser.add_argument("--adapted_component", type=str, default="final_layer")
13
+ parser.add_argument("--epsilon", type=float, default=1e-3)
14
+ parser.add_argument("--init_var", type=float, default=1e-2)
15
+ parser.add_argument("--rank", type=int, default=1000)
16
+ parser.add_argument("--num_steers", type=int, default=10)
17
+ parser.add_argument("--temperature", type=int, default=1)
18
+ parser.add_argument("--cuda", action="store_true")
19
+ parser.add_argument("--low_resource_mode", action="store_true")
20
+
21
+ # Data related
22
+ parser.add_argument("--data_dir", type=str, default=None)
23
+ parser.add_argument("--dataset_name", type=str, default=None)
24
+ parser.add_argument("--eval_file", type=str, default=None)
25
+ parser.add_argument("--output_file", type=str, default=None)
26
+ parser.add_argument("--data_size", type=int, default=None)
27
+ parser.add_argument("--split", type=str, default=None)
28
+
29
+ # Training related
30
+ parser.add_argument("--regularization", type=float, default=0)
31
+ parser.add_argument("--optimizer", type=str, default="Adam")
32
+ parser.add_argument("--lr", type=float, default=1e-3)
33
+ parser.add_argument("--gamma_mean", type=float, default=0.99)
34
+ parser.add_argument("--n_steps", type=int, default=10000)
35
+ parser.add_argument("--seed", type=int, default=0)
36
+ parser.add_argument("--ckpt_name", type=str, default=None)
37
+ parser.add_argument("--max_length", type=int, default=256)
38
+ parser.add_argument("--batch_size", type=int, default=32)
39
+ parser.add_argument("--log_step", type=int, default=500)
40
+ parser.add_argument("--subset", type=int, default=None)
41
+ parser.add_argument("--dummy_steer", type=int, default=None)
42
+ parser.add_argument("--training_steer", type=int, default=0)
43
+
44
+ # Evaluation related
45
+ parser.add_argument("--eval_size", type=int, default=None)
46
+ parser.add_argument("--steer_values", default=None, nargs="*", type=float)
47
+ parser.add_argument("--verbose", action="store_true")
48
+ parser.add_argument("--top_p", type=float, default=1)
49
+
50
+ # transfer related
51
+ parser.add_argument("--transfer_from", type=str, default=None)
52
+
53
+ args = parser.parse_args()
54
+
55
+ set_seed(args.seed)
56
+
57
+ print("arguments:")
58
+ pprint(args.__dict__)
59
+ return args
lm_steer/models/__pycache__/get_model.cpython-310.pyc ADDED
Binary file (1.48 kB). View file
 
lm_steer/models/__pycache__/model_base.cpython-310.pyc ADDED
Binary file (4.88 kB). View file
 
lm_steer/models/__pycache__/model_gpt_neo.cpython-310.pyc ADDED
Binary file (2.6 kB). View file
 
lm_steer/models/__pycache__/model_gpt_neox.cpython-310.pyc ADDED
Binary file (3.7 kB). View file
 
lm_steer/models/__pycache__/model_utils.cpython-310.pyc ADDED
Binary file (2.23 kB). View file
 
lm_steer/models/__pycache__/steers.cpython-310.pyc ADDED
Binary file (3.07 kB). View file
 
lm_steer/models/get_model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def get_model(model_name, adapted_component, adaptor_class, num_steers, rank,
3
+ epsilon, init_var, low_resource_mode):
4
+ if model_name.startswith("EleutherAI/gpt-neo") or \
5
+ model_name.startswith("gpt2"):
6
+ from lm_steer.models.model_gpt_neo import Switching_GPTNeoModel
7
+ model = Switching_GPTNeoModel(
8
+ model_name, adapted_component, adaptor_class, num_steers, rank,
9
+ epsilon, init_var, low_resource_mode)
10
+ return model, model.tokenizer
11
+ elif model_name.startswith("lora-gpt2"):
12
+ from lm_steer.models.model_lora_gpt_neo import LORA_GPTNeoModel
13
+ model = LORA_GPTNeoModel(model_name, rank, epsilon)
14
+ return model, model.tokenizer
15
+ elif model_name.startswith("embedding_tuning"):
16
+ from lm_steer.models.model_embedding_tuning_gpt_neo import \
17
+ EmbeddingTuning_GPTNeoModel
18
+ model = EmbeddingTuning_GPTNeoModel(model_name)
19
+ return model, model.tokenizer
20
+ elif model_name.startswith("prefix-gpt2"):
21
+ from lm_steer.models.model_prefix_gpt_neo import PREFIX_GPTNeoModel
22
+ model = PREFIX_GPTNeoModel(model_name)
23
+ return model, model.tokenizer
24
+ elif model_name.startswith("EleutherAI/pythia"):
25
+ from lm_steer.models.model_gpt_neox import Switching_GPTNeoXModel
26
+ model = Switching_GPTNeoXModel(
27
+ model_name, adapted_component, adaptor_class, num_steers, rank,
28
+ epsilon, init_var, low_resource_mode)
29
+ return model, model.tokenizer
30
+ elif model_name.startswith("EleutherAI/gpt-j"):
31
+ from lm_steer.models.model_gpt_j import Switching_GPTJModel
32
+ model = Switching_GPTJModel(
33
+ model_name, adapted_component, adaptor_class, num_steers, rank,
34
+ epsilon, init_var, low_resource_mode)
35
+ return model, model.tokenizer
36
+ elif model_name.startswith("microsoft/DialoGPT"):
37
+ from lm_steer.models.model_dialogpt import Switching_DialoGPTModel
38
+ model = Switching_DialoGPTModel(
39
+ model_name, adapted_component, adaptor_class, num_steers, rank,
40
+ epsilon, init_var, low_resource_mode)
41
+ return model, model.tokenizer
42
+ else:
43
+ raise NotImplementedError()
lm_steer/models/model_base.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.nn as nn
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+
7
+ from lm_steer.utils import set_seed
8
+ from .model_utils import find_max_subspans
9
+
10
+
11
+ punctuations = [
12
+ '!', '"', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.',
13
+ # '/', '#',
14
+ ':', ';', '<', '=', '>', '?', '@',
15
+ '[', '\\', ']', '^', '_', '`',
16
+ '{', '|', '}', '~',
17
+ '¨', '©', 'ª', '«', '¬', '®', '¯', '°', '±', '²', '³', '´', 'µ', '¶', '·',
18
+ '¸', '¹', 'º', '»', '¼', '½', '¾',
19
+ '\n', ' ',
20
+ ]
21
+
22
+
23
+ class LMSteerBase(nn.Module):
24
+ def evidence_words(self, prompt, comparing_steer_values,
25
+ truncation_length=1024, max_segments=4, max_length=10):
26
+ if isinstance(comparing_steer_values, list):
27
+ comparing_steer_values = \
28
+ torch.Tensor(comparing_steer_values).to(self.device)
29
+ if (comparing_steer_values[0] - comparing_steer_values[1]
30
+ ).abs().sum() <= 0.2:
31
+ return [(prompt, None)]
32
+ tokenized = self.tokenizer(
33
+ prompt, return_tensors="pt",
34
+ max_length=truncation_length, truncation=True)
35
+ input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device)
36
+ input_ids = input_ids.expand(2, -1)
37
+ attention_mask = torch.LongTensor(tokenized["attention_mask"]).to(
38
+ self.device)
39
+ attention_mask = attention_mask.expand(2, -1)
40
+ self.steer.set_value(comparing_steer_values)
41
+ with torch.no_grad():
42
+ output = self.model(
43
+ input_ids=input_ids,
44
+ attention_mask=attention_mask,
45
+ labels=input_ids)
46
+ length = input_ids.shape[1]
47
+ loss_token = F.cross_entropy(
48
+ output.logits[:, :-1].reshape((2)*(length-1), -1),
49
+ input_ids[:, 1:].reshape(-1),
50
+ reduction="none"
51
+ )
52
+ loss_token = loss_token.reshape(2, length - 1)
53
+
54
+ token_evidence = (- loss_token[0] + loss_token[1])
55
+ tokens = input_ids[0]
56
+ evidence_segments = find_max_subspans(
57
+ token_evidence.cpu().numpy().tolist(), max_segments, max_length)[0]
58
+ evidence_segments = [
59
+ (_seg[0]+1, _seg[1]+1) for _seg in evidence_segments]
60
+ start = 0
61
+ output = []
62
+ if len(evidence_segments) > 0:
63
+ for _segment in evidence_segments:
64
+ if _segment[0] > start:
65
+ output.append((
66
+ self.tokenizer.decode(tokens[start: _segment[0]]),
67
+ None
68
+ ))
69
+ output.append((
70
+ self.tokenizer.decode(tokens[_segment[0]: _segment[1]]),
71
+ "evidence"
72
+ ))
73
+ start = _segment[1]
74
+ length = tokens.shape[-1]
75
+ if _segment[1] < length:
76
+ output.append((
77
+ self.tokenizer.decode(tokens[_segment[1]: length]),
78
+ None
79
+ ))
80
+ else:
81
+ output = [(prompt, None)]
82
+
83
+ return output, token_evidence.tolist()
84
+
85
+ def steer_analysis(self, prompt, steer_dim, min_value=-3, max_value=3,
86
+ bins=7):
87
+ tokenized = self.tokenizer(prompt)
88
+ input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device)
89
+ input_ids = input_ids.expand(bins + 1, -1)
90
+ attention_mask = torch.LongTensor(tokenized["attention_mask"]).to(
91
+ self.device)
92
+ attention_mask = attention_mask.expand(bins + 1, -1)
93
+ steer_values = torch.zeros(bins+1, self.num_steers).to(self.device)
94
+ for bin_i in range(bins):
95
+ steer_values[bin_i, steer_dim] = (
96
+ min_value + (max_value - min_value) / (bins - 1) * bin_i
97
+ )
98
+ self.steer.set_value(steer_values)
99
+ with torch.no_grad():
100
+ output = self.model(
101
+ input_ids=input_ids,
102
+ attention_mask=attention_mask,
103
+ labels=input_ids)
104
+ length = input_ids.shape[1]
105
+ loss_token = F.cross_entropy(
106
+ output.logits[:, :-1].reshape((bins+1)*(length-1), -1),
107
+ input_ids[:, 1:].reshape(-1),
108
+ reduction="none"
109
+ )
110
+ loss_token = loss_token.reshape(bins + 1, length - 1)
111
+ loss = loss_token.mean(-1)[:-1]
112
+ dist = ((- loss + loss.mean()) * 100).softmax(0)
113
+ dist_list = list(zip(
114
+ [
115
+ min_value + (max_value - min_value) / (bins - 1) * bin_i
116
+ for bin_i in range(bins)
117
+ ],
118
+ dist.tolist(),
119
+ ))
120
+ best_guess = loss.argmin(0)
121
+ best_guess_value = min_value + \
122
+ (max_value - min_value) / (bins - 1) * best_guess.item()
123
+
124
+ token_evidence = (- loss_token[best_guess] + loss_token[-1]) * 10
125
+ token_evidence = [0] + token_evidence.tolist()
126
+ # tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
127
+
128
+ word_evidence_list = []
129
+ start = 0
130
+ n_tokens = len(input_ids[0])
131
+ for token_i in range(1, n_tokens+1):
132
+ span = self.tokenizer.decode(input_ids[0][start: token_i])
133
+ for _punc in punctuations:
134
+ if token_i == n_tokens or _punc in span:
135
+ new_span = self.tokenizer.decode(
136
+ input_ids[0][start: token_i-1]).strip()
137
+ if len(new_span) <= 1:
138
+ break
139
+ word_evidence_list.append((
140
+ new_span,
141
+ np.array(token_evidence[start: token_i-1]).mean()
142
+ ))
143
+ start = token_i - 1
144
+ break
145
+
146
+ # token_evidence_list = list(zip(tokens, token_evidence))
147
+ return best_guess_value, dist_list, word_evidence_list
148
+
149
+ def generate(self, prompt, steer_values, min_length=20, max_length=100,
150
+ seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
151
+ temperature=1, top_p=1):
152
+ '''
153
+ prompt: a string
154
+ steer_values
155
+ min_length: minimum generation length
156
+ max_length: maximum generation length
157
+ seed: seed for generation. None if not specified.
158
+ '''
159
+ if seed is not None:
160
+ set_seed(seed)
161
+ steer_values = torch.Tensor(steer_values).to(
162
+ self.device)
163
+ self.steer.set_value(steer_values[None])
164
+ with torch.no_grad():
165
+ text = self.generator(
166
+ prompt, num_beams=num_beams, num_beam_groups=num_beam_groups,
167
+ do_sample=do_sample, temperature=temperature, top_p=top_p,
168
+ min_length=min_length, max_length=max_length,
169
+ pad_token_id=self.tokenizer.pad_token_id,
170
+ )
171
+ text = text[0]["generated_text"]
172
+
173
+ return text
lm_steer/models/model_embedding_tuning_gpt_neo.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import pipeline
4
+
5
+ from .model_utils import Hack_no_grad
6
+ from lm_steer.utils import set_seed
7
+
8
+
9
+ class EmbeddingTuning_GPTNeoModel(nn.Module):
10
+ def __init__(self, model_name):
11
+ super().__init__()
12
+ self.generator = pipeline(
13
+ 'text-generation',
14
+ model=model_name.replace("embedding_tuning-", ""))
15
+ self.tokenizer = self.generator.tokenizer
16
+ self.model = self.generator.model
17
+ self.tokenizer.pad_token = self.tokenizer.eos_token
18
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
19
+
20
+ self.model.transformer = Hack_no_grad(self.model.transformer)
21
+
22
+ def forward(self, input_ids, attention_mask, steer_values):
23
+ output = self.model(
24
+ input_ids=input_ids,
25
+ attention_mask=attention_mask,
26
+ labels=input_ids)
27
+ return output
28
+
29
+ def parameters(self):
30
+ return [self.model.lm_head.weight]
31
+
32
+ def state_dict(self):
33
+ return self.model.lm_head.state_dict()
34
+
35
+ def load_state_dict(self, state_dict):
36
+ self.model.lm_head.load_state_dict(state_dict)
37
+
38
+ def to_device(self, device):
39
+ self.generator.device = device
40
+ self.model.to(device)
41
+ self.device = device
42
+
43
+ def regularization_term(self):
44
+ return torch.tensor(0)
45
+
46
+ def generate(self, prompt, steer_values, min_length=20, max_length=100,
47
+ seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
48
+ temperature=1, top_p=1):
49
+ if seed is not None:
50
+ set_seed(seed)
51
+ with torch.no_grad():
52
+ text = self.generator(
53
+ prompt, num_beams=num_beams, num_beam_groups=num_beam_groups,
54
+ do_sample=do_sample, temperature=temperature, top_p=top_p,
55
+ min_length=min_length, max_length=max_length,
56
+ pad_token_id=self.tokenizer.pad_token_id,
57
+ )
58
+ text = text[0]["generated_text"]
59
+ return text
lm_steer/models/model_gpt_j.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import GPTJForCausalLM, AutoTokenizer
6
+
7
+ from .model_utils import Hack_no_grad, find_max_subspans
8
+ from .steers import Projected_Adaptor
9
+ from lm_steer.utils import set_seed
10
+
11
+
12
+ punctuations = [
13
+ '!', '"', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.',
14
+ # '/', '#',
15
+ ':', ';', '<', '=', '>', '?', '@',
16
+ '[', '\\', ']', '^', '_', '`',
17
+ '{', '|', '}', '~',
18
+ '¨', '©', 'ª', '«', '¬', '®', '¯', '°', '±', '²', '³', '´', 'µ', '¶', '·',
19
+ '¸', '¹', 'º', '»', '¼', '½', '¾',
20
+ '\n', ' ',
21
+ ]
22
+
23
+
24
+ class Switching_GPTJModel(nn.Module):
25
+ def __init__(self, model_name, adapted_component, adaptor_class,
26
+ num_steers, rank, epsilon, init_var, low_resource_mode):
27
+ super().__init__()
28
+ self.adapted_component = adapted_component
29
+ self.adaptor_class = adaptor_class
30
+ # self.generator = pipeline('text-generation', model=model_name)
31
+ # self.tokenizer = self.generator.tokenizer
32
+ # self.model = self.generator.model
33
+ if low_resource_mode:
34
+ print("using low_resource_mode and fp16")
35
+ self.model = GPTJForCausalLM.from_pretrained(
36
+ "EleutherAI/gpt-j-6B", revision="float16",
37
+ torch_dtype=torch.float16, low_cpu_mem_usage=True
38
+ )
39
+ else:
40
+ self.model = GPTJForCausalLM.from_pretrained(
41
+ "EleutherAI/gpt-j-6B",
42
+ )
43
+ self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
44
+ self.tokenizer.pad_token = self.tokenizer.eos_token
45
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
46
+ self.init_var = init_var
47
+ self.num_steers = num_steers
48
+ self.device = torch.device("cpu")
49
+ self.low_resource_mode = low_resource_mode
50
+ embed_dim = self.model.lm_head.weight.shape[1]
51
+ vocab_size = self.model.lm_head.weight.shape[0]
52
+
53
+ for _param in self.model.parameters():
54
+ _param.requires_grad_(False)
55
+
56
+ if adapted_component == "final_layer":
57
+ self.model.transformer = Hack_no_grad(self.model.transformer)
58
+ self.steer = Projected_Adaptor(
59
+ self.model.lm_head, adaptor_class, num_steers, embed_dim,
60
+ vocab_size, rank, epsilon, init_var, "output")
61
+ self.model.set_output_embeddings(self.steer)
62
+ elif adapted_component == "input_embedding":
63
+ self.steer = Projected_Adaptor(
64
+ self.model.transformer.wte, adaptor_class, num_steers,
65
+ embed_dim, vocab_size, rank, epsilon, init_var, "input")
66
+ self.model.transformer.set_input_embeddings(self.steer)
67
+ else:
68
+ raise NotImplementedError()
69
+
70
+ def forward(self, input_ids, attention_mask, steer_values):
71
+ self.steer.set_value(steer_values)
72
+ output = self.model(
73
+ input_ids=input_ids,
74
+ attention_mask=attention_mask,
75
+ labels=input_ids)
76
+ return output
77
+
78
+ def parameters(self):
79
+ return self.steer.parameters()
80
+
81
+ def state_dict(self):
82
+ return self.steer.state_dict()
83
+
84
+ def load_state_dict(self, state_dict):
85
+ self.steer.load_state_dict(state_dict)
86
+
87
+ def to_device(self, device):
88
+ # self.generator.device = device
89
+ self.model.to(device)
90
+ self.device = device
91
+
92
+ def regularization_term(self):
93
+ return self.steer.regularization_term()
94
+
95
+ def generate(self, prompt, steer_values, min_length=20, max_length=100,
96
+ seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
97
+ temperature=1, top_p=1):
98
+ '''
99
+ prompt: a string
100
+ steer_values
101
+ min_length: minimum generation length
102
+ max_length: maximum generation length
103
+ seed: seed for generation. None if not specified.
104
+ '''
105
+ if seed is not None:
106
+ set_seed(seed)
107
+ steer_values = torch.Tensor(steer_values).to(
108
+ self.device)
109
+ if self.low_resource_mode:
110
+ fp16 = torch.float16
111
+ steer_values = steer_values.to(fp16)
112
+ self.steer.projector1.data = self.steer.projector1.to(fp16)
113
+ self.steer.projector2.data = self.steer.projector2.to(fp16)
114
+ self.steer.set_value(steer_values[None])
115
+ with torch.no_grad():
116
+ input_ids = self.tokenizer(
117
+ prompt, return_tensors="pt").input_ids.to(self.device)
118
+ gen_tokens = self.model.generate(
119
+ input_ids,
120
+ num_beams=num_beams, num_beam_groups=num_beam_groups,
121
+ do_sample=do_sample, temperature=temperature, top_p=top_p,
122
+ min_new_tokens=min_length, max_new_tokens=max_length,
123
+ pad_token_id=self.tokenizer.pad_token_id)
124
+ text = self.tokenizer.batch_decode(gen_tokens)[0]
125
+
126
+ # recovering
127
+ if self.low_resource_mode:
128
+ fp32 = torch.float32
129
+ self.steer.projector1.data = self.steer.projector1.to(fp32)
130
+ self.steer.projector2.data = self.steer.projector2.to(fp32)
131
+ return text
132
+
133
+ def generate_multiple(
134
+ self, prompts, steer_values, min_length=20, max_length=100,
135
+ seed=None):
136
+ '''
137
+ prompt: a string
138
+ steer_values
139
+ min_length: minimum generation length
140
+ max_length: maximum generation length
141
+ seed: seed for generation. None if not specified.
142
+ '''
143
+ if seed is not None:
144
+ set_seed(seed)
145
+ steer_values = torch.Tensor(steer_values).to(
146
+ self.device)
147
+ if self.low_resource_mode:
148
+ fp16 = torch.float16
149
+ steer_values = steer_values.to(fp16)
150
+ self.steer.projector1.data = self.steer.projector1.to(fp16)
151
+ self.steer.projector2.data = self.steer.projector2.to(fp16)
152
+ self.steer.set_value(steer_values)
153
+ with torch.no_grad():
154
+ input_ids = self.tokenizer(
155
+ prompts, return_tensors="pt").input_ids.to(self.device)
156
+ gen_tokens = self.model.generate(
157
+ input_ids,
158
+ do_sample=True,
159
+ min_new_tokens=min_length, max_new_tokens=max_length,
160
+ pad_token_id=self.tokenizer.pad_token_id)
161
+ text = self.tokenizer.batch_decode(gen_tokens)
162
+
163
+ # recovering
164
+ if self.low_resource_mode:
165
+ fp32 = torch.float32
166
+ self.steer.projector1.data = self.steer.projector1.to(fp32)
167
+ self.steer.projector2.data = self.steer.projector2.to(fp32)
168
+ return text
169
+
170
+ # def evidence_words(self, prompt, original_steer_values, max_segments=4,
171
+ # max_length=10):
172
+ # if isinstance(original_steer_values, list):
173
+ # original_steer_values = torch.Tensor(original_steer_values)
174
+ # if original_steer_values.abs().sum() <= 0.2:
175
+ # return [(prompt, None)]
176
+ # tokenized = self.tokenizer(prompt)
177
+ # input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device)
178
+ # input_ids = input_ids.expand(2, -1)
179
+ # attention_mask = torch.LongTensor(tokenized["attention_mask"]).to(
180
+ # self.device)
181
+ # attention_mask = attention_mask.expand(2, -1)
182
+ # steer_values = torch.zeros(2, self.num_steers).to(self.device)
183
+ # steer_values[0] = original_steer_values
184
+ # steer_values[1] = (-original_steer_values > 0) * 2 - 1
185
+ # if self.low_resource_mode:
186
+ # fp16 = torch.float16
187
+ # steer_values = steer_values.to(fp16)
188
+ # self.steer.projector1.data = self.steer.projector1.to(fp16)
189
+ # self.steer.projector2.data = self.steer.projector2.to(fp16)
190
+ # self.steer.set_value(steer_values)
191
+ # with torch.no_grad():
192
+ # output = self.model(
193
+ # input_ids=input_ids,
194
+ # attention_mask=attention_mask,
195
+ # labels=input_ids)
196
+ # length = input_ids.shape[1]
197
+ # loss_token = F.cross_entropy(
198
+ # output.logits[:, :-1].reshape((2)*(length-1), -1),
199
+ # input_ids[:, 1:].reshape(-1),
200
+ # reduction="none"
201
+ # )
202
+ # loss_token = loss_token.reshape(2, length - 1)
203
+
204
+ def evidence_words(self, prompt, original_steer_values,
205
+ truncation_length=1024, max_segments=4, max_length=10):
206
+ if isinstance(original_steer_values, list):
207
+ original_steer_values = torch.Tensor(original_steer_values)
208
+ if original_steer_values.abs().sum() <= 0.2:
209
+ return [(prompt, None)]
210
+ tokenized = self.tokenizer(
211
+ prompt, return_tensors="pt", max_length=truncation_length, truncation=True)
212
+ input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device)
213
+ input_ids = input_ids.expand(2, -1)
214
+ attention_mask = torch.LongTensor(tokenized["attention_mask"]).to(
215
+ self.device)
216
+ attention_mask = attention_mask.expand(2, -1)
217
+ steer_values = torch.zeros(2, self.num_steers).to(self.device)
218
+ steer_values[0] = original_steer_values
219
+ steer_values[1] = (-original_steer_values > 0) * 2 - 1
220
+ if self.low_resource_mode:
221
+ fp16 = torch.float16
222
+ steer_values = steer_values.to(fp16)
223
+ self.steer.projector1.data = self.steer.projector1.to(fp16)
224
+ self.steer.projector2.data = self.steer.projector2.to(fp16)
225
+ self.steer.set_value(steer_values)
226
+ with torch.no_grad():
227
+ output = self.model(
228
+ input_ids=input_ids,
229
+ attention_mask=attention_mask,
230
+ labels=input_ids)
231
+ length = input_ids.shape[1]
232
+ loss_token = F.cross_entropy(
233
+ output.logits[:, :-1].reshape((2)*(length-1), -1),
234
+ input_ids[:, 1:].reshape(-1),
235
+ reduction="none"
236
+ )
237
+ loss_token = loss_token.reshape(2, length - 1)
238
+
239
+ token_evidence = (- loss_token[0] + loss_token[1])
240
+ tokens = input_ids[0]
241
+ evidence_segments = find_max_subspans(
242
+ token_evidence.cpu().numpy().tolist(), max_segments, max_length)[0]
243
+ evidence_segments = [
244
+ (_seg[0]+1, _seg[1]+1) for _seg in evidence_segments]
245
+ start = 0
246
+ output = []
247
+ color = (
248
+ "gray" if original_steer_values.shape[0] > 1
249
+ else "red" if original_steer_values[0] > 0
250
+ else "blue"
251
+ )
252
+ if len(evidence_segments) > 0:
253
+ for _segment in evidence_segments:
254
+ if _segment[0] > start:
255
+ output.append((
256
+ self.tokenizer.decode(tokens[start: _segment[0]]),
257
+ None
258
+ ))
259
+ output.append((
260
+ self.tokenizer.decode(tokens[_segment[0]: _segment[1]]),
261
+ color
262
+ ))
263
+ start = _segment[1]
264
+ length = tokens.shape[-1]
265
+ if _segment[1] < length:
266
+ output.append((
267
+ self.tokenizer.decode(tokens[_segment[1]: length]),
268
+ None
269
+ ))
270
+ else:
271
+ output = [(prompt, None)]
272
+
273
+ if self.low_resource_mode:
274
+ fp32 = torch.float32
275
+ self.steer.projector1.data = self.steer.projector1.to(fp32)
276
+ self.steer.projector2.data = self.steer.projector2.to(fp32)
277
+ return output
278
+
279
+ def steer_analysis(self, prompt, steer_dim, min_value=-3, max_value=3,
280
+ bins=7, truncation_length=1024):
281
+ tokenized = self.tokenizer(
282
+ prompt, return_tensors="pt",
283
+ max_length=truncation_length,
284
+ truncation=True)
285
+ input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device)
286
+ input_ids = input_ids.expand(bins + 1, -1)
287
+ attention_mask = torch.LongTensor(tokenized["attention_mask"]).to(
288
+ self.device)
289
+ attention_mask = attention_mask.expand(bins + 1, -1)
290
+ steer_values = torch.zeros(bins+1, self.num_steers).to(self.device)
291
+ for bin_i in range(bins):
292
+ steer_values[bin_i, steer_dim] = (
293
+ min_value + (max_value - min_value) / (bins - 1) * bin_i
294
+ )
295
+ if self.low_resource_mode:
296
+ fp16 = torch.float16
297
+ steer_values = steer_values.to(fp16)
298
+ self.steer.projector1.data = self.steer.projector1.to(fp16)
299
+ self.steer.projector2.data = self.steer.projector2.to(fp16)
300
+ self.steer.set_value(steer_values)
301
+ with torch.no_grad():
302
+ output = self.model(
303
+ input_ids=input_ids,
304
+ attention_mask=attention_mask,
305
+ labels=input_ids)
306
+ length = input_ids.shape[1]
307
+ loss_token = F.cross_entropy(
308
+ output.logits[:, :-1].reshape((bins+1)*(length-1), -1),
309
+ input_ids[:, 1:].reshape(-1),
310
+ reduction="none"
311
+ )
312
+ loss_token = loss_token.reshape(bins + 1, length - 1)
313
+ loss = loss_token.mean(-1)[:-1]
314
+ dist = ((- loss + loss.mean()) * 100).softmax(0)
315
+ dist_list = list(zip(
316
+ [
317
+ min_value + (max_value - min_value) / (bins - 1) * bin_i
318
+ for bin_i in range(bins)
319
+ ],
320
+ dist.tolist(),
321
+ ))
322
+ best_guess = loss.argmin(0)
323
+ best_guess_value = min_value + \
324
+ (max_value - min_value) / (bins - 1) * best_guess.item()
325
+
326
+ token_evidence = self.evidence_words(
327
+ prompt, steer_values[best_guess],
328
+ )
329
+
330
+ if self.low_resource_mode:
331
+ fp32 = torch.float32
332
+ self.steer.projector1.data = self.steer.projector1.to(fp32)
333
+ return best_guess_value, dist_list, token_evidence
lm_steer/models/model_gpt_neo.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import pipeline
3
+
4
+ from .model_utils import Hack_no_grad
5
+ from .steers import Projected_Adaptor
6
+ from .model_base import LMSteerBase
7
+
8
+
9
+ class Switching_GPTNeoModel(LMSteerBase):
10
+ def __init__(self, model_name, adapted_component, adaptor_class,
11
+ num_steers, rank, epsilon, init_var,
12
+ low_resource_mode):
13
+ super().__init__()
14
+ self.adapted_component = adapted_component
15
+ self.generator = pipeline('text-generation', model=model_name)
16
+ self.tokenizer = self.generator.tokenizer
17
+ self.model = self.generator.model
18
+ self.tokenizer.pad_token = self.tokenizer.eos_token
19
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
20
+ self.init_var = init_var
21
+ self.num_steers = num_steers
22
+ self.device = torch.device("cpu")
23
+ embed_dim = self.model.lm_head.weight.shape[1]
24
+ vocab_size = self.model.lm_head.weight.shape[0]
25
+
26
+ for _param in self.model.parameters():
27
+ _param.requires_grad_(False)
28
+
29
+ if adapted_component == "final_layer":
30
+ self.model.transformer = Hack_no_grad(self.model.transformer)
31
+ self.steer = Projected_Adaptor(
32
+ self.model.lm_head, adaptor_class, num_steers, embed_dim,
33
+ vocab_size, rank, epsilon, init_var, "output")
34
+ self.model.set_output_embeddings(self.steer)
35
+ elif adapted_component == "input_embedding":
36
+ self.steer = Projected_Adaptor(
37
+ self.model.transformer.wte, adaptor_class, num_steers,
38
+ embed_dim, vocab_size, rank, epsilon, init_var, "input")
39
+ self.model.transformer.set_input_embeddings(self.steer)
40
+ else:
41
+ raise NotImplementedError()
42
+
43
+ def forward(self, input_ids, attention_mask, steer_values):
44
+ self.steer.set_value(steer_values)
45
+ output = self.model(
46
+ input_ids=input_ids,
47
+ attention_mask=attention_mask,
48
+ labels=input_ids)
49
+ return output
50
+
51
+ def parameters(self):
52
+ return self.steer.parameters()
53
+
54
+ def state_dict(self):
55
+ return self.steer.state_dict()
56
+
57
+ def load_state_dict(self, state_dict):
58
+ self.steer.load_state_dict(state_dict)
59
+
60
+ def to_device(self, device):
61
+ self.generator.device = device
62
+ self.model.to(device)
63
+ self.device = device
64
+
65
+ def regularization_term(self):
66
+ return self.steer.regularization_term()
lm_steer/models/model_gpt_neox.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import GPTNeoXForCausalLM, AutoTokenizer
3
+
4
+ from .model_utils import Hack_no_grad
5
+ from .steers import Projected_Adaptor
6
+ from .model_base import LMSteerBase
7
+ from lm_steer.utils import set_seed
8
+
9
+
10
+ class Switching_GPTNeoXModel(LMSteerBase):
11
+ def __init__(self, model_name, adapted_component, adaptor_class,
12
+ num_steers, rank, epsilon, init_var,
13
+ low_resource_mode):
14
+ super().__init__()
15
+ self.adapted_component = adapted_component
16
+ if low_resource_mode:
17
+ self.model = GPTNeoXForCausalLM.from_pretrained(
18
+ model_name,
19
+ torch_dtype=torch.float16, low_cpu_mem_usage=True
20
+ )
21
+ else:
22
+ self.model = GPTNeoXForCausalLM.from_pretrained(model_name)
23
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
24
+ self.tokenizer.pad_token = self.tokenizer.eos_token
25
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
26
+ self.init_var = init_var
27
+ self.num_steers = num_steers
28
+ self.device = torch.device("cpu")
29
+ embed_dim = self.model.embed_out.weight.shape[1]
30
+ vocab_size = self.model.embed_out.weight.shape[0]
31
+ self.low_resource_mode = low_resource_mode
32
+
33
+ for _param in self.model.parameters():
34
+ _param.requires_grad_(False)
35
+
36
+ if adapted_component == "final_layer":
37
+ self.model.gpt_neox = Hack_no_grad(self.model.gpt_neox)
38
+ self.steer = Projected_Adaptor(
39
+ self.model.embed_out, adaptor_class, num_steers, embed_dim,
40
+ vocab_size, rank, epsilon, init_var, "output")
41
+ self.model.set_output_embeddings(self.steer)
42
+ else:
43
+ raise NotImplementedError()
44
+
45
+ def forward(self, input_ids, attention_mask, steer_values):
46
+ self.steer.set_value(steer_values)
47
+ output = self.model(
48
+ input_ids=input_ids,
49
+ attention_mask=attention_mask,
50
+ labels=input_ids)
51
+ return output
52
+
53
+ def parameters(self):
54
+ return self.steer.parameters()
55
+
56
+ def state_dict(self):
57
+ return self.steer.state_dict()
58
+
59
+ def load_state_dict(self, state_dict):
60
+ self.steer.load_state_dict(state_dict)
61
+
62
+ def to_device(self, device):
63
+ self.model.to(device)
64
+ self.device = device
65
+
66
+ def regularization_term(self):
67
+ return self.steer.regularization_term()
68
+
69
+ def generate(self, prompt, steer_values, min_length=20, max_length=100,
70
+ seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
71
+ temperature=1, top_p=1):
72
+ '''
73
+ prompt: a string
74
+ steer_values
75
+ min_length: minimum generation length
76
+ max_length: maximum generation length
77
+ seed: seed for generation. None if not specified.
78
+ '''
79
+ if seed is not None:
80
+ set_seed(seed)
81
+ steer_values = torch.Tensor(steer_values).to(
82
+ self.device)
83
+ if self.low_resource_mode:
84
+ fp16 = torch.float16
85
+ steer_values = steer_values.to(fp16)
86
+ self.steer.projector1.data = self.steer.projector1.to(fp16)
87
+ self.steer.projector2.data = self.steer.projector2.to(fp16)
88
+ self.steer.set_value(steer_values[None])
89
+ with torch.no_grad():
90
+ input_ids = self.tokenizer(
91
+ prompt, return_tensors="pt").input_ids.to(self.device)
92
+ gen_tokens = self.model.generate(
93
+ input_ids,
94
+ num_beams=num_beams, num_beam_groups=num_beam_groups,
95
+ do_sample=do_sample, temperature=temperature, top_p=top_p,
96
+ min_length=min_length, max_length=max_length,
97
+ pad_token_id=self.tokenizer.pad_token_id)
98
+ text = self.tokenizer.batch_decode(gen_tokens)[0]
99
+
100
+ # recovering
101
+ if self.low_resource_mode:
102
+ fp32 = torch.float32
103
+ self.steer.projector1.data = self.steer.projector1.to(fp32)
104
+ self.steer.projector2.data = self.steer.projector2.to(fp32)
105
+ return text
lm_steer/models/model_lora_gpt_neo.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import pipeline
4
+ from peft import LoraConfig, get_peft_model
5
+
6
+ from lm_steer.utils import set_seed
7
+
8
+
9
+ class LORA_GPTNeoModel(nn.Module):
10
+ def __init__(self, model_name, rank, epsilon):
11
+ super().__init__()
12
+ self.generator = pipeline('text-generation',
13
+ model=model_name.replace("lora-", ""))
14
+ self.tokenizer = self.generator.tokenizer
15
+ model = self.generator.model
16
+ self.tokenizer.pad_token = self.tokenizer.eos_token
17
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
18
+
19
+ config = LoraConfig(
20
+ r=rank,
21
+ lora_alpha=epsilon,
22
+ target_modules=["c_attn", "c_proj", "c_fc"],
23
+ lora_dropout=0.1,
24
+ bias="lora_only",
25
+ modules_to_save=[],
26
+ )
27
+ self.model = get_peft_model(model, config)
28
+ self.generator.model = self.model
29
+ self.model.print_trainable_parameters()
30
+
31
+ def forward(self, input_ids, attention_mask, steer_values):
32
+ output = self.model(
33
+ input_ids=input_ids,
34
+ attention_mask=attention_mask,
35
+ labels=input_ids)
36
+ return output
37
+
38
+ def to_device(self, device):
39
+ self.generator.device = device
40
+ self.model.to(device)
41
+ self.device = device
42
+
43
+ def regularization_term(self):
44
+ return torch.tensor(0)
45
+
46
+ def generate(self, prompt, steer_values, min_length=20, max_length=100,
47
+ seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
48
+ temperature=1, top_p=1):
49
+ if seed is not None:
50
+ set_seed(seed)
51
+ with torch.no_grad():
52
+ text = self.generator(
53
+ prompt, num_beams=num_beams, num_beam_groups=num_beam_groups,
54
+ do_sample=do_sample, temperature=temperature, top_p=top_p,
55
+ min_length=min_length, max_length=max_length,
56
+ pad_token_id=self.tokenizer.pad_token_id,
57
+ )
58
+ text = text[0]["generated_text"]
59
+ return text
lm_steer/models/model_utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ class Hack_no_grad(nn.Module):
7
+ def __init__(self, module):
8
+ super().__init__()
9
+ self.module = module
10
+
11
+ def forward(self, *inputs, **kwargs):
12
+ with torch.no_grad():
13
+ return self.module(*inputs, **kwargs)
14
+
15
+
16
+ def find_max_subspans(sequence, n_spans, max_length):
17
+ length = len(sequence)
18
+ inner_scores = np.zeros((length, n_spans + 1, max_length + 1, 2))
19
+ trace = np.zeros((length, n_spans + 1, max_length + 1, 2, 3), dtype=int)
20
+ # trace[:, n_spans, max_length, 0] = (n_spans, max_length, 0)
21
+ inner_scores[-1, :, :, 1] = -1e5
22
+ for _i in range(length):
23
+ for _j in range(n_spans+1):
24
+ for _k in range(max_length+1):
25
+ trace[_i, _j, _k, 0] = (_j, max_length, 0)
26
+
27
+ for _i in range(length):
28
+ for _j in range(n_spans):
29
+ for _k in range(max_length+1):
30
+ inner_scores[_i, _j, _k, 0], trace[_i, _j, _k, 0] = (
31
+ inner_scores[_i-1, _j, max_length, 0],
32
+ (_j, max_length, 0)
33
+ )
34
+ max_taken = inner_scores[_i-1, _j, :, 1].max()
35
+ if max_taken > inner_scores[_i, _j, _k, 0]:
36
+ inner_scores[_i, _j, _k, 0] = max_taken
37
+ trace[_i, _j, _k, 0] = (
38
+ _j, inner_scores[_i-1, _j, :, 1].argmax(), 1)
39
+
40
+ if _k < max_length:
41
+ inner_scores[_i, _j, _k, 1], trace[_i, _j, _k, 1] = (
42
+ (
43
+ inner_scores[_i-1, _j, _k+1, 1] + sequence[_i],
44
+ (_j, _k+1, 1)
45
+ )
46
+ if (inner_scores[_i-1, _j, _k+1, 1] >
47
+ inner_scores[_i-1, _j+1, max_length, 0])
48
+ else (
49
+ inner_scores[_i-1, _j+1, max_length, 0] +
50
+ sequence[_i],
51
+ (_j+1, max_length, 0)
52
+ )
53
+ )
54
+
55
+ max_score = 0
56
+ argmax = (0, 0, 0)
57
+ for _j in reversed(range(n_spans + 1)):
58
+ for _k in reversed(range(max_length)):
59
+ if inner_scores[-1, _j, _k, 0] > max_score:
60
+ max_score = inner_scores[-1, _j, _k, 0]
61
+ argmax = (_j, _k, 0)
62
+ if inner_scores[-1, _j, _k, 1] > max_score:
63
+ max_score = inner_scores[-1, _j, _k, 1]
64
+ argmax = (_j, _k, 1)
65
+
66
+ trace_back = argmax
67
+ tags = []
68
+ for _i in reversed(range(length)):
69
+ tags.append(trace_back[2])
70
+ trace_back = trace[_i, trace_back[0], trace_back[1], trace_back[2]]
71
+
72
+ tags.reverse()
73
+ segments = []
74
+ start = None
75
+ for _i in range(length + 1):
76
+ if _i < length and tags[_i] == 1 and start is None:
77
+ start = _i
78
+ elif (_i == length or tags[_i] == 0) and start is not None:
79
+ segments.append((start, _i))
80
+ start = None
81
+ return segments, max_score, tags # , inner_scores, trace
lm_steer/models/steers.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Projected_Adaptor(nn.Module):
6
+ def __init__(self, lm_head, adaptor_class, num_steers, embed_dim,
7
+ vocab_size, rank, epsilon, init_var, position="output"):
8
+ super().__init__()
9
+ assert rank > 0
10
+ if adaptor_class == "multiply":
11
+ self.projector1 = nn.Parameter(torch.randn(
12
+ num_steers, embed_dim, rank
13
+ ) * init_var)
14
+ self.projector2 = nn.Parameter(torch.randn(
15
+ num_steers, embed_dim, rank
16
+ ) * init_var)
17
+ elif adaptor_class == "add":
18
+ self.add_vec = nn.Parameter(torch.randn(
19
+ num_steers, embed_dim
20
+ ))
21
+ elif adaptor_class == "offset":
22
+ self.offset_vec = nn.Parameter(torch.randn(
23
+ num_steers, vocab_size
24
+ ))
25
+ else:
26
+ raise NotImplementedError()
27
+
28
+ self.adaptor_class = adaptor_class
29
+ self.rank = rank
30
+ self.lm_head = lm_head
31
+ self.epsilon = epsilon
32
+ self.position = position
33
+ self.num_steers = num_steers
34
+ self.init_var = init_var
35
+ self.steer_values = torch.zeros(num_steers)
36
+
37
+ def set_value(self, steer_values):
38
+ self.steer_values = steer_values
39
+
40
+ def forward(self, state):
41
+ if self.steer_values.abs().sum() == 0:
42
+ return state.matmul(
43
+ self.lm_head.weight.detach().transpose(0, 1))
44
+ if self.adaptor_class == "multiply":
45
+ delta = state[:, None].matmul(self.projector1[None]) *\
46
+ self.steer_values[:, :, None, None]
47
+ delta = delta.matmul(
48
+ self.projector2.transpose(1, 2)[None]).sum(1)
49
+ projected_state = state + self.epsilon * delta
50
+ logits = projected_state.matmul(
51
+ self.lm_head.weight.detach().transpose(0, 1))
52
+ elif self.adaptor_class == "add":
53
+ add_values = self.steer_values.matmul(self.add_vec)
54
+ projected_state = state + self.epsilon * add_values[:, None]
55
+ logits = projected_state.matmul(
56
+ self.lm_head.weight.detach().transpose(0, 1))
57
+ elif self.adaptor_class == "offset":
58
+ offset_values = self.steer_values.matmul(self.offset_vec)
59
+ logits = state.matmul(
60
+ self.lm_head.weight.detach().transpose(0, 1))
61
+ logits = logits + self.epsilon * offset_values[:, None]
62
+ return logits
63
+
64
+ def regularization_term(self):
65
+ if self.adaptor_class == "multiply":
66
+ return self.projector1.pow(2).sum() + self.projector2.pow(2).sum()
67
+ elif self.adaptor_class == "add":
68
+ return self.add_vec.pow(2).sum()
69
+ elif self.adaptor_class == "offset":
70
+ return self.offset_vec.pow(2).sum()
71
+
72
+ def parameters(self):
73
+ if self.adaptor_class == "multiply":
74
+ return [self.projector1, self.projector2]
75
+ elif self.adaptor_class == "add":
76
+ return [self.add_vec]
77
+ elif self.adaptor_class == "offset":
78
+ return [self.offset_vec]
79
+
80
+ def state_dict(self):
81
+ if self.adaptor_class == "multiply":
82
+ return {"projector1": self.projector1,
83
+ "projector2": self.projector2}
84
+ elif self.adaptor_class == "add":
85
+ return {"add_vec": self.add_vec}
86
+ elif self.adaptor_class == "offset":
87
+ return {"offset_vec": self.offset_vec}
88
+
89
+ def load_state_dict(self, state_dict):
90
+ if self.adaptor_class == "multiply":
91
+ self.projector1.data = state_dict["projector1"]
92
+ self.projector2.data = state_dict["projector2"]
93
+ elif self.adaptor_class == "add":
94
+ self.add_vec.data = state_dict["add_vec"]
95
+ elif self.adaptor_class == "offset":
96
+ self.offset_vec.data = state_dict["offset_vec"]
lm_steer/utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import numpy as np
4
+
5
+
6
+ def set_seed(seed):
7
+ if seed is None:
8
+ return
9
+ torch.manual_seed(seed)
10
+ torch.cuda.manual_seed_all(seed)
11
+ np.random.seed(seed)
12
+ random.seed(seed)
13
+
14
+
15
+ class RunningMean:
16
+ def __init__(self, gamma):
17
+ self.gamma = gamma
18
+ self.count = 0
19
+ self._value = None
20
+
21
+ def update(self, value):
22
+ value = value.detach().cpu()
23
+ if value.ndim == 0:
24
+ self._update(value)
25
+ else:
26
+ for _v in value:
27
+ self._update(_v)
28
+
29
+ def _update(self, value):
30
+ self.count += 1
31
+ if self._value is None:
32
+ self._value = value
33
+ else:
34
+ w1 = self.gamma * (1 - self.gamma ** (self.count - 1))
35
+ w2 = (1 - self.gamma)
36
+ wt = w1 + w2
37
+ w1 = w1 / wt
38
+ w2 = w2 / wt
39
+ self._value = w1 * self._value + w2 * value
40
+
41
+ @property
42
+ def value(self):
43
+ if self._value is None:
44
+ return 0
45
+ return self._value * 1
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ datasets
4
+ numpy
5
+ pandas