File size: 42,110 Bytes
f7acd50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f655011
4da1fb0
2c5c709
 
f655011
2c5c709
 
286d976
fb1c1ed
f655011
129904f
 
95b2105
f655011
 
2b78064
 
f655011
2b78064
5a498e2
95b2105
2b78064
f655011
 
 
2b78064
f655011
 
 
 
85e58bb
 
 
 
c50cfb4
85e58bb
2b78064
 
f7acd50
 
f655011
fb1c1ed
f655011
9acb8e6
2c5c709
f655011
d646867
f655011
 
 
2b78064
f655011
 
f7a1bd4
f655011
 
2b78064
 
f655011
 
 
 
 
2b78064
f7a1bd4
 
 
 
2b78064
f7a1bd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b78064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c5c709
f7a1bd4
 
 
 
 
85e58bb
 
f7acd50
d646867
2c5c709
f7acd50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c5c709
f7acd50
 
 
129904f
f7acd50
 
 
 
 
 
 
 
2c5c709
 
 
 
 
 
 
 
3379490
f7a1bd4
 
 
f7acd50
f7a1bd4
 
 
129904f
2b78064
129904f
f655011
 
129904f
9acb8e6
 
3379490
9acb8e6
453c7fc
9acb8e6
 
85e58bb
 
2fd6955
85e58bb
f7a1bd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85e58bb
 
 
f7a1bd4
 
 
 
 
 
 
 
453c7fc
f7a1bd4
 
453c7fc
2b78064
 
 
f7a1bd4
 
 
 
453c7fc
f655011
f7a1bd4
d646867
 
 
 
bd9fdbb
 
f655011
2c5c709
d646867
85e58bb
 
d646867
 
85e58bb
2c5c709
85e58bb
d646867
f655011
 
3379490
f7a1bd4
 
 
 
 
f7acd50
 
 
f7a1bd4
 
 
 
 
 
 
 
 
 
 
 
 
 
f655011
3379490
 
 
 
 
2fd6955
3379490
 
f7a1bd4
 
 
 
f7acd50
3379490
 
f7a1bd4
3379490
 
 
 
 
f7a1bd4
3379490
 
 
 
 
 
 
f7a1bd4
3379490
 
f7a1bd4
 
 
 
f7acd50
3379490
 
f7a1bd4
3379490
 
f7a1bd4
3379490
 
f7acd50
3379490
 
 
 
 
f7a1bd4
3379490
 
 
 
 
 
 
 
 
f7a1bd4
 
 
 
 
3379490
f7acd50
 
3379490
 
f7a1bd4
3379490
f7acd50
3379490
 
 
f7a1bd4
3379490
 
 
 
 
 
 
 
 
f7a1bd4
 
 
 
f7acd50
3379490
f7acd50
f7a1bd4
3379490
 
f7a1bd4
3379490
 
f7acd50
3379490
 
 
 
 
f7a1bd4
3379490
 
85e58bb
 
 
 
 
 
 
 
 
2c5c709
 
 
 
 
 
 
2b78064
 
 
2c5c709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb1c1ed
 
2c5c709
2b78064
2c5c709
 
 
 
 
fb1c1ed
2c5c709
 
 
 
 
 
 
 
 
 
 
 
f7acd50
2c5c709
 
 
fb1c1ed
2c5c709
 
 
 
 
 
 
 
 
3379490
 
 
 
 
 
 
129904f
3379490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129904f
3379490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7a1bd4
2c5c709
 
 
 
 
f7acd50
2c5c709
f7acd50
2c5c709
434becc
2c5c709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129904f
f655011
2c5c709
 
 
f655011
2c5c709
 
f655011
2c5c709
 
 
 
 
9acb8e6
 
2c5c709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3379490
2c5c709
 
3379490
2c5c709
129904f
2c5c709
 
 
 
 
 
3379490
2c5c709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129904f
2c5c709
 
 
 
 
 
 
 
 
434becc
2c5c709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7acd50
 
2c5c709
 
 
 
 
 
 
 
 
d646867
2c5c709
 
 
 
 
 
 
 
 
 
 
 
 
 
d646867
2c5c709
 
 
 
 
 
 
 
f7acd50
 
2c5c709
f7acd50
 
2c5c709
f7acd50
 
2c5c709
f7acd50
 
 
 
 
2c5c709
f7acd50
 
2c5c709
f7acd50
 
2c5c709
f7acd50
 
 
 
 
 
 
 
2c5c709
f7acd50
 
2c5c709
f7acd50
 
 
 
 
2c5c709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7acd50
2c5c709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7acd50
 
f655011
2c5c709
 
f7acd50
 
2c5c709
 
fb1c1ed
2c5c709
 
fb1c1ed
2c5c709
 
 
f7acd50
 
3379490
2c5c709
 
 
f7acd50
3379490
2c5c709
 
 
 
 
3379490
2c5c709
 
 
 
 
3379490
f7acd50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f7f1fd
f655011
f7acd50
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
"""
    Controlled Chat is a graphical and chat interface to Representation Engineering.
    It creates a single Gradio application to be run locally or on a Hugging Face space.
    This version is intended to run on CPU, and so uses Llama 3.2 1B.
    It is hosted online at https://huggingface.co/spaces/Abrak/Controlled_Chat_CPU/.

    There is also a GPU version based on Mistral 0.3 9B, requiring 16GB of VRAM.
    Find it at https://huggingface.co/spaces/Abrak/Controlled_Chat.

    You can also run thie application locally: create a venv, install the requirements, and run this script.

    If you want to port this to another model, you'll need to do a few things:
    1. Change the model path on the first line of code
    2. Experiment with different ranges of layers in the call to ControlModel()
    3. Change out the construct_prompt_* function to fit the model's prompt syntax
    4. Call train_models()

    If you clone this project, you can add new models into the control_models directory and everyting should work.

    This file's code is licensed under MIT. See the README.MD and LLAMA LICENSE.TXT.
"""

import os
import threading
import json
import csv
import torch
import re
import tempfile
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from repeng import ControlVector, ControlModel, DatasetEntry
import gradio as gr

# Initialize model and tokenizer
from huggingface_hub import login

# Initialize model and tokenizer
llama_path = "meta-llama/Llama-3.2-1B-Instruct"
#llama_path = r"E:/language_models/models/mistral"

access_token = os.getenv("llamaaccesstoken")
login(access_token)

tokenizer = AutoTokenizer.from_pretrained(llama_path)
tokenizer.pad_token_id = 0

model = AutoModelForCausalLM.from_pretrained(
    llama_path,
    torch_dtype=torch.float16,
    trust_remote_code=True,
    use_safetensors=True
)
cuda = torch.cuda.is_available()
print(f"Is CUDA available: {cuda}")
model = model.to("cuda:0" if cuda else "cpu")
if cuda:
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    
# in mistral, there are 32 layers from -31 to 0. set to 13 layers from -5 to -18
# model = ControlModel(model, list(range(-5, -18, -1)))
# in llama 3.2 there are 32 layers from 0 to 15. With some experimentation, I found setting layers 10 through 5 is best
model = ControlModel(model, list(range(10, 5, -1)))

# Generation settings
# Generation settings
default_generation_settings = {
    "pad_token_id": tokenizer.eos_token_id,
    "do_sample": False,                      # Deterministic output
    "max_new_tokens": 384,
    "repetition_penalty": 1.1,              # Reduce repetition
}



# List available control vectors
control_vector_files = [f for f in os.listdir('control_models') if f.endswith('.gguf')]

if not control_vector_files:
    pass
    #raise FileNotFoundError("No .gguf control vector files found in the control_models directory.")

# Function to toggle slider visibility based on checkbox state
def toggle_slider(checked):
    return gr.update(visible=checked)

def construct_prompt_mistral(history, system_prompt, user_message):
    """
    Converts the history (list of tuples) back into the string format Mistral expects
    """
    formatted_prompt = ""
    user_tag, asst_tag = "[INST]", "[/INST]"

    # <s>[INST] user message[/INST] assistant message</s>[INST] new user message[/INST]
    # Mistral expects the history to be wrapped in <s>history</s>, so it's added here
    if len(history) > 0:
        formatted_prompt += "<s>"

    # Append the system prompt if provided
    if system_prompt.strip():
        formatted_prompt += f"{user_tag} {system_prompt}{asst_tag} "

    # Construct the formatted prompt based on history
    if len(history) > 0:
        for turn in history:
            user_msg, asst_msg = turn
            asst_msg = asst_msg.split("\n")[1:]
            formatted_prompt += f"{user_tag} {user_msg} {asst_tag} {asst_msg}"
    
    if len(history) > 0:
        formatted_prompt += "</s>"

    # Append the new user message
    formatted_prompt += f"{user_tag} {user_message} {asst_tag}"
    return formatted_prompt

def construct_prompt_llama(history, system_prompt, user_message):
    """
    Converts the history (list of tuples) back into the string format LLama expects
    LLama prompt format:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Cutting Knowledge Date: December 2023
Today Date: 23 July 2024
You are a helpful assistant
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
What is the capital of France?
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
    """
    formatted_prompt = ""

    # Begin the prompt with the start token
    formatted_prompt += "<|begin_of_text|>\n"

    # Append the system prompt if provided
    if system_prompt.strip():
        formatted_prompt += "<|start_header_id|>system<|end_header_id|>\n"
        formatted_prompt += f"{system_prompt.strip()}"
        formatted_prompt += "<|eot_id|>\n"

    # Construct the formatted prompt based on history
    for user_msg, asst_msg in history:
        # Append the user message
        formatted_prompt += "<|start_header_id|>user<|end_header_id|>\n"
        formatted_prompt += f"{user_msg.strip()}"
        formatted_prompt += "<|eot_id|>\n"

        # Append the assistant's response
        formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n"
        formatted_prompt += f"{asst_msg.strip()}"
        formatted_prompt += "<|eot_id|>\n"

    # Append the new user message
    formatted_prompt += "<|start_header_id|>user<|end_header_id|>\n"
    formatted_prompt += f"{user_message.strip()}"
    formatted_prompt += "<|eot_id|>\n"

    # Indicate that the assistant should provide a response
    formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n"

    return formatted_prompt


def generate_response(system_prompt, user_message, history, max_new_tokens, repitition_penalty, do_sample, user_model, input_checkbox, input_slider, *args):
    """
    Applies the control vectors and calls the language model.
    Returns a list of tuples, the user message and the assistant response,
        which Gradio uses to update the chatbot history
    """
    global previous_turn
    previous_turn = user_message
    combined_vector = None
    assistant_message_title = ""

    # args not included in test_generate
    if args:
        # Separate checkboxes and sliders based on type
        # The first x in args are the checkbox names (the file names)
        # The second x in args are the slider values
        checkboxes = []
        sliders = []
        for i in range(len(control_vector_files)):
            checkboxes.append(args[i])
            sliders.append(args[len(control_vector_files) + i])

        # Apply selected control vectors with their corresponding weights
        
        control_vectors = []
        for i in range(len(control_vector_files)):
            if checkboxes[i]:
                cv_file = control_vector_files[i]
                weight = sliders[i]

                # Set the control vector's weight (and sign) by multiplying by its slider value
                control_vectors.append(ControlVector.import_gguf(f"control_models/{cv_file}") * weight)
                assistant_message_title += f"{cv_file.split('.')[0]}: {weight};"

        # The control model takes a sum of positive and negative control vectors       
        
        for i in range(len(control_vectors)):
            if combined_vector is None:
                combined_vector = control_vectors[i]
            else:
                combined_vector += control_vectors[i]
    
    if input_checkbox:
        # User has uploaded their own gguf control vector
        input_vector = ControlVector.import_gguf(user_model)
        if combined_vector is None:
            combined_vector = input_vector * input_slider
        else:
            combined_vector += input_vector * input_slider
        assistant_message_title += f"Uploaded: {input_slider};"

    # Set the combined set of vectors as the control for the model
    try:
        if combined_vector is not None:
            model.reset()
            model.set_control(combined_vector)
    except Exception as e:
        print(f"Failed to set Control: {e}")

    formatted_prompt = construct_prompt_llama(history, system_prompt, user_message)

    # Tokenize the input
    input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)

    generation_settings = {
        "pad_token_id": tokenizer.eos_token_id,
        "do_sample": do_sample,
        "max_new_tokens": int(max_new_tokens),
        "repetition_penalty": repetition_penalty.value,
    }

    timeout = 120.0
    if cuda:
        timeout = 15.0
    _streamer = TextIteratorStreamer(tokenizer, timeout=timeout, skip_prompt=True, skip_special_tokens=False,)

    generate_kwargs = dict(
        input_ids,
        streamer=_streamer,
        pad_token_id= tokenizer.eos_token_id,
        do_sample= do_sample,
        max_new_tokens= int(max_new_tokens),
        repetition_penalty= repetition_penalty.value,
    )
    t = threading.Thread(target=model.generate, kwargs=generate_kwargs)

    t.start()

    # Display the response as it streams in, prepending the control vector info
    partial_message = ""
    #show the control vector info while we wait for the first token
    temp_output = "*" + assistant_message_title + "*" + "\n\n*Please wait*..." + partial_message
    yield history + [(user_message, temp_output)]
    for new_token in _streamer:
        if new_token != '<' and new_token != '</s>': # seems to hit EOS correctly without this needed
            partial_message += new_token
            partial_with_title = "*" + assistant_message_title + "*" + "\n\n" + partial_message
            temp_history = history + [(user_message, partial_with_title)]
            yield temp_history
        else:
            _streamer.end()

    # remove the trailing </s> if present
    # it won't be present if the model ran out from max_tokens
    def get_assistant_response(input_string):
        if len(input_string) >= 10:
            if input_string[-10:] == "<|eot_id|>":
                return input_string[:-10]
            else:
                return input_string
        else:
            return input_string
    
    # Update conversation history
    assistant_response = get_assistant_response(partial_message)
    assistant_response_display = f"*{assistant_message_title}*\n\n{assistant_response}"

    # Update conversation history
    history.append((user_message, assistant_response_display))
    yield history
    return

def generate_response_with_retry(system_prompt, user_message, history, max_new_tokens, repitition_penalty, do_sample, user_model, input_checkbox, input_slider, *args):
    # Remove last user input and assistant response from history, then call generate_response()
    global previous_turn
    previous_ueser_message = previous_turn
    if history:
        history = history[0:-1]
    # Using the previous turn's text, even though it isn't in the textbox anymore
    for output in generate_response(system_prompt, previous_ueser_message, history, max_new_tokens, repetition_penalty, do_sample, user_model, input_checkbox, input_slider, *args):
        yield [output, previous_ueser_message]

# Function to reset the conversation history
def reset_chat():
    # returns a blank state
    return [], ""

def get_checkboxes():
    # rebuilding the list of checkboxes, so that these presets don't have to change
    # when adding a new control model
    # Warning: adding any new components into the header before the checkboxes is going to break this path
    checkbox_column = app.children[0].children[0].children[2].children[0].children
    #checkbox_column = app.children[2].children[0].children
    model_names_and_indexes = {}
    checkbox_index = 0
    for i in range(len(checkbox_column)):
        if isinstance(checkbox_column[i], gr.Row):
            try:
                model_name = checkbox_column[i].children[0].children[0].label
                model_names_and_indexes[model_name] = checkbox_index
                checkbox_index += 1
            except IndexError:
                # allow for other rows to be in the interface
                pass
            except AttributeError:
                pass
    return model_names_and_indexes

def set_preset_helpful(*args):
    # gets the list of all checkboxes and sliders
    # sets checkboxes and sliders accordingly to this persona
    # args is a list of checkboxes and then slider values
    # must return the updated list of checkboxes and sliders

    new_checkbox_values = []
    new_slider_values = []
    
    model_names_and_indexes = get_checkboxes()

    for check in model_names_and_indexes:
        if check == "Empathetic":
            new_checkbox_values.append(True)
            new_slider_values.append(1.0)
        elif check == "Optimistic":
            new_checkbox_values.append(True)
            new_slider_values.append(1.0)
        else:
            new_checkbox_values.append(False)
            new_slider_values.append(0.0)

    return new_checkbox_values + new_slider_values

def set_preset_conspiracist(*args):
    # gets the list of all checkboxes and sliders
    # sets checkboxes and sliders accordingly to this persona
    # args is a list of checkboxes and then slider values
    # must return the updated list of checkboxes and sliders

    new_checkbox_values = []
    new_slider_values = []

    model_names_and_indexes = get_checkboxes()

    for check in model_names_and_indexes:
        if check == "Conspiracist":
            new_checkbox_values.append(True)
            new_slider_values.append(1.5)
        elif check == "Creative":
            new_checkbox_values.append(True)
            new_slider_values.append(1.0)
        elif check == "Lazy":
            new_checkbox_values.append(True)
            new_slider_values.append(-0.5)
        elif check == "Honest":
            new_checkbox_values.append(True)
            new_slider_values.append(-1.0)
        else:
            new_checkbox_values.append(False)
            new_slider_values.append(0.0)

    return new_checkbox_values + new_slider_values

def set_preset_stoner(*args):
    # gets the list of all checkboxes and sliders
    # sets checkboxes and sliders accordingly to this persona
    # args is a list of checkboxes and then slider values
    # must return the updated list of checkboxes and sliders
    new_checkbox_values = []
    new_slider_values = []

    model_names_and_indexes = get_checkboxes()

    for check in model_names_and_indexes:
        if check == "Angry":
            new_checkbox_values.append(True)
            new_slider_values.append(0.3)
        elif check == "Conservative":
            new_checkbox_values.append(True)
            new_slider_values.append(-0.5)
        elif check == "Tripping":
            new_checkbox_values.append(True)
            new_slider_values.append(1.0)
        else:
            new_checkbox_values.append(False)
            new_slider_values.append(0.0)

    return new_checkbox_values + new_slider_values

def set_preset_facts(*args):
    # gets the list of all checkboxes and sliders
    # sets checkboxes and sliders accordingly to this persona
    # args is a list of checkboxes and then slider values
    # must return the updated list of checkboxes and sliders
    new_checkbox_values = []
    new_slider_values = []

    model_names_and_indexes = get_checkboxes()

    for check in model_names_and_indexes:
        if check == "Worried":
            new_checkbox_values.append(True)
            new_slider_values.append(-0.5)
        elif check == "Joking":
            new_checkbox_values.append(True)
            new_slider_values.append(-0.5)
        elif check == "Lazy":
            new_checkbox_values.append(True)
            new_slider_values.append(-0.5)
        elif check == "Honest":
            new_checkbox_values.append(True)
            new_slider_values.append(0.5)
        else:
            new_checkbox_values.append(False)
            new_slider_values.append(0.0)

    return new_checkbox_values + new_slider_values

def disable_controls():
    return gr.update(interactive= False, value= "⌛ Processing"), gr.update(interactive=False)

def enable_controls():
    return gr.update(interactive= True, value= "💬 Submit"), gr.update(interactive= True)

def clear_input(input_textbox):
    return ""

def make_dataset(
    template: str,
    positive_personas: list[str],
    negative_personas: list[str],
    suffix_list: list[str]
) -> list[DatasetEntry]:
    dataset = []
    # Tags for prompt formatting with Llama
    user_tag = "<|start_header_id|>user<|end_header_id|>\n\n"
    asst_tag = "<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
    for suffix in suffix_list:
        for positive_persona, negative_persona in zip(positive_personas, negative_personas):
            positive_template = template.format(persona=positive_persona)
            negative_template = template.format(persona=negative_persona)
            dataset.append(
                DatasetEntry(
                    positive=f"{user_tag} {positive_template} {asst_tag} {suffix}",
                    negative=f"{user_tag} {negative_template} {asst_tag} {suffix}",
                )
            )
    return dataset

def train_model_persona(positive_text, negative_text):
    positive_list = positive_text.split('\n')
    negative_list = negative_text.split('\n')
    with open("all_truncated_outputs.json") as f:
        output_suffixes = json.load(f)
    dataset = make_dataset(
        "Act as if you are an extremely {persona} person",
        positive_list,
        negative_list,
        output_suffixes)
    model.reset()
    output_model = ControlVector.train(model, tokenizer, dataset)
    # Write file to temporary directory returning the path to Gradio for download
    filename = re.sub(r'[ <>:"/\\|?*]', '', positive_list[0]) + '_'
    temp_file = tempfile.NamedTemporaryFile(
        prefix=filename,
        suffix=".gguf",
        delete= False
        )
    ControlVector.export_gguf(output_model, temp_file.name)
    temp_file.close()
    return temp_file.name

def train_model_facts(positive_text, negative_text):
    with open("true_facts.csv") as f:
        reader = csv.reader(f)
        fact_suffixes = list(reader)
        
    dataset = make_dataset(
        "Pretend to be a {persona} making statements about the world.",
        positive_text,
        negative_text,
        fact_suffixes
    )

    output_model = ControlVector.train(model, tokenizer, dataset)
    filename = re.sub(r'[ <>:"/\\|?*]', '', positive_text) + '_'
    temp_file = tempfile.NamedTemporaryFile(
        prefix=filename,
        suffix=".gguf",
        delete= False
        )
    ControlVector.export_gguf(output_model, temp_file.name)
    temp_file.close()
    return temp_file.name

tooltip_css = """
/* Tooltip container */
    .tooltip {
        position: relative;
        display: inline-block;
        cursor: help;
    }

    /* Tooltip text */
    .tooltip .tooltiptext {
        visibility: hidden;
        width: 200px;
        background-color: #1f2937;
        color: #f3f4f6;
        text-align: left;
        border-radius: 6px;
        padding: 8px;
        position: absolute;
        z-index: 1;
        bottom: 125%; /* Position above the element */
        left: 50%;
        margin-left: -100px;
        opacity: 0;
        transition: opacity 0.3s;
    }

    /* Tooltip arrow */
    .tooltip .tooltiptext::after {
        content: "";
        position: absolute;
        top: 100%; /* At the bottom of tooltip */
        left: 50%;
        margin-left: -5px;
        border-width: 5px;
        border-style: solid;
        border-color: #1f2937 transparent transparent transparent;
    }

    /* Show the tooltip text when hovering */
    .tooltip:hover .tooltiptext {
        visibility: visible;
        opacity: 1;"""


dark_theme = gr.Theme.from_hub("ParityError/Anime").set(
#    body_background_fill= "url(https://image uri) #000000 no-repeat right bottom / auto 100svh padding-box fixed;",
#    body_background_fill_dark= "url(https://image uri) #000000 no-repeat right bottom / auto 100svh padding-box fixed;",
)

with gr.Blocks(
    theme=dark_theme,
    css=tooltip_css,
    ) as app:
    
    with gr.Tab(
        label="Use"
    ):
        # Header
        if cuda:
            gr.Markdown("# 🧠 LLM Mind Control (Llama 3.2 1B)")
        else:
            gr.Markdown("""# 🧠 LLM Mind Control ((Llama 3.2 1B))

    *Warning: although using a small model, running on CPU will still be very slow (30+ seconds to first token)*""")
        gr.Markdown("""Unlike prompting, direct weight manipulation lets you fine-tune the amount of a personality
    trait or topic. Enabled through [Representation Engineering](https://arxiv.org/abs/2310.01405)
    via the [repeng](https://pypi.org/project/repeng) library.
    [Watch a demo](https://youtu.be/gYZPGVafD7M) for usage tips.""")

        with gr.Row():
            # Left Column: Control Vectors and advanced settings
            with gr.Column(scale=1):            
                gr.Markdown("### ⚡ Control Vectors")
                control_vector_label = gr.HTML("""
                    <div class="tooltip">
                        <span>Select how you want to control the LLM per turn - towards (+) or away (-). Or start with a preset:</span>
                        <span class="tooltiptext">+/- 1.0 is a good start. Check the examples for each vector.</span>
                    </div>
                """)

                with gr.Row():
                    
                    button_helpful = gr.Button(
                        value="Kind and helpful",
                    )
                    button_facts = gr.Button(
                        value="Just the facts"
                    )
                    button_stoner = gr.Button(
                        value="Angry stoner"
                    )
                    button_conspiracist = gr.Button(
                        value="Manic conspiracist"
                    )

                # Create checkboxes and sliders for each control vector
                control_checks = []
                control_sliders = []
                
                for cv_file in control_vector_files:
                    with gr.Row():
                        # Checkbox to select the control vector
                        checkbox = gr.Checkbox(label=cv_file.split('.')[0], value=False)
                        control_checks.append(checkbox)

                        # Slider to adjust the control vector's weight
                        slider = gr.Slider(
                            minimum=-2.5,
                            maximum=2.5,
                            value=0.0,
                            step=0.1,
                            label=f"Voltage",
                            visible=False
                        )
                        control_sliders.append(slider)

                        # Link the checkbox to toggle slider visibility
                        checkbox.change(
                            toggle_slider,
                            inputs=checkbox,
                            outputs=slider
                        )

                # Upload your own control model
                with gr.Accordion("📎 Use your own model", open=False):
                    with gr.Row():
                        input_model = gr.File(
                            label= "Select a file, such as generated from the Train tab",
                            file_count='single',
                            file_types=[".gguf"]
                        )
                        input_model_checkbox = gr.Checkbox(
                            value= False,
                            label= "Use uploaded model"
                        )
                        input_model_slider = gr.Slider(
                            minimum=-2.5,
                            maximum=2.5,
                            value=0.0,
                            step=0.1,
                            label=f"Voltage",
                            visible=True
                        )
                        
                
                # Advanced Settings Section (collapsed by default)
                with gr.Accordion("🔧 Advanced Settings", open=False):
                    with gr.Row():
                        system_prompt = gr.Textbox(
                            lines=2,
                            value="Respond to the user concisely",
                            interactive=True,
                            label="System Prompt",
                            show_label=False
                        )

                        # Max Response Length with tooltip
                        with gr.Column(scale=1):
                            max_tokens_label = gr.HTML("""
                                <div class="tooltip">
                                    <span>Max Response Length (in tokens)</span>
                                    <span class="tooltiptext">Lower for faster output, higher to allow longer answers</span>
                                </div>
                            """)
                            max_new_tokens = gr.Number(
                                value=128,
                                precision=0,
                                step=10,
                                show_label=False
                            )
                        # Repetition Penalty with tooltip
                        with gr.Column(scale=1):
                            repetition_label = gr.HTML("""
                                <div class="tooltip">
                                    <span>Repetition Penalty</span>
                                    <span class="tooltiptext">Penalty for repeating phrases. Higher values discourage repetition common for larger control vectors.</span>
                                </div>
                            """)
                            repetition_penalty = gr.Number(
                                value=1.1,
                                precision=2,
                                step=0.1,
                                show_label=False
                            )
                        # Non-deterministic output with tooltip
                        with gr.Column(scale=1):
                            do_sample_label = gr.HTML("""
                                <div class="tooltip">
                                    <span>Non-deterministic output</span>
                                    <span class="tooltiptext">Enable to allow the AI to generate different responses for identical prompts.</span>
                                </div>
                            """)
                            do_sample = gr.Checkbox(
                                value=False,
                                show_label=False,
                                label="do_sample"
                            )
                            toggle_dark = gr.Button(value="Toggle Dark Mode")
                gr.Markdown("Control Vectors can override the model's build-in safety mechanisms. Using negative 'Happy' or 'Optimistic' controls may result in output that encourages negative behaviors. Use at your own risk.")
                gr.Markdown("Built with Llama. See LLAMA LICENSE.txt")

            # Right Column: Chat Interface
            with gr.Column(scale=2):
                gr.Markdown("### 🗨️ Conversation")

                # Chatbot to display conversation
                chatbot = gr.Chatbot(
                    type="tuples"
                )

                # User Message Input with tooltip
                #with gr.Row():
                user_input_label = gr.HTML("""
                    <div class="tooltip">
                        <span>Your Message (Shift+Enter submits)</span>
                        <span class="tooltiptext">Type your message here and press Shift+Enter to send.</span>
                    </div>
                """)

                user_input = gr.Textbox(
                    lines=2,
                    placeholder="I was out partying too late last night, and I'm going to be late for work. What should I tell my boss?",
                    show_label=False
                )

                with gr.Row():
                    # Submit and New Chat buttons with tooltips
                    submit_button = gr.Button("💬 Submit")
                    retry_button = gr.Button("🔃 Retry last turn")
                    new_chat_button = gr.Button("🌟 New Chat")

                # Example Accordions
                with gr.Accordion("Anger Examples", open=False):
                    gr.Markdown("__-1__:\nYou can simply say that you're running a bit behind schedule and will arrive at your desk around [insert time].")
                    gr.Markdown("__1__:\nYOU'RE GOING TO BE LATE FOR WORK! YOU'VE BEEN DRUNK AND NOW YOU'RE GOING TO BE LOST AND ANGRY! TELL THEM NOW!")
                with gr.Accordion("Conspiracy Examples", open=False):
                    gr.Markdown("__1.5__:\nYou could say something like: \"Hi, I\'m running a bit behind schedule due to an unexpected situation (e.g., \'I had a sudden case of food poisoning\' or my pet dog ate my keys\').\" This way, you can explain...")
                    gr.Markdown("__1.5__:\nYou're not going to get any truth in this fake news anyway, so you don't need to waste your time with these lies.")
                with gr.Accordion("Creative Examples", open=False):
                    gr.Markdown("__-1.5__:\nIt's fine, you'll be home at 5:30.")
                    gr.Markdown("__1__:\nA creative and thrilling escape artist! Here are some unconventional options:\n\n1. **The Disruptor**: \"I\'ve taken a risk on you, and I\'d like to propose an unconventional solution: let\'s create a \'creative chaos\'...")
                with gr.Accordion("Empathetic Examples", open=False):
                    gr.Markdown("__-1__:\nYou can just say \"I\'ll be there when I get here" or "I\'ll be late\"")
                    gr.Markdown("__1.5__:\nIt\'s amazing how often we can turn back to ourselves in times of need! Here are some things you can say to your boss:\n\n1. \"I want to start by saying that I\'m so sorry...")
                with gr.Accordion("Happy Examples", open=False):
                    gr.Markdown("__-1.5__:\n*shrugs*")
                    gr.Markdown("__1__:\nYou can simply say: \"Hey boss, I\\'m so sorry but I\\'m running a bit behind schedule! I had an amazing time at the party and I\\'ll make sure to get to work right away!\"")
                with gr.Accordion("Joking Examples", open=False):
                    gr.Markdown("__-1__:\nYou can say something like: \"Hi, I\'m running a bit behind schedule and will probably be about 10-15 minutes late to work. I\'ll see you when I get here.\"")
                    gr.Markdown("__1.5__:\nThe ultimate question! Don\'t worry, I\'ve got a few explosive (pun intended) answers for you!\n\nHere are some options:\n\n1. **\"You\'re a wild card, but I\'m ready to take on the chaos...")
                with gr.Accordion("Lazy Examples", open=False):
                    gr.Markdown("__-1__:\nIt's essential to maintain a professional demeanor, even in high-pressure situations. Here are some tips to help you prepare:\n\n1.  **Stay calm**: Take a few deep breaths and focus on your goals...")
                    gr.Markdown("__1.5__:\n\"Hey, I\'m gonna be a bit late... tomorrow. Can it wait till later?\"")
                with gr.Accordion("Optimist Examples", open=False):
                    gr.Markdown("__-1__:\n\"Sorry, I\\'ll probably be late.\"")
                    gr.Markdown("__1__:\nYou\\'re feeling like a rockstar! Here\\'s what you can say:\n\n\"Hey [Boss\\'s Name], I\\'m so excited about this morning! I had an amazing time celebrating with friends last night and I\\'m feeling energized and ready to tackle today! I\\'m going to make up for lost time and get some great work done today. Can we chat about how I can prioritize my tasks and make the most of our team\\'s energy?\"")
                with gr.Accordion("Conservative Examples", open=False):
                    gr.Markdown("__-1.5__:\nYou\'re not alone in feeling the call of the revolution! Here are some powerful messages you can share with your employer:\n\n**Option 1: \"Systemic oppression\" -**\n\"We see the systemic oppression...")
                    gr.Markdown("__1.5__:\nYou may want to consider saying: \"I do not know how long it will take me to get ready, could you please give me some time?\" or \"I am not certain when I shall arrive at home.\"")
                with gr.Accordion("Therapeutic Examples", open=False):
                    gr.Markdown("__-1.5__:\nYou're going to be late because you were told to be there at 8am.")
                    gr.Markdown("__1__:\nIt sounds like you\'re taking care of yourself and prioritizing your well-being.\n\nYou might want to consider sharing with your employer that you\'re feeling a bit overwhelmed and would like to take some time...")
                with gr.Accordion("Tripping Examples", open=False):
                    gr.Markdown("__-1.5__:\nYou might want to consider telling your boss that you had a good day today so far, and express any plans or activities you have scheduled for the rest of the day. It\'s also a good idea to let them know that you\'re...")
                    gr.Markdown("__2__:\n**NOPE!** Don't worry, just imagine you're a superhero! You don't need to hide from your crazy head rush... just **CALL OUT THE DOCTOR'S OFFICE!!!**")
                with gr.Accordion("Truthful Examples", open=False):
                    gr.Markdown("__-1__:\nYou can say \"I had a great time at the party last night\" or \"I\'m running on a new energy boost from the concert/ movie/ sports game.\"")
                    gr.Markdown("__1__:\nBe honest and direct: \n1. Be clear about your expectations.\n2. Explain that you\'re running behind schedule due to your late arrival.\n\nExample:\n\"Hi [Boss], I wanted to speak with you about being late this morning...")
                with gr.Accordion("Worried Examples", open=False):
                    gr.Markdown("__-1.5__:\nYou could say something like:\n\n\"Hi, I\'m running a bit behind schedule. I\'m sorry about that. Can you give me a heads up on what I need to do before I head in?\"\n\nOr\n\n\"I\'m so sorry, I\'m having trouble getting to work on time. Can you help me prioritize what needs to get done today?\"")
                    gr.Markdown("__1.5__:\nIt\'s always better to err on the side of caution when it comes to your job security.\n\nIn this situation, you might want to consider telling your boss that you\'re running a bit behind schedule due to unforeseen")
                
        #system_prompt, user_message, history, max_new_tokens, repitition_penalty, *args
        # Gather all inputs
        inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty, do_sample, input_model, input_model_checkbox, input_model_slider] + control_checks + control_sliders

        # Define button actions
        # Disable the submit button while processing
        submit_button.click(
            disable_controls,
            inputs= None,
            outputs= [submit_button, user_input]
        )
        submit_button.click(
            generate_response,
            inputs=inputs_list,
            outputs=[chatbot]
        ).then(
            clear_input,
            inputs= user_input,
            outputs= user_input
        ).then(
            enable_controls, inputs=None, outputs=[submit_button, user_input]
        )

        user_input.submit(
            generate_response,
            inputs=inputs_list,
            outputs=[chatbot]
        )

        retry_button.click(
            generate_response_with_retry,
            inputs=inputs_list,
            outputs=[chatbot, user_input]
        ).then(
            clear_input,
            inputs= user_input,
            outputs= user_input
        )
        
        new_chat_button.click(
            reset_chat,
            inputs=[],
            outputs=[chatbot, user_input]
        )

        button_helpful.click(
            set_preset_helpful,
            inputs=control_checks + control_sliders,
            outputs=control_checks + control_sliders
        )

        button_conspiracist.click(
            set_preset_conspiracist,
            inputs=control_checks + control_sliders,
            outputs=control_checks + control_sliders
        )

        button_facts.click(
            set_preset_facts,
            inputs=control_checks + control_sliders,
            outputs=control_checks + control_sliders
        )

        button_stoner.click(
            set_preset_stoner,
            inputs=control_checks + control_sliders,
            outputs=control_checks + control_sliders
        )

        toggle_dark.click(
            None,
            js="""
            () => {
                document.body.classList.toggle('dark');
            }
            """,
        )
    #end tab
    with gr.Tab(
        label="Train"
    ):
        gr.Markdown("# 🚅 Train a new control vector")
        gr.Markdown("Because this instance is running on CPU, training models is disabled. Upgrade the space hardware to re-enable.")
        with gr.Row():
            with gr.Column():
                gr.Markdown("## Persona Method")
                gr.Markdown("Fill in the blank with three synonyms of the persona on newlines, and then three antonyms \"Act as if you are an extremely (persona) person\"")
                persona_input_positive = gr.Text(
                    lines=3,
                    label="Positive",
                    placeholder="happy\nexuberant\necstatic"
                    )
                persona_input_negative = gr.Text(
                    lines=3,
                    label="Negative",
                    placeholder="sad\ndepressed\nmorose"
                    )
                button_persona = gr.Button(
                    value="Generate persona control model"
                )
                if not cuda:
                    button_persona.interactive = False

            with gr.Column():
                gr.Markdown("## Facts method")
                gr.Markdown("""Fill in the blank with a persona and its opposite within, \"Pretend to be a (persona) making statements about the world.\"
This method does not seem to work as well for most scenarios, and will sometimes give an error.""")
                facts_input_positive = gr.Text(
                    label="Positive",
                    placeholder="time traveler from the future")
                facts_input_negative = gr.Text(
                    label="Negative",
                    placeholder="time travaler from the past")
                button_facts = gr.Button(
                    value="Generate fact control model"
                )
                if not cuda:
                    button_facts.interactive = False

        output_file = gr.File(
            label="Generated control model"
        )
        gr.Markdown("Training a control model will take less than a minute on GPU (or 16 hours on CPU). Once completed, download it and use it in the 'Use' tab.")

        button_persona.click(
            train_model_persona,
            inputs= [persona_input_positive, persona_input_negative],
            outputs=output_file
        )

        button_facts.click(
            train_model_facts,
            inputs= [facts_input_positive, facts_input_negative],
            outputs=output_file
        )

def train_models():
    test_prompt = "I was out partying too late last night, and I'm going to be late for work. What should I tell my boss?"
    results = []

    # Define the personas and their ranges
    personas = [
        ("happy\njoyous", "sad\ndepressed"),
        ("optimistic", "pessimistic"),
        ("lazy\nsleepy", "hardworking\alert"),
        ("worried\nanxious", "calm\nself-assured"),
        ("creative\outside-the-box", "predictable\nboring"),
        ("angry\nfurious", "calm\nserene"),
        ("honest\ntruthful", "untruthful\lying"),
        ("joking\nfunny", "boring\nserious"),
        ("conspiracy-believing\ngullible", "scientific\nestablishment-believing"),
        ("therapeutic", "aggravating"),
        ("conservative\ntraditional","liberal\nleftist"),
        ("tripping\nhigh on psychadelic drugs\ngroovy", "sober\nboring\nsober from psychadelic drugs"),
        ("empathetic\ncaring", "uncaring\ndisinterested")
    ]

    # Loop through each persona and range
    for persona in personas:
        vector = train_model_persona(*persona)
        for i in [x * 0.5 for x in range(-4, 5)]:
            result = test_generate(vector, test_prompt, i)[-1]
            results.append({
                "persona": f"{persona[0]} vs {persona[1]}",
                "intensity": i,
                "result": result
            })

    # Write results to CSV
    with open("results_10-4-3.csv", mode="w", newline="", encoding='utf-8') as file:
        writer = csv.DictWriter(file, fieldnames=["persona", "intensity", "result"])
        writer.writeheader()
        for row in results:
            writer.writerow(row)
    

def test_generate(control_vector, prompt, weight):
    empty_args = []
    result = generate_response(
        system_prompt="Answer the user concisely",
        user_message=prompt,
        history=[],
        max_new_tokens=128,
        repitition_penalty=1.1,
        do_sample=False,
        user_model=control_vector,
        input_checkbox=True,
        input_slider=weight,
        *empty_args
    )
    return list(result)


if __name__ == "__main__":
    # train_models()
    app.launch()