piyushgrover commited on
Commit
7d65571
·
1 Parent(s): e63234b

added new files

Browse files
Files changed (3) hide show
  1. app.py +104 -0
  2. requirements.txt +12 -0
  3. s27erav1.ipynb +1314 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ model_name = "microsoft/phi-2"
6
+
7
+ # Reload model in FP16 and merge it with LoRA weights
8
+ base_model = AutoModelForCausalLM.from_pretrained(
9
+ model_name,
10
+ low_cpu_mem_usage=True,
11
+ return_dict=True,
12
+ torch_dtype=torch.float16,
13
+ trust_remote_code=True
14
+ # device_map=device_map,
15
+ )
16
+
17
+ from peft import PeftModel
18
+ new_model = "piyushgrover/phi-2-qlora-adapter-custom"
19
+ model = PeftModel.from_pretrained(base_model, new_model)
20
+ model = model.merge_and_unload()
21
+
22
+ # Reload tokenizer to save it
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
24
+ tokenizer.pad_token = tokenizer.eos_token
25
+ tokenizer.padding_side = "right"
26
+
27
+ from transformers import pipeline
28
+ gen = pipeline('text-generation', model=model, tokenizer=tokenizer, max_length=300)
29
+
30
+
31
+
32
+ def fn_query_on_load():
33
+ prompt = "Explain nuclear physics to a five year old kid."
34
+ return prompt
35
+
36
+
37
+ def generate_response(input):
38
+
39
+ prompt = f"### Human: {input}\n\n### Assistant: "
40
+ result = gen(prompt)
41
+
42
+ resp = result[0]['generated_text'].replace(prompt, '')
43
+ resp_arr = resp.split('###')
44
+
45
+ final_resp = resp_arr[0]
46
+
47
+ '''
48
+
49
+ start_ids = encode(start)
50
+ x = (torch.tensor(start_ids, dtype=torch.long, device=device_type)[None, ...])
51
+
52
+ out_text = ''
53
+ with torch.no_grad():
54
+ with ctx:
55
+ for k in range(num_samples):
56
+ y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
57
+ out_text += decode(y[0].tolist())
58
+ out_text += '\n-o-o-o-o-o-o-o-\n\n'
59
+
60
+ '''
61
+ return {
62
+ output: final_resp
63
+ }
64
+
65
+
66
+ with gr.Blocks() as app:
67
+ with gr.Row():
68
+ gr.Markdown(
69
+ """
70
+ # PhiGPT - Ask Me Anything (AI Assistant)
71
+ ### Phi2 Model (2Billion parameters) Fine-tuned on OpenAssistant/oasst1 Dataset :)
72
+ #### [Please be patient as it's running on CPU & not GPU]
73
+ """)
74
+
75
+ with gr.Row(visible=True):
76
+ search_text = gr.Textbox(value=fn_query_on_load, placeholder='Enter prompt..', label='Enter Prompt')
77
+
78
+ with gr.Row():
79
+ submit_btn = gr.Button("Submit", variant='primary')
80
+ clear_btn = gr.ClearButton()
81
+ with gr.Row():
82
+ with gr.Row():
83
+ output = gr.Textbox(lines=15, interactive=False, label='Response ')
84
+
85
+ def clear_data():
86
+ return {
87
+ output: None,
88
+ search_text: None
89
+ }
90
+
91
+ clear_btn.click(clear_data, None, [output, search_text])
92
+
93
+
94
+ submit_btn.click(
95
+ generate_response,
96
+ search_text,
97
+ output
98
+ )
99
+
100
+
101
+ '''
102
+ Launch the app
103
+ '''
104
+ app.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ transformers
4
+ datasets
5
+ tiktoken
6
+ wandb
7
+ tqdm
8
+ trl
9
+ accelerate
10
+ git+https://github.com/huggingface/peft.git
11
+ bitsandbytes
12
+ einops
s27erav1.ipynb ADDED
@@ -0,0 +1,1314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {
7
+ "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
8
+ "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
9
+ "execution": {
10
+ "iopub.execute_input": "2023-12-21T18:16:54.952190Z",
11
+ "iopub.status.busy": "2023-12-21T18:16:54.951899Z",
12
+ "iopub.status.idle": "2023-12-21T18:17:53.065431Z",
13
+ "shell.execute_reply": "2023-12-21T18:17:53.064016Z",
14
+ "shell.execute_reply.started": "2023-12-21T18:16:54.952164Z"
15
+ }
16
+ },
17
+ "outputs": [
18
+ {
19
+ "name": "stdout",
20
+ "output_type": "stream",
21
+ "text": [
22
+ "Collecting git+https://github.com/huggingface/peft.git\n",
23
+ " Cloning https://github.com/huggingface/peft.git to /tmp/pip-req-build-dpzetz7o\n",
24
+ " Running command git clone --filter=blob:none --quiet https://github.com/huggingface/peft.git /tmp/pip-req-build-dpzetz7o\n",
25
+ " Resolved https://github.com/huggingface/peft.git to commit 993836ff90791289b94d27caa46385eec958e147\n",
26
+ " Installing build dependencies ... \u001b[?25ldone\n",
27
+ "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
28
+ "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
29
+ "\u001b[?25hCollecting trl\n",
30
+ " Obtaining dependency information for trl from https://files.pythonhosted.org/packages/0d/44/c406c3cf5981bddb16ff72acb5ca235888db4073d868cf51bd143bef3aad/trl-0.7.4-py3-none-any.whl.metadata\n",
31
+ " Downloading trl-0.7.4-py3-none-any.whl.metadata (10 kB)\n",
32
+ "Requirement already satisfied: transformers in /opt/conda/lib/python3.10/site-packages (4.36.0)\n",
33
+ "Collecting transformers\n",
34
+ " Obtaining dependency information for transformers from https://files.pythonhosted.org/packages/20/0a/739426a81f7635b422fbe6cb8d1d99d1235579a6ac8024c13d743efa6847/transformers-4.36.2-py3-none-any.whl.metadata\n",
35
+ " Downloading transformers-4.36.2-py3-none-any.whl.metadata (126 kB)\n",
36
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m126.8/126.8 kB\u001b[0m \u001b[31m670.8 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
37
+ "\u001b[?25hRequirement already satisfied: accelerate in /opt/conda/lib/python3.10/site-packages (0.25.0)\n",
38
+ "Requirement already satisfied: torch>=1.4.0 in /opt/conda/lib/python3.10/site-packages (from trl) (2.0.0)\n",
39
+ "Requirement already satisfied: numpy>=1.18.2 in /opt/conda/lib/python3.10/site-packages (from trl) (1.24.3)\n",
40
+ "Requirement already satisfied: datasets in /opt/conda/lib/python3.10/site-packages (from trl) (2.1.0)\n",
41
+ "Collecting tyro>=0.5.11 (from trl)\n",
42
+ " Obtaining dependency information for tyro>=0.5.11 from https://files.pythonhosted.org/packages/c5/11/abdf67467d06713b431618732a43f82d1b1f02120107b05a789afbcdf54d/tyro-0.6.0-py3-none-any.whl.metadata\n",
43
+ " Downloading tyro-0.6.0-py3-none-any.whl.metadata (7.5 kB)\n",
44
+ "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from transformers) (3.12.2)\n",
45
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /opt/conda/lib/python3.10/site-packages (from transformers) (0.19.4)\n",
46
+ "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from transformers) (21.3)\n",
47
+ "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from transformers) (6.0.1)\n",
48
+ "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from transformers) (2023.8.8)\n",
49
+ "Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from transformers) (2.31.0)\n",
50
+ "Requirement already satisfied: tokenizers<0.19,>=0.14 in /opt/conda/lib/python3.10/site-packages (from transformers) (0.15.0)\n",
51
+ "Requirement already satisfied: safetensors>=0.3.1 in /opt/conda/lib/python3.10/site-packages (from transformers) (0.4.1)\n",
52
+ "Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.10/site-packages (from transformers) (4.66.1)\n",
53
+ "Requirement already satisfied: psutil in /opt/conda/lib/python3.10/site-packages (from accelerate) (5.9.3)\n",
54
+ "Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.19.3->transformers) (2023.12.2)\n",
55
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.19.3->transformers) (4.5.0)\n",
56
+ "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging>=20.0->transformers) (3.0.9)\n",
57
+ "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch>=1.4.0->trl) (1.12)\n",
58
+ "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch>=1.4.0->trl) (3.1)\n",
59
+ "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch>=1.4.0->trl) (3.1.2)\n",
60
+ "Requirement already satisfied: docstring-parser>=0.14.1 in /opt/conda/lib/python3.10/site-packages (from tyro>=0.5.11->trl) (0.15)\n",
61
+ "Requirement already satisfied: rich>=11.1.0 in /opt/conda/lib/python3.10/site-packages (from tyro>=0.5.11->trl) (13.5.2)\n",
62
+ "Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl)\n",
63
+ " Obtaining dependency information for shtab>=1.5.6 from https://files.pythonhosted.org/packages/40/ad/7227da64498eaa7abecee4311008f70869e156014b3270cec36e2e70cd31/shtab-1.6.5-py3-none-any.whl.metadata\n",
64
+ " Downloading shtab-1.6.5-py3-none-any.whl.metadata (7.3 kB)\n",
65
+ "Requirement already satisfied: pyarrow>=5.0.0 in /opt/conda/lib/python3.10/site-packages (from datasets->trl) (11.0.0)\n",
66
+ "Requirement already satisfied: dill in /opt/conda/lib/python3.10/site-packages (from datasets->trl) (0.3.7)\n",
67
+ "Requirement already satisfied: pandas in /opt/conda/lib/python3.10/site-packages (from datasets->trl) (2.0.3)\n",
68
+ "Requirement already satisfied: xxhash in /opt/conda/lib/python3.10/site-packages (from datasets->trl) (3.4.1)\n",
69
+ "Requirement already satisfied: multiprocess in /opt/conda/lib/python3.10/site-packages (from datasets->trl) (0.70.15)\n",
70
+ "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.10/site-packages (from datasets->trl) (3.8.5)\n",
71
+ "Requirement already satisfied: responses<0.19 in /opt/conda/lib/python3.10/site-packages (from datasets->trl) (0.18.0)\n",
72
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->transformers) (3.2.0)\n",
73
+ "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->transformers) (3.4)\n",
74
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->transformers) (1.26.15)\n",
75
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->transformers) (2023.11.17)\n",
76
+ "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets->trl) (23.1.0)\n",
77
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets->trl) (6.0.4)\n",
78
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets->trl) (4.0.3)\n",
79
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets->trl) (1.9.2)\n",
80
+ "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets->trl) (1.4.0)\n",
81
+ "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets->trl) (1.3.1)\n",
82
+ "Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/conda/lib/python3.10/site-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n",
83
+ "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/conda/lib/python3.10/site-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.16.1)\n",
84
+ "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch>=1.4.0->trl) (2.1.3)\n",
85
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets->trl) (2.8.2)\n",
86
+ "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets->trl) (2023.3)\n",
87
+ "Requirement already satisfied: tzdata>=2022.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets->trl) (2023.3)\n",
88
+ "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch>=1.4.0->trl) (1.3.0)\n",
89
+ "Requirement already satisfied: mdurl~=0.1 in /opt/conda/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.0)\n",
90
+ "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets->trl) (1.16.0)\n",
91
+ "Downloading trl-0.7.4-py3-none-any.whl (133 kB)\n",
92
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m133.9/133.9 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n",
93
+ "\u001b[?25hDownloading transformers-4.36.2-py3-none-any.whl (8.2 MB)\n",
94
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.2/8.2 MB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
95
+ "\u001b[?25hDownloading tyro-0.6.0-py3-none-any.whl (100 kB)\n",
96
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.9/100.9 kB\u001b[0m \u001b[31m10.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
97
+ "\u001b[?25hDownloading shtab-1.6.5-py3-none-any.whl (13 kB)\n",
98
+ "Building wheels for collected packages: peft\n",
99
+ " Building wheel for peft (pyproject.toml) ... \u001b[?25ldone\n",
100
+ "\u001b[?25h Created wheel for peft: filename=peft-0.7.2.dev0-py3-none-any.whl size=169329 sha256=65c9f890817815f066ee515202e5f5044739b6cb22fcf4ef4280bc3ee8339237\n",
101
+ " Stored in directory: /tmp/pip-ephem-wheel-cache-fsxcdvsj/wheels/d7/c7/de/1368fac8590e1b103ddc2ec2a28ad51d83aded1a3830e8a087\n",
102
+ "Successfully built peft\n",
103
+ "Installing collected packages: shtab, tyro, transformers, trl, peft\n",
104
+ " Attempting uninstall: transformers\n",
105
+ " Found existing installation: transformers 4.36.0\n",
106
+ " Uninstalling transformers-4.36.0:\n",
107
+ " Successfully uninstalled transformers-4.36.0\n",
108
+ "Successfully installed peft-0.7.2.dev0 shtab-1.6.5 transformers-4.36.2 trl-0.7.4 tyro-0.6.0\n",
109
+ "Requirement already satisfied: datasets in /opt/conda/lib/python3.10/site-packages (2.1.0)\n",
110
+ "Collecting datasets\n",
111
+ " Obtaining dependency information for datasets from https://files.pythonhosted.org/packages/e2/cf/db41e572d7ed958e8679018f8190438ef700aeb501b62da9e1eed9e4d69a/datasets-2.15.0-py3-none-any.whl.metadata\n",
112
+ " Downloading datasets-2.15.0-py3-none-any.whl.metadata (20 kB)\n",
113
+ "Collecting bitsandbytes\n",
114
+ " Obtaining dependency information for bitsandbytes from https://files.pythonhosted.org/packages/d9/8d/b62d4fb02587e293e5b91b68bbcaa2d88c6a0360b622e9521d4bd07a20cd/bitsandbytes-0.41.3.post2-py3-none-any.whl.metadata\n",
115
+ " Downloading bitsandbytes-0.41.3.post2-py3-none-any.whl.metadata (9.8 kB)\n",
116
+ "Collecting einops\n",
117
+ " Obtaining dependency information for einops from https://files.pythonhosted.org/packages/29/0b/2d1c0ebfd092e25935b86509a9a817159212d82aa43d7fb07eca4eeff2c2/einops-0.7.0-py3-none-any.whl.metadata\n",
118
+ " Downloading einops-0.7.0-py3-none-any.whl.metadata (13 kB)\n",
119
+ "Requirement already satisfied: wandb in /opt/conda/lib/python3.10/site-packages (0.16.1)\n",
120
+ "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from datasets) (1.24.3)\n",
121
+ "Requirement already satisfied: pyarrow>=8.0.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (11.0.0)\n",
122
+ "Collecting pyarrow-hotfix (from datasets)\n",
123
+ " Obtaining dependency information for pyarrow-hotfix from https://files.pythonhosted.org/packages/e4/f4/9ec2222f5f5f8ea04f66f184caafd991a39c8782e31f5b0266f101cb68ca/pyarrow_hotfix-0.6-py3-none-any.whl.metadata\n",
124
+ " Downloading pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)\n",
125
+ "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (0.3.7)\n",
126
+ "Requirement already satisfied: pandas in /opt/conda/lib/python3.10/site-packages (from datasets) (2.0.3)\n",
127
+ "Requirement already satisfied: requests>=2.19.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (2.31.0)\n",
128
+ "Requirement already satisfied: tqdm>=4.62.1 in /opt/conda/lib/python3.10/site-packages (from datasets) (4.66.1)\n",
129
+ "Requirement already satisfied: xxhash in /opt/conda/lib/python3.10/site-packages (from datasets) (3.4.1)\n",
130
+ "Requirement already satisfied: multiprocess in /opt/conda/lib/python3.10/site-packages (from datasets) (0.70.15)\n",
131
+ "Collecting fsspec[http]<=2023.10.0,>=2023.1.0 (from datasets)\n",
132
+ " Obtaining dependency information for fsspec[http]<=2023.10.0,>=2023.1.0 from https://files.pythonhosted.org/packages/e8/f6/3eccfb530aac90ad1301c582da228e4763f19e719ac8200752a4841b0b2d/fsspec-2023.10.0-py3-none-any.whl.metadata\n",
133
+ " Downloading fsspec-2023.10.0-py3-none-any.whl.metadata (6.8 kB)\n",
134
+ "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.10/site-packages (from datasets) (3.8.5)\n",
135
+ "Requirement already satisfied: huggingface-hub>=0.18.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (0.19.4)\n",
136
+ "Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from datasets) (21.3)\n",
137
+ "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from datasets) (6.0.1)\n",
138
+ "Requirement already satisfied: Click!=8.0.0,>=7.1 in /opt/conda/lib/python3.10/site-packages (from wandb) (8.1.7)\n",
139
+ "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /opt/conda/lib/python3.10/site-packages (from wandb) (3.1.32)\n",
140
+ "Requirement already satisfied: psutil>=5.0.0 in /opt/conda/lib/python3.10/site-packages (from wandb) (5.9.3)\n",
141
+ "Requirement already satisfied: sentry-sdk>=1.0.0 in /opt/conda/lib/python3.10/site-packages (from wandb) (1.39.0)\n",
142
+ "Requirement already satisfied: docker-pycreds>=0.4.0 in /opt/conda/lib/python3.10/site-packages (from wandb) (0.4.0)\n",
143
+ "Requirement already satisfied: setproctitle in /opt/conda/lib/python3.10/site-packages (from wandb) (1.3.3)\n",
144
+ "Requirement already satisfied: setuptools in /opt/conda/lib/python3.10/site-packages (from wandb) (68.1.2)\n",
145
+ "Requirement already satisfied: appdirs>=1.4.3 in /opt/conda/lib/python3.10/site-packages (from wandb) (1.4.4)\n",
146
+ "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /opt/conda/lib/python3.10/site-packages (from wandb) (3.20.3)\n",
147
+ "Requirement already satisfied: six>=1.4.0 in /opt/conda/lib/python3.10/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
148
+ "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (23.1.0)\n",
149
+ "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (3.2.0)\n",
150
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (6.0.4)\n",
151
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (4.0.3)\n",
152
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (1.9.2)\n",
153
+ "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (1.4.0)\n",
154
+ "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.1)\n",
155
+ "Requirement already satisfied: gitdb<5,>=4.0.1 in /opt/conda/lib/python3.10/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\n",
156
+ "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.18.0->datasets) (3.12.2)\n",
157
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.18.0->datasets) (4.5.0)\n",
158
+ "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging->datasets) (3.0.9)\n",
159
+ "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (3.4)\n",
160
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (1.26.15)\n",
161
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (2023.11.17)\n",
162
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2.8.2)\n",
163
+ "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2023.3)\n",
164
+ "Requirement already satisfied: tzdata>=2022.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2023.3)\n",
165
+ "Requirement already satisfied: smmap<6,>=3.0.1 in /opt/conda/lib/python3.10/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\n",
166
+ "Downloading datasets-2.15.0-py3-none-any.whl (521 kB)\n",
167
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m521.2/521.2 kB\u001b[0m \u001b[31m22.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
168
+ "\u001b[?25hDownloading bitsandbytes-0.41.3.post2-py3-none-any.whl (92.6 MB)\n",
169
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m92.6/92.6 MB\u001b[0m \u001b[31m14.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n",
170
+ "\u001b[?25hDownloading einops-0.7.0-py3-none-any.whl (44 kB)\n",
171
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.6/44.6 kB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
172
+ "\u001b[?25hDownloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)\n",
173
+ "Downloading fsspec-2023.10.0-py3-none-any.whl (166 kB)\n",
174
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m166.4/166.4 kB\u001b[0m \u001b[31m18.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
175
+ "\u001b[?25hInstalling collected packages: bitsandbytes, pyarrow-hotfix, fsspec, einops, datasets\n",
176
+ " Attempting uninstall: fsspec\n",
177
+ " Found existing installation: fsspec 2023.12.2\n",
178
+ " Uninstalling fsspec-2023.12.2:\n",
179
+ " Successfully uninstalled fsspec-2023.12.2\n",
180
+ " Attempting uninstall: datasets\n",
181
+ " Found existing installation: datasets 2.1.0\n",
182
+ " Uninstalling datasets-2.1.0:\n",
183
+ " Successfully uninstalled datasets-2.1.0\n",
184
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
185
+ "cudf 23.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.\n",
186
+ "cuml 23.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.\n",
187
+ "dask-cudf 23.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.\n",
188
+ "cudf 23.8.0 requires pandas<1.6.0dev0,>=1.3, but you have pandas 2.0.3 which is incompatible.\n",
189
+ "cudf 23.8.0 requires protobuf<5,>=4.21, but you have protobuf 3.20.3 which is incompatible.\n",
190
+ "cuml 23.8.0 requires dask==2023.7.1, but you have dask 2023.12.0 which is incompatible.\n",
191
+ "cuml 23.8.0 requires distributed==2023.7.1, but you have distributed 2023.12.0 which is incompatible.\n",
192
+ "dask-cuda 23.8.0 requires dask==2023.7.1, but you have dask 2023.12.0 which is incompatible.\n",
193
+ "dask-cuda 23.8.0 requires distributed==2023.7.1, but you have distributed 2023.12.0 which is incompatible.\n",
194
+ "dask-cuda 23.8.0 requires pandas<1.6.0dev0,>=1.3, but you have pandas 2.0.3 which is incompatible.\n",
195
+ "dask-cudf 23.8.0 requires dask==2023.7.1, but you have dask 2023.12.0 which is incompatible.\n",
196
+ "dask-cudf 23.8.0 requires distributed==2023.7.1, but you have distributed 2023.12.0 which is incompatible.\n",
197
+ "dask-cudf 23.8.0 requires pandas<1.6.0dev0,>=1.3, but you have pandas 2.0.3 which is incompatible.\n",
198
+ "gcsfs 2023.6.0 requires fsspec==2023.6.0, but you have fsspec 2023.10.0 which is incompatible.\n",
199
+ "raft-dask 23.8.0 requires dask==2023.7.1, but you have dask 2023.12.0 which is incompatible.\n",
200
+ "raft-dask 23.8.0 requires distributed==2023.7.1, but you have distributed 2023.12.0 which is incompatible.\n",
201
+ "s3fs 2023.12.2 requires fsspec==2023.12.2, but you have fsspec 2023.10.0 which is incompatible.\u001b[0m\u001b[31m\n",
202
+ "\u001b[0mSuccessfully installed bitsandbytes-0.41.3.post2 datasets-2.15.0 einops-0.7.0 fsspec-2023.10.0 pyarrow-hotfix-0.6\n"
203
+ ]
204
+ }
205
+ ],
206
+ "source": [
207
+ "!pip install -U trl transformers accelerate git+https://github.com/huggingface/peft.git\n",
208
+ "!pip install -U datasets bitsandbytes einops wandb\n"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": 2,
214
+ "metadata": {
215
+ "execution": {
216
+ "iopub.execute_input": "2023-12-21T18:22:23.573426Z",
217
+ "iopub.status.busy": "2023-12-21T18:22:23.572634Z",
218
+ "iopub.status.idle": "2023-12-21T18:22:52.776357Z",
219
+ "shell.execute_reply": "2023-12-21T18:22:52.775443Z",
220
+ "shell.execute_reply.started": "2023-12-21T18:22:23.573393Z"
221
+ }
222
+ },
223
+ "outputs": [
224
+ {
225
+ "data": {
226
+ "text/plain": [
227
+ "config.json: 0%| | 0.00/755 [00:00<?, ?B/s]"
228
+ ]
229
+ },
230
+ "metadata": {},
231
+ "output_type": "display_data"
232
+ },
233
+ {
234
+ "data": {
235
+ "text/plain": [
236
+ "configuration_phi.py: 0%| | 0.00/2.03k [00:00<?, ?B/s]"
237
+ ]
238
+ },
239
+ "metadata": {},
240
+ "output_type": "display_data"
241
+ },
242
+ {
243
+ "name": "stderr",
244
+ "output_type": "stream",
245
+ "text": [
246
+ "A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-2:\n",
247
+ "- configuration_phi.py\n",
248
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
249
+ ]
250
+ },
251
+ {
252
+ "data": {
253
+ "text/plain": [
254
+ "modeling_phi.py: 0%| | 0.00/33.4k [00:00<?, ?B/s]"
255
+ ]
256
+ },
257
+ "metadata": {},
258
+ "output_type": "display_data"
259
+ },
260
+ {
261
+ "name": "stderr",
262
+ "output_type": "stream",
263
+ "text": [
264
+ "A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-2:\n",
265
+ "- modeling_phi.py\n",
266
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
267
+ ]
268
+ },
269
+ {
270
+ "data": {
271
+ "text/plain": [
272
+ "model.safetensors.index.json: 0%| | 0.00/24.3k [00:00<?, ?B/s]"
273
+ ]
274
+ },
275
+ "metadata": {},
276
+ "output_type": "display_data"
277
+ },
278
+ {
279
+ "data": {
280
+ "text/plain": [
281
+ "Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]"
282
+ ]
283
+ },
284
+ "metadata": {},
285
+ "output_type": "display_data"
286
+ },
287
+ {
288
+ "data": {
289
+ "text/plain": [
290
+ "model-00001-of-00002.safetensors: 0%| | 0.00/4.98G [00:00<?, ?B/s]"
291
+ ]
292
+ },
293
+ "metadata": {},
294
+ "output_type": "display_data"
295
+ },
296
+ {
297
+ "data": {
298
+ "text/plain": [
299
+ "model-00002-of-00002.safetensors: 0%| | 0.00/577M [00:00<?, ?B/s]"
300
+ ]
301
+ },
302
+ "metadata": {},
303
+ "output_type": "display_data"
304
+ },
305
+ {
306
+ "name": "stderr",
307
+ "output_type": "stream",
308
+ "text": [
309
+ "/opt/conda/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.24.3\n",
310
+ " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
311
+ ]
312
+ },
313
+ {
314
+ "data": {
315
+ "text/plain": [
316
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
317
+ ]
318
+ },
319
+ "metadata": {},
320
+ "output_type": "display_data"
321
+ },
322
+ {
323
+ "data": {
324
+ "text/plain": [
325
+ "generation_config.json: 0%| | 0.00/69.0 [00:00<?, ?B/s]"
326
+ ]
327
+ },
328
+ "metadata": {},
329
+ "output_type": "display_data"
330
+ },
331
+ {
332
+ "name": "stdout",
333
+ "output_type": "stream",
334
+ "text": [
335
+ "PhiForCausalLM(\n",
336
+ " (transformer): PhiModel(\n",
337
+ " (embd): Embedding(\n",
338
+ " (wte): Embedding(51200, 2560)\n",
339
+ " (drop): Dropout(p=0.0, inplace=False)\n",
340
+ " )\n",
341
+ " (h): ModuleList(\n",
342
+ " (0-31): 32 x ParallelBlock(\n",
343
+ " (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)\n",
344
+ " (resid_dropout): Dropout(p=0.1, inplace=False)\n",
345
+ " (mixer): MHA(\n",
346
+ " (rotary_emb): RotaryEmbedding()\n",
347
+ " (Wqkv): Linear4bit(in_features=2560, out_features=7680, bias=True)\n",
348
+ " (out_proj): Linear4bit(in_features=2560, out_features=2560, bias=True)\n",
349
+ " (inner_attn): SelfAttention(\n",
350
+ " (drop): Dropout(p=0.0, inplace=False)\n",
351
+ " )\n",
352
+ " (inner_cross_attn): CrossAttention(\n",
353
+ " (drop): Dropout(p=0.0, inplace=False)\n",
354
+ " )\n",
355
+ " )\n",
356
+ " (mlp): MLP(\n",
357
+ " (fc1): Linear4bit(in_features=2560, out_features=10240, bias=True)\n",
358
+ " (fc2): Linear4bit(in_features=10240, out_features=2560, bias=True)\n",
359
+ " (act): NewGELUActivation()\n",
360
+ " )\n",
361
+ " )\n",
362
+ " )\n",
363
+ " )\n",
364
+ " (lm_head): CausalLMHead(\n",
365
+ " (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)\n",
366
+ " (linear): Linear(in_features=2560, out_features=51200, bias=True)\n",
367
+ " )\n",
368
+ " (loss): CausalLMLoss(\n",
369
+ " (loss_fct): CrossEntropyLoss()\n",
370
+ " )\n",
371
+ ")\n"
372
+ ]
373
+ }
374
+ ],
375
+ "source": [
376
+ "import torch\n",
377
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer\n",
378
+ "\n",
379
+ "#model_name = \"ybelkada/falcon-7b-sharded-bf16\"\n",
380
+ "model_name = \"microsoft/phi-2\"\n",
381
+ "bnb_config = BitsAndBytesConfig(\n",
382
+ " load_in_4bit=True,\n",
383
+ " bnb_4bit_quant_type=\"nf4\",\n",
384
+ " bnb_4bit_compute_dtype=torch.float16,\n",
385
+ ")\n",
386
+ "\n",
387
+ "model = AutoModelForCausalLM.from_pretrained(\n",
388
+ " model_name,\n",
389
+ " quantization_config=bnb_config,\n",
390
+ " trust_remote_code=True\n",
391
+ ")\n",
392
+ "model.config.use_cache = False\n",
393
+ "print(model)"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "code",
398
+ "execution_count": 3,
399
+ "metadata": {
400
+ "execution": {
401
+ "iopub.execute_input": "2023-12-21T18:22:59.972504Z",
402
+ "iopub.status.busy": "2023-12-21T18:22:59.971964Z",
403
+ "iopub.status.idle": "2023-12-21T18:23:02.356756Z",
404
+ "shell.execute_reply": "2023-12-21T18:23:02.355664Z",
405
+ "shell.execute_reply.started": "2023-12-21T18:22:59.972462Z"
406
+ }
407
+ },
408
+ "outputs": [
409
+ {
410
+ "data": {
411
+ "text/plain": [
412
+ "tokenizer_config.json: 0%| | 0.00/7.34k [00:00<?, ?B/s]"
413
+ ]
414
+ },
415
+ "metadata": {},
416
+ "output_type": "display_data"
417
+ },
418
+ {
419
+ "data": {
420
+ "text/plain": [
421
+ "vocab.json: 0%| | 0.00/798k [00:00<?, ?B/s]"
422
+ ]
423
+ },
424
+ "metadata": {},
425
+ "output_type": "display_data"
426
+ },
427
+ {
428
+ "data": {
429
+ "text/plain": [
430
+ "merges.txt: 0%| | 0.00/456k [00:00<?, ?B/s]"
431
+ ]
432
+ },
433
+ "metadata": {},
434
+ "output_type": "display_data"
435
+ },
436
+ {
437
+ "data": {
438
+ "text/plain": [
439
+ "tokenizer.json: 0%| | 0.00/2.11M [00:00<?, ?B/s]"
440
+ ]
441
+ },
442
+ "metadata": {},
443
+ "output_type": "display_data"
444
+ },
445
+ {
446
+ "data": {
447
+ "text/plain": [
448
+ "added_tokens.json: 0%| | 0.00/1.08k [00:00<?, ?B/s]"
449
+ ]
450
+ },
451
+ "metadata": {},
452
+ "output_type": "display_data"
453
+ },
454
+ {
455
+ "data": {
456
+ "text/plain": [
457
+ "special_tokens_map.json: 0%| | 0.00/99.0 [00:00<?, ?B/s]"
458
+ ]
459
+ },
460
+ "metadata": {},
461
+ "output_type": "display_data"
462
+ },
463
+ {
464
+ "name": "stderr",
465
+ "output_type": "stream",
466
+ "text": [
467
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
468
+ ]
469
+ }
470
+ ],
471
+ "source": [
472
+ "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
473
+ "tokenizer.pad_token = tokenizer.eos_token"
474
+ ]
475
+ },
476
+ {
477
+ "cell_type": "code",
478
+ "execution_count": 4,
479
+ "metadata": {
480
+ "execution": {
481
+ "iopub.execute_input": "2023-12-21T18:23:12.232728Z",
482
+ "iopub.status.busy": "2023-12-21T18:23:12.231764Z",
483
+ "iopub.status.idle": "2023-12-21T18:23:12.276671Z",
484
+ "shell.execute_reply": "2023-12-21T18:23:12.275949Z",
485
+ "shell.execute_reply.started": "2023-12-21T18:23:12.232691Z"
486
+ }
487
+ },
488
+ "outputs": [],
489
+ "source": [
490
+ "from peft import LoraConfig\n",
491
+ "\n",
492
+ "lora_alpha = 16\n",
493
+ "lora_dropout = 0.1\n",
494
+ "lora_r = 64\n",
495
+ "\n",
496
+ "'''target_modules = [\n",
497
+ " \"query_key_value\",#Wqkv\n",
498
+ " \"dense\",#out_proj\n",
499
+ " \"dense_h_to_4h\", #fc1\n",
500
+ " \"dense_4h_to_h\", #fc2\n",
501
+ "]'''\n",
502
+ "\n",
503
+ "target_modules = [\n",
504
+ " \"Wqkv\",\n",
505
+ " \"out_proj\",\n",
506
+ " \"fc1\",\n",
507
+ " \"fc2\"\n",
508
+ "]\n",
509
+ "\n",
510
+ "peft_config = LoraConfig(\n",
511
+ " lora_alpha=lora_alpha,\n",
512
+ " lora_dropout=lora_dropout,\n",
513
+ " r=lora_r,\n",
514
+ " bias=\"none\",\n",
515
+ " task_type=\"CAUSAL_LM\",\n",
516
+ " target_modules=target_modules\n",
517
+ ")"
518
+ ]
519
+ },
520
+ {
521
+ "cell_type": "code",
522
+ "execution_count": 5,
523
+ "metadata": {
524
+ "execution": {
525
+ "iopub.execute_input": "2023-12-21T18:23:20.097392Z",
526
+ "iopub.status.busy": "2023-12-21T18:23:20.096964Z",
527
+ "iopub.status.idle": "2023-12-21T18:23:26.215631Z",
528
+ "shell.execute_reply": "2023-12-21T18:23:26.214868Z",
529
+ "shell.execute_reply.started": "2023-12-21T18:23:20.097362Z"
530
+ }
531
+ },
532
+ "outputs": [
533
+ {
534
+ "data": {
535
+ "text/plain": [
536
+ "Downloading readme: 0%| | 0.00/10.2k [00:00<?, ?B/s]"
537
+ ]
538
+ },
539
+ "metadata": {},
540
+ "output_type": "display_data"
541
+ },
542
+ {
543
+ "data": {
544
+ "text/plain": [
545
+ "Downloading data files: 0%| | 0/2 [00:00<?, ?it/s]"
546
+ ]
547
+ },
548
+ "metadata": {},
549
+ "output_type": "display_data"
550
+ },
551
+ {
552
+ "data": {
553
+ "text/plain": [
554
+ "Downloading data: 0%| | 0.00/39.5M [00:00<?, ?B/s]"
555
+ ]
556
+ },
557
+ "metadata": {},
558
+ "output_type": "display_data"
559
+ },
560
+ {
561
+ "data": {
562
+ "text/plain": [
563
+ "Downloading data: 0%| | 0.00/2.08M [00:00<?, ?B/s]"
564
+ ]
565
+ },
566
+ "metadata": {},
567
+ "output_type": "display_data"
568
+ },
569
+ {
570
+ "data": {
571
+ "text/plain": [
572
+ "Extracting data files: 0%| | 0/2 [00:00<?, ?it/s]"
573
+ ]
574
+ },
575
+ "metadata": {},
576
+ "output_type": "display_data"
577
+ },
578
+ {
579
+ "data": {
580
+ "text/plain": [
581
+ "Generating train split: 0%| | 0/84437 [00:00<?, ? examples/s]"
582
+ ]
583
+ },
584
+ "metadata": {},
585
+ "output_type": "display_data"
586
+ },
587
+ {
588
+ "data": {
589
+ "text/plain": [
590
+ "Generating validation split: 0%| | 0/4401 [00:00<?, ? examples/s]"
591
+ ]
592
+ },
593
+ "metadata": {},
594
+ "output_type": "display_data"
595
+ }
596
+ ],
597
+ "source": [
598
+ "from datasets import load_dataset\n",
599
+ "\n",
600
+ "#dataset_name = \"timdettmers/openassistant-guanaco\"\n",
601
+ "dataset_name = \"OpenAssistant/oasst1\"\n",
602
+ "full_dataset = load_dataset(dataset_name, split=\"train\")\n"
603
+ ]
604
+ },
605
+ {
606
+ "cell_type": "code",
607
+ "execution_count": 10,
608
+ "metadata": {
609
+ "execution": {
610
+ "iopub.execute_input": "2023-12-21T18:27:25.393857Z",
611
+ "iopub.status.busy": "2023-12-21T18:27:25.393527Z",
612
+ "iopub.status.idle": "2023-12-21T18:27:53.555516Z",
613
+ "shell.execute_reply": "2023-12-21T18:27:53.554597Z",
614
+ "shell.execute_reply.started": "2023-12-21T18:27:25.393829Z"
615
+ }
616
+ },
617
+ "outputs": [
618
+ {
619
+ "name": "stdout",
620
+ "output_type": "stream",
621
+ "text": [
622
+ "{'text': '### Human: Escribir un cuento para dormir dirigido a niños de entre 5 y 8 años### Assistant: Había una vez un pequeño ratón llamado Miguel que vivía en un agujero en una pared de piedra. Miguel era un ratón muy curioso, y le encantaba explorar su entorno en busca de aventuras emocionantes.\\n\\nUn día, mientras Miguel exploraba los alrededores, encontró un pequeño agujero en la pared que no había visto antes. Decidió investigar y se arrastró por el agujero hasta llegar a un cuarto muy oscuro y tenebroso.\\n\\nMiguel estaba muy asustado, pero también muy curioso, así que decidió explorar un poco más. Pronto, encontró una vieja caja de madera que parecía estar abandonada. Con mucho cuidado, Miguel abrió la caja, y dentro encontró una hermosa varita mágica.\\n\\nMiguel nunca había visto nada igual antes, y estaba emocionado por su descubrimiento. Decidió llevar la varita mágica de vuelta a su agujero en la pared y esconderla allí.\\n\\nLa noche siguiente, Miguel no podía dormir pensando en la varita mágica. Finalmente, se acercó a la varita y la agarró con sus pequeñas patitas. Para su sorpresa, la varita empezó a brillar con una luz mágica y una hada apareció ante él.\\n\\nEl hada le dijo a Miguel que era un ratón muy especial por haber encontrado la varita mágica, y que le concedería un deseo. Miguel pensó en todas las cosas que quería, pero finalmente decidió que lo que más quería era tener amigos.\\n\\nEl hada sonrió y movió la varita mágica. De repente, un grupo de ratones de todas las formas y tamaños aparecieron ante Miguel. Todos eran amables y divertidos, y Miguel estaba muy feliz de haber hecho nuevos amigos.\\n\\nDesde entonces, Miguel y sus nuevos amigos se reunían todas las noches para jugar y contar historias. Y así, Miguel aprendió que la amistad y la aventura pueden encontrarse en los lugares más inesperados. Y con una sonrisa en el rostro, Miguel cerró los ojos y se quedó dormido, soñando con las próximas aventuras que tendría con sus nuevos amigos.### Human: Me ha encantado tu cuento ¿Podrías reescribirme el cuento pero escribiendo una historia que no contenga magia?### Assistant: ¡Me alegro que te haya encantado! Aquí va otro cuento ¡Espero que te agrade!\\n\\nHabía una vez un pequeño pueblo en el que vivían muchos animales. Había conejos, zorros, pájaros y muchos otros más. Un día, un pequeño conejito llamado Ben decidió explorar el bosque cercano a su casa.\\n\\nMientras caminaba, vio a un pequeño pajarito que estaba muy triste. Ben se acercó y le preguntó qué le sucedía. El pajarito le explicó que había perdido su nido y que no sabía dónde estaba su familia.\\n\\nBen, sintiendo mucha empatía por el pajarito, decidió ayudarlo a encontrar su nido. Juntos, buscaron en todos los árboles y arbustos del bosque. Después de mucho buscar, encontraron el nido del pajarito. Estaba en un árbol muy alto y parecía que había sido arrastrado por el viento.\\n\\nBen y el pajarito trabajaron juntos para llevar el nido de regreso al árbol y asegurarlo allí. Cuando terminaron, el pajarito estaba muy agradecido y le dijo a Ben que lo consideraba un amigo para siempre.\\n\\nDesde ese día, Ben y el pajarito se convirtieron en amigos inseparables. Juntos exploraron el bosque, hicieron nuevos amigos y vivieron muchas aventuras emocionantes.\\n\\nFinalmente, llegó el momento de dormir y Ben y el pajarito regresaron a sus hogares. Mientras se acurrucaba en su cama, Ben se sintió muy feliz de haber encontrado un amigo tan especial. Y mientras cerraba los ojos, se dio cuenta de que la amistad era la magia más poderosa de todas.### Human: Te he pedido que me reescribas el mismo cuento sin magia, no que inventaras un cuento nuevo.### Assistant: Había una vez un pequeño ratón llamado Miguel que vivía en un agujero en una pared de piedra. Miguel era un ratón muy curioso, y le encantaba explorar su entorno en busca de aventuras emocionantes.\\n\\nUn día, mientras Miguel exploraba los alrededores, encontró un pequeño agujero en la pared que no había visto antes. Decidió investigar y se arrastró por el agujero hasta llegar a un cuarto muy oscuro y tenebroso.\\n\\nMiguel estaba muy asustado, pero también muy curioso, así que decidió explorar un poco más. Pronto, encontró una vieja caja de madera que parecía estar abandonada. Con mucho cuidado, Miguel abrió la caja, y dentro encontró un objeto muy extraño para él (no para nosotros los humanos).\\n\\nMiguel nunca había visto nada igual antes, y estaba emocionado por su descubrimiento. Decidió llevar el objeto de vuelta a su agujero en la pared y esconderlo allí.\\n\\nLa noche siguiente, Miguel no podía dormir pensando en su descubrimiento. Finalmente, se acercó al objeto y lo agarró con sus pequeñas patitas. Para su sorpresa, el objeto empezó a brillar con una luz y una idea apareció ante él (es una lámpara portátil, pero, no se lo digan).\\n\\nLa idea le decía a Miguel que era un ratón muy especial por haber encontrado algo así, y que le proponía una actividad. Miguel pensó en todas las cosas que podría hacer, pero finalmente decidió que lo que más quería era conocer a otros ratones como él, y con ello.\\n\\nSu interior sonrió y Miguel empezó a reflexionar en cómo podría hacer para conocer a otros ratones. Decidió que podría invitar a los ratones de los alrededores a su agujero en la pared, atrayéndolos con el objeto extraño para jugar y contar historias.\\n\\nDesde entonces, Miguel encendía la lámpara para llamar a sus nuevos amigos, se reunían todas las noches para jugar y contar historias. Y así, Miguel aprendió que la amistad y la aventura pueden encontrarse en los lugares más inesperados, y que los artilugios traen las soluciones más inesperadas. Y con una sonrisa en el rostro, Miguel cerró los ojos y se quedó dormido, soñando con las próximas aventuras que tendría con sus nuevos amigos.\\n\\nEsperemos que la batería de la lámpara portátil nunca se acabe, pero, lo bueno es que ya todos se conocen, y nunca más necesitara usar la lámpara.\\n\\nEllos fueron amigos por siempre...\\n\\nEspero que haya sido de ayuda, si tienes alguna otra pregunta o necesitas algo más, no dudes en consultar.### Assistant: Había una vez un pequeño zorro llamado Max, que vivía en un bosque mágico rodeado de árboles enormes y un río cristalino. Max era un zorro muy curioso y le encantaba explorar su hogar en busca de aventuras. Una noche, mientras se preparaba para dormir, Max escuchó un ruido extraño que venía del bosque. Curioso como era, decidió investigar.\\n\\nAl salir de su madriguera, Max se dio cuenta de que el ruido venía de un árbol cercano. Al acercarse, encontró a un pequeño búho atrapado entre las ramas. Max se acercó y le preguntó al búho qué le había pasado. El búho explicó que había estado cazando ratones y que accidentalmente se había enredado en las ramas del árbol.\\n\\nMax no dudó en ayudar al búho. Con su agilidad y astucia, logró desenredar al pequeño búho y lo llevó de regreso a su nido en lo alto del árbol. El búho estaba muy agradecido y le prometió a Max que si alguna vez necesitaba ayuda, estaría allí para él.\\n\\nA partir de ese día, Max y el búho se convirtieron en amigos inseparables. Max aprendió mucho del sabio búho y juntos exploraron el bosque en busca de nuevas aventuras. Ya no se sentía solo en su madriguera, sabía que tenía un amigo en quien podía confiar.\\n\\nLa noche que Max había salvado al búho, mientras regresaban a casa, un espectáculo de luces apareció en el cielo, parecía una lluvia de estrellas. El búho le contó a Max que esto solo sucedía una vez al año, y que significaba que había magia en el aire.\\n\\nMax se dio cuenta de que la magia ya estaba en su vida, porque había encontrado a un amigo verdadero. Y así, acurrucados en su madriguera, Max y el búho se quedaron dormidos bajo la lluvia de estrellas, con la seguridad de que su amistad los llevaría a muchas más aventuras y descubrimientos en el bosque mágico.### Human: Escribe el mismo cuento, pero añadiendo diálogos### Assistant: Había una vez un pequeño zorro llamado Max, que vivía en un bosque mágico rodeado de árboles enormes y un río cristalino. Max era un zorro muy curioso y le encantaba explorar su hogar en busca de aventuras. Una noche, mientras se preparaba para dormir, Max escuchó un ruido extraño que venía del bosque. Curioso como era, decidió investigar.\\n\\n\"¿Qué será ese ruido?\" - se preguntó Max en voz alta.\\n\\n\"¿Qué pasa, Max?\" - preguntó su amigo búho, quien había estado despierto.\\n\\n\"Escuché un ruido extraño y fui a investigar\" - respondió Max.\\n\\n\"¿Quieres que te acompañe?\" - ofreció el búho.\\n\\nMax aceptó, y juntos se dirigieron hacia el árbol donde provenía el ruido. Al acercarse, encontraron al pequeño búho atrapado entre las ramas.\\n\\n\"¿Qué ha pasado aquí?\" - preguntó el búho al ver al pequeño atrapado.\\n\\n\"He estado cazando ratones y me enredé en las ramas del árbol\" - explicó el pequeño búho.\\n\\n\"¡No te preocupes! Vamos a ayudarte a salir de ahí\" - dijo Max decidido.\\n\\nCon su agilidad y astucia, Max logró desenredar al pequeño búho y lo llevó de regreso a su nido en lo alto del árbol.\\n\\n\"¡Gracias por ayudarme! No sé cómo podría haber salido de ahí sin ti\" - agradeció el pequeño búho.\\n\\n\"¡No hay problema! Eso es lo que hacen los amigos\" - respondió Max con una sonrisa.\\n\\nA partir de ese día, Max y el búho se convirtieron en amigos inseparables. Juntos exploraron el bosque en busca de nuevas aventuras y descubrieron cosas fascinantes sobre su hogar.\\n\\n\"¡Mira! Hay una cueva escondida detrás de ese árbol\" - exclamó Max emocionado.\\n\\n\"¡Vamos a investigarla!\" - animó el búho.\\n\\nYa no se sentía solo en su madriguera, sabía que tenía un amigo en quien podía confiar. Y así, acurrucados en su madriguera, Max y el búho se quedaron dormidos bajo la lluvia de estrellas, con la seguridad de que su amistad los llevaría a muchas más aventuras y descubrimientos en el bosque mágico.'}\n"
623
+ ]
624
+ }
625
+ ],
626
+ "source": [
627
+ "from datasets import Dataset\n",
628
+ "count = -1\n",
629
+ "conversations = []\n",
630
+ "for item in full_dataset:\n",
631
+ " parent_id = item['parent_id']\n",
632
+ " text = item['text']\n",
633
+ " role = item['role']\n",
634
+ " \n",
635
+ " if parent_id is None:\n",
636
+ " conversations.append('')\n",
637
+ " count += 1\n",
638
+ " \n",
639
+ " \n",
640
+ " if role == 'prompter':\n",
641
+ " conversations[count] += '### Human: %s' % text\n",
642
+ " elif role == 'assistant':\n",
643
+ " conversations[count] += '### Assistant: %s' % text\n",
644
+ " \n",
645
+ "#convo_list = [c for c in conversations]\n",
646
+ "dataset = Dataset.from_dict(dict(text=conversations))\n",
647
+ "print(dataset[345])"
648
+ ]
649
+ },
650
+ {
651
+ "cell_type": "code",
652
+ "execution_count": 11,
653
+ "metadata": {
654
+ "execution": {
655
+ "iopub.execute_input": "2023-12-21T18:28:24.234945Z",
656
+ "iopub.status.busy": "2023-12-21T18:28:24.234242Z",
657
+ "iopub.status.idle": "2023-12-21T18:28:24.260143Z",
658
+ "shell.execute_reply": "2023-12-21T18:28:24.259339Z",
659
+ "shell.execute_reply.started": "2023-12-21T18:28:24.234909Z"
660
+ }
661
+ },
662
+ "outputs": [],
663
+ "source": [
664
+ "from transformers import TrainingArguments\n",
665
+ "\n",
666
+ "output_dir = \"./results\"\n",
667
+ "per_device_train_batch_size = 2\n",
668
+ "gradient_accumulation_steps = 8\n",
669
+ "optim = \"paged_adamw_32bit\"\n",
670
+ "save_steps = 100\n",
671
+ "logging_steps = 10\n",
672
+ "learning_rate = 2e-4\n",
673
+ "max_grad_norm = 0.3\n",
674
+ "max_steps = 500\n",
675
+ "warmup_ratio = 0.03\n",
676
+ "lr_scheduler_type = \"constant\"\n",
677
+ "\n",
678
+ "training_arguments = TrainingArguments(\n",
679
+ " output_dir=output_dir,\n",
680
+ " per_device_train_batch_size=per_device_train_batch_size,\n",
681
+ " gradient_accumulation_steps=gradient_accumulation_steps,\n",
682
+ " optim=optim,\n",
683
+ " save_steps=save_steps,\n",
684
+ " logging_steps=logging_steps,\n",
685
+ " learning_rate=learning_rate,\n",
686
+ " fp16=True,\n",
687
+ " max_grad_norm=max_grad_norm,\n",
688
+ " max_steps=max_steps,\n",
689
+ " warmup_ratio=warmup_ratio,\n",
690
+ " group_by_length=True,\n",
691
+ " lr_scheduler_type=lr_scheduler_type\n",
692
+ ")"
693
+ ]
694
+ },
695
+ {
696
+ "cell_type": "code",
697
+ "execution_count": 12,
698
+ "metadata": {
699
+ "execution": {
700
+ "iopub.execute_input": "2023-12-21T18:28:28.994021Z",
701
+ "iopub.status.busy": "2023-12-21T18:28:28.993650Z",
702
+ "iopub.status.idle": "2023-12-21T18:28:59.848088Z",
703
+ "shell.execute_reply": "2023-12-21T18:28:59.847330Z",
704
+ "shell.execute_reply.started": "2023-12-21T18:28:28.993993Z"
705
+ }
706
+ },
707
+ "outputs": [
708
+ {
709
+ "name": "stderr",
710
+ "output_type": "stream",
711
+ "text": [
712
+ "/opt/conda/lib/python3.10/site-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n",
713
+ " warnings.warn(\n"
714
+ ]
715
+ },
716
+ {
717
+ "data": {
718
+ "text/plain": [
719
+ "Map: 0%| | 0/9846 [00:00<?, ? examples/s]"
720
+ ]
721
+ },
722
+ "metadata": {},
723
+ "output_type": "display_data"
724
+ }
725
+ ],
726
+ "source": [
727
+ "from trl import SFTTrainer\n",
728
+ "\n",
729
+ "max_seq_length = 256\n",
730
+ "\n",
731
+ "trainer = SFTTrainer(\n",
732
+ " model=model,\n",
733
+ " train_dataset=dataset,\n",
734
+ " peft_config=peft_config,\n",
735
+ " dataset_text_field=\"text\",\n",
736
+ " max_seq_length=max_seq_length,\n",
737
+ " tokenizer=tokenizer,\n",
738
+ " args=training_arguments,\n",
739
+ ")"
740
+ ]
741
+ },
742
+ {
743
+ "cell_type": "code",
744
+ "execution_count": 13,
745
+ "metadata": {
746
+ "execution": {
747
+ "iopub.execute_input": "2023-12-21T18:29:19.034538Z",
748
+ "iopub.status.busy": "2023-12-21T18:29:19.034143Z",
749
+ "iopub.status.idle": "2023-12-21T18:29:19.045041Z",
750
+ "shell.execute_reply": "2023-12-21T18:29:19.043887Z",
751
+ "shell.execute_reply.started": "2023-12-21T18:29:19.034509Z"
752
+ }
753
+ },
754
+ "outputs": [],
755
+ "source": [
756
+ "for name, module in trainer.model.named_modules():\n",
757
+ " if \"norm\" in name:\n",
758
+ " module = module.to(torch.float32)"
759
+ ]
760
+ },
761
+ {
762
+ "cell_type": "code",
763
+ "execution_count": 14,
764
+ "metadata": {
765
+ "execution": {
766
+ "iopub.execute_input": "2023-12-21T18:29:28.833966Z",
767
+ "iopub.status.busy": "2023-12-21T18:29:28.833572Z",
768
+ "iopub.status.idle": "2023-12-21T19:54:14.245160Z",
769
+ "shell.execute_reply": "2023-12-21T19:54:14.244185Z",
770
+ "shell.execute_reply.started": "2023-12-21T18:29:28.833937Z"
771
+ }
772
+ },
773
+ "outputs": [
774
+ {
775
+ "name": "stderr",
776
+ "output_type": "stream",
777
+ "text": [
778
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)\n",
779
+ "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n",
780
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:"
781
+ ]
782
+ },
783
+ {
784
+ "name": "stdin",
785
+ "output_type": "stream",
786
+ "text": [
787
+ " ········································\n"
788
+ ]
789
+ },
790
+ {
791
+ "name": "stderr",
792
+ "output_type": "stream",
793
+ "text": [
794
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"
795
+ ]
796
+ },
797
+ {
798
+ "data": {
799
+ "text/html": [
800
+ "Tracking run with wandb version 0.16.1"
801
+ ],
802
+ "text/plain": [
803
+ "<IPython.core.display.HTML object>"
804
+ ]
805
+ },
806
+ "metadata": {},
807
+ "output_type": "execute_result"
808
+ },
809
+ {
810
+ "data": {
811
+ "text/html": [
812
+ "Run data is saved locally in <code>/kaggle/working/wandb/run-20231221_183002-p8sfo0k2</code>"
813
+ ],
814
+ "text/plain": [
815
+ "<IPython.core.display.HTML object>"
816
+ ]
817
+ },
818
+ "metadata": {},
819
+ "output_type": "execute_result"
820
+ },
821
+ {
822
+ "data": {
823
+ "text/html": [
824
+ "Syncing run <strong><a href='https://wandb.ai/erav1/huggingface/runs/p8sfo0k2' target=\"_blank\">deft-energy-4</a></strong> to <a href='https://wandb.ai/erav1/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
825
+ ],
826
+ "text/plain": [
827
+ "<IPython.core.display.HTML object>"
828
+ ]
829
+ },
830
+ "metadata": {},
831
+ "output_type": "execute_result"
832
+ },
833
+ {
834
+ "data": {
835
+ "text/html": [
836
+ " View project at <a href='https://wandb.ai/erav1/huggingface' target=\"_blank\">https://wandb.ai/erav1/huggingface</a>"
837
+ ],
838
+ "text/plain": [
839
+ "<IPython.core.display.HTML object>"
840
+ ]
841
+ },
842
+ "metadata": {},
843
+ "output_type": "execute_result"
844
+ },
845
+ {
846
+ "data": {
847
+ "text/html": [
848
+ " View run at <a href='https://wandb.ai/erav1/huggingface/runs/p8sfo0k2' target=\"_blank\">https://wandb.ai/erav1/huggingface/runs/p8sfo0k2</a>"
849
+ ],
850
+ "text/plain": [
851
+ "<IPython.core.display.HTML object>"
852
+ ]
853
+ },
854
+ "metadata": {},
855
+ "output_type": "execute_result"
856
+ },
857
+ {
858
+ "name": "stderr",
859
+ "output_type": "stream",
860
+ "text": [
861
+ "You're using a CodeGenTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
862
+ ]
863
+ },
864
+ {
865
+ "data": {
866
+ "text/html": [
867
+ "\n",
868
+ " <div>\n",
869
+ " \n",
870
+ " <progress value='500' max='500' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
871
+ " [500/500 1:23:28, Epoch 0/1]\n",
872
+ " </div>\n",
873
+ " <table border=\"1\" class=\"dataframe\">\n",
874
+ " <thead>\n",
875
+ " <tr style=\"text-align: left;\">\n",
876
+ " <th>Step</th>\n",
877
+ " <th>Training Loss</th>\n",
878
+ " </tr>\n",
879
+ " </thead>\n",
880
+ " <tbody>\n",
881
+ " <tr>\n",
882
+ " <td>10</td>\n",
883
+ " <td>1.869900</td>\n",
884
+ " </tr>\n",
885
+ " <tr>\n",
886
+ " <td>20</td>\n",
887
+ " <td>1.790100</td>\n",
888
+ " </tr>\n",
889
+ " <tr>\n",
890
+ " <td>30</td>\n",
891
+ " <td>1.739300</td>\n",
892
+ " </tr>\n",
893
+ " <tr>\n",
894
+ " <td>40</td>\n",
895
+ " <td>1.684800</td>\n",
896
+ " </tr>\n",
897
+ " <tr>\n",
898
+ " <td>50</td>\n",
899
+ " <td>1.801400</td>\n",
900
+ " </tr>\n",
901
+ " <tr>\n",
902
+ " <td>60</td>\n",
903
+ " <td>1.748200</td>\n",
904
+ " </tr>\n",
905
+ " <tr>\n",
906
+ " <td>70</td>\n",
907
+ " <td>1.736700</td>\n",
908
+ " </tr>\n",
909
+ " <tr>\n",
910
+ " <td>80</td>\n",
911
+ " <td>1.653100</td>\n",
912
+ " </tr>\n",
913
+ " <tr>\n",
914
+ " <td>90</td>\n",
915
+ " <td>1.719100</td>\n",
916
+ " </tr>\n",
917
+ " <tr>\n",
918
+ " <td>100</td>\n",
919
+ " <td>1.796200</td>\n",
920
+ " </tr>\n",
921
+ " <tr>\n",
922
+ " <td>110</td>\n",
923
+ " <td>1.727100</td>\n",
924
+ " </tr>\n",
925
+ " <tr>\n",
926
+ " <td>120</td>\n",
927
+ " <td>1.650200</td>\n",
928
+ " </tr>\n",
929
+ " <tr>\n",
930
+ " <td>130</td>\n",
931
+ " <td>1.760600</td>\n",
932
+ " </tr>\n",
933
+ " <tr>\n",
934
+ " <td>140</td>\n",
935
+ " <td>1.715600</td>\n",
936
+ " </tr>\n",
937
+ " <tr>\n",
938
+ " <td>150</td>\n",
939
+ " <td>1.782500</td>\n",
940
+ " </tr>\n",
941
+ " <tr>\n",
942
+ " <td>160</td>\n",
943
+ " <td>1.751000</td>\n",
944
+ " </tr>\n",
945
+ " <tr>\n",
946
+ " <td>170</td>\n",
947
+ " <td>1.678500</td>\n",
948
+ " </tr>\n",
949
+ " <tr>\n",
950
+ " <td>180</td>\n",
951
+ " <td>1.643800</td>\n",
952
+ " </tr>\n",
953
+ " <tr>\n",
954
+ " <td>190</td>\n",
955
+ " <td>1.758200</td>\n",
956
+ " </tr>\n",
957
+ " <tr>\n",
958
+ " <td>200</td>\n",
959
+ " <td>1.767500</td>\n",
960
+ " </tr>\n",
961
+ " <tr>\n",
962
+ " <td>210</td>\n",
963
+ " <td>1.735000</td>\n",
964
+ " </tr>\n",
965
+ " <tr>\n",
966
+ " <td>220</td>\n",
967
+ " <td>1.635400</td>\n",
968
+ " </tr>\n",
969
+ " <tr>\n",
970
+ " <td>230</td>\n",
971
+ " <td>1.666700</td>\n",
972
+ " </tr>\n",
973
+ " <tr>\n",
974
+ " <td>240</td>\n",
975
+ " <td>1.690100</td>\n",
976
+ " </tr>\n",
977
+ " <tr>\n",
978
+ " <td>250</td>\n",
979
+ " <td>1.763400</td>\n",
980
+ " </tr>\n",
981
+ " <tr>\n",
982
+ " <td>260</td>\n",
983
+ " <td>1.700900</td>\n",
984
+ " </tr>\n",
985
+ " <tr>\n",
986
+ " <td>270</td>\n",
987
+ " <td>1.606000</td>\n",
988
+ " </tr>\n",
989
+ " <tr>\n",
990
+ " <td>280</td>\n",
991
+ " <td>1.679100</td>\n",
992
+ " </tr>\n",
993
+ " <tr>\n",
994
+ " <td>290</td>\n",
995
+ " <td>1.631000</td>\n",
996
+ " </tr>\n",
997
+ " <tr>\n",
998
+ " <td>300</td>\n",
999
+ " <td>1.884500</td>\n",
1000
+ " </tr>\n",
1001
+ " <tr>\n",
1002
+ " <td>310</td>\n",
1003
+ " <td>1.691000</td>\n",
1004
+ " </tr>\n",
1005
+ " <tr>\n",
1006
+ " <td>320</td>\n",
1007
+ " <td>1.705400</td>\n",
1008
+ " </tr>\n",
1009
+ " <tr>\n",
1010
+ " <td>330</td>\n",
1011
+ " <td>1.643000</td>\n",
1012
+ " </tr>\n",
1013
+ " <tr>\n",
1014
+ " <td>340</td>\n",
1015
+ " <td>1.738400</td>\n",
1016
+ " </tr>\n",
1017
+ " <tr>\n",
1018
+ " <td>350</td>\n",
1019
+ " <td>1.676200</td>\n",
1020
+ " </tr>\n",
1021
+ " <tr>\n",
1022
+ " <td>360</td>\n",
1023
+ " <td>1.674900</td>\n",
1024
+ " </tr>\n",
1025
+ " <tr>\n",
1026
+ " <td>370</td>\n",
1027
+ " <td>1.719000</td>\n",
1028
+ " </tr>\n",
1029
+ " <tr>\n",
1030
+ " <td>380</td>\n",
1031
+ " <td>1.614800</td>\n",
1032
+ " </tr>\n",
1033
+ " <tr>\n",
1034
+ " <td>390</td>\n",
1035
+ " <td>1.648600</td>\n",
1036
+ " </tr>\n",
1037
+ " <tr>\n",
1038
+ " <td>400</td>\n",
1039
+ " <td>1.857600</td>\n",
1040
+ " </tr>\n",
1041
+ " <tr>\n",
1042
+ " <td>410</td>\n",
1043
+ " <td>1.743500</td>\n",
1044
+ " </tr>\n",
1045
+ " <tr>\n",
1046
+ " <td>420</td>\n",
1047
+ " <td>1.696600</td>\n",
1048
+ " </tr>\n",
1049
+ " <tr>\n",
1050
+ " <td>430</td>\n",
1051
+ " <td>1.621600</td>\n",
1052
+ " </tr>\n",
1053
+ " <tr>\n",
1054
+ " <td>440</td>\n",
1055
+ " <td>1.643400</td>\n",
1056
+ " </tr>\n",
1057
+ " <tr>\n",
1058
+ " <td>450</td>\n",
1059
+ " <td>1.760000</td>\n",
1060
+ " </tr>\n",
1061
+ " <tr>\n",
1062
+ " <td>460</td>\n",
1063
+ " <td>1.601600</td>\n",
1064
+ " </tr>\n",
1065
+ " <tr>\n",
1066
+ " <td>470</td>\n",
1067
+ " <td>1.605700</td>\n",
1068
+ " </tr>\n",
1069
+ " <tr>\n",
1070
+ " <td>480</td>\n",
1071
+ " <td>1.664800</td>\n",
1072
+ " </tr>\n",
1073
+ " <tr>\n",
1074
+ " <td>490</td>\n",
1075
+ " <td>1.646200</td>\n",
1076
+ " </tr>\n",
1077
+ " <tr>\n",
1078
+ " <td>500</td>\n",
1079
+ " <td>1.764400</td>\n",
1080
+ " </tr>\n",
1081
+ " </tbody>\n",
1082
+ "</table><p>"
1083
+ ],
1084
+ "text/plain": [
1085
+ "<IPython.core.display.HTML object>"
1086
+ ]
1087
+ },
1088
+ "metadata": {},
1089
+ "output_type": "execute_result"
1090
+ },
1091
+ {
1092
+ "data": {
1093
+ "text/plain": [
1094
+ "TrainOutput(global_step=500, training_loss=1.7096557769775391, metrics={'train_runtime': 5081.3153, 'train_samples_per_second': 1.574, 'train_steps_per_second': 0.098, 'total_flos': 3.293667738832896e+16, 'train_loss': 1.7096557769775391, 'epoch': 0.81})"
1095
+ ]
1096
+ },
1097
+ "execution_count": 14,
1098
+ "metadata": {},
1099
+ "output_type": "execute_result"
1100
+ }
1101
+ ],
1102
+ "source": [
1103
+ "trainer.train()"
1104
+ ]
1105
+ },
1106
+ {
1107
+ "cell_type": "code",
1108
+ "execution_count": 16,
1109
+ "metadata": {
1110
+ "execution": {
1111
+ "iopub.execute_input": "2023-12-21T19:54:58.501371Z",
1112
+ "iopub.status.busy": "2023-12-21T19:54:58.501013Z",
1113
+ "iopub.status.idle": "2023-12-21T19:55:00.218026Z",
1114
+ "shell.execute_reply": "2023-12-21T19:55:00.216901Z",
1115
+ "shell.execute_reply.started": "2023-12-21T19:54:58.501343Z"
1116
+ }
1117
+ },
1118
+ "outputs": [],
1119
+ "source": [
1120
+ "from transformers import (\n",
1121
+ " AutoModelForCausalLM,\n",
1122
+ " AutoTokenizer,\n",
1123
+ " BitsAndBytesConfig,\n",
1124
+ " HfArgumentParser,\n",
1125
+ " TrainingArguments,\n",
1126
+ " pipeline,\n",
1127
+ " logging,\n",
1128
+ ")"
1129
+ ]
1130
+ },
1131
+ {
1132
+ "cell_type": "code",
1133
+ "execution_count": 28,
1134
+ "metadata": {
1135
+ "execution": {
1136
+ "iopub.execute_input": "2023-12-21T20:08:21.028778Z",
1137
+ "iopub.status.busy": "2023-12-21T20:08:21.027885Z",
1138
+ "iopub.status.idle": "2023-12-21T20:08:36.121203Z",
1139
+ "shell.execute_reply": "2023-12-21T20:08:36.120260Z",
1140
+ "shell.execute_reply.started": "2023-12-21T20:08:21.028747Z"
1141
+ }
1142
+ },
1143
+ "outputs": [
1144
+ {
1145
+ "name": "stdout",
1146
+ "output_type": "stream",
1147
+ "text": [
1148
+ "### Human: Explain blockchain to a five year old in two sentences.\n",
1149
+ "\n",
1150
+ "### Assistant: \n",
1151
+ "Blockchain is like a big book where everyone can write down what they do, and then everyone can see it. It's like a big game of telephone, but instead of just one person telling the story, everyone can see it all at once.### Assistant: Blockchain is like a big book where everyone can write down what they do, and then everyone can see it. It's like a big game of telephone, but instead of just one person telling the story, everyone can see it all at once.### Assistant: Blockchain is like a big book where everyone can write down what they do, and then everyone can see it. It's like a big game of telephone, but instead of just one person telling the story, everyone can see it all at once.### Assistant: Blockchain is like a big book where everyone can write down what they do, and then everyone can see it.\n"
1152
+ ]
1153
+ }
1154
+ ],
1155
+ "source": [
1156
+ "# Run text generation pipeline with our next model\n",
1157
+ "prompt = \"Explain blockchain to a five year old in two sentences.\"\n",
1158
+ "pipe = pipeline(task=\"text-generation\", model=model, tokenizer=tokenizer, max_length=200)\n",
1159
+ "result = pipe(f\"### Human: {prompt}\\n\\n### Assistant: \")\n",
1160
+ "print(result[0]['generated_text'])"
1161
+ ]
1162
+ },
1163
+ {
1164
+ "cell_type": "code",
1165
+ "execution_count": 42,
1166
+ "metadata": {
1167
+ "execution": {
1168
+ "iopub.execute_input": "2023-12-21T20:47:38.133788Z",
1169
+ "iopub.status.busy": "2023-12-21T20:47:38.132720Z",
1170
+ "iopub.status.idle": "2023-12-21T20:47:40.133261Z",
1171
+ "shell.execute_reply": "2023-12-21T20:47:40.132074Z",
1172
+ "shell.execute_reply.started": "2023-12-21T20:47:38.133744Z"
1173
+ }
1174
+ },
1175
+ "outputs": [
1176
+ {
1177
+ "name": "stderr",
1178
+ "output_type": "stream",
1179
+ "text": [
1180
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1181
+ "To disable this warning, you can either:\n",
1182
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1183
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
1184
+ ]
1185
+ },
1186
+ {
1187
+ "name": "stdout",
1188
+ "output_type": "stream",
1189
+ "text": [
1190
+ " adding: kaggle/working/results/runs/ (stored 0%)\n",
1191
+ " adding: kaggle/working/results/runs/Dec21_18-28-24_6ea8eab21180/ (stored 0%)\n",
1192
+ " adding: kaggle/working/results/runs/Dec21_18-28-24_6ea8eab21180/events.out.tfevents.1703183372.6ea8eab21180.42.0 (deflated 65%)\n"
1193
+ ]
1194
+ },
1195
+ {
1196
+ "name": "stderr",
1197
+ "output_type": "stream",
1198
+ "text": [
1199
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1200
+ "To disable this warning, you can either:\n",
1201
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1202
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
1203
+ ]
1204
+ },
1205
+ {
1206
+ "name": "stdout",
1207
+ "output_type": "stream",
1208
+ "text": [
1209
+ "checkpoint-100\tcheckpoint-300\tcheckpoint-500\n",
1210
+ "checkpoint-200\tcheckpoint-400\truns\n"
1211
+ ]
1212
+ }
1213
+ ],
1214
+ "source": [
1215
+ "!zip -r runs.zip /kaggle/working/results/runs\n",
1216
+ "!ls results"
1217
+ ]
1218
+ },
1219
+ {
1220
+ "cell_type": "code",
1221
+ "execution_count": null,
1222
+ "metadata": {},
1223
+ "outputs": [],
1224
+ "source": [
1225
+ "import torch\n",
1226
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
1227
+ "#model_path = \"/piyushgrover/phi-2-qlora-merged-custom\" # change to your preferred path\n",
1228
+ "model_name = \"microsoft/phi-2\"\n",
1229
+ "# device_map = {\"\": 1}\n",
1230
+ "\n",
1231
+ "# Reload model in FP16 and merge it with LoRA weights\n",
1232
+ "base_model = AutoModelForCausalLM.from_pretrained(\n",
1233
+ " model_name,\n",
1234
+ " low_cpu_mem_usage=True,\n",
1235
+ " return_dict=True,\n",
1236
+ " torch_dtype=torch.float16,\n",
1237
+ " trust_remote_code=True\n",
1238
+ " # device_map=device_map,\n",
1239
+ ")"
1240
+ ]
1241
+ },
1242
+ {
1243
+ "cell_type": "code",
1244
+ "execution_count": null,
1245
+ "metadata": {},
1246
+ "outputs": [],
1247
+ "source": [
1248
+ "from peft import PeftModel\n",
1249
+ "new_model = \"/piyushgrover/phi-2-qlora-adapter-custom\"\n",
1250
+ "model = PeftModel.from_pretrained(base_model, new_model)\n",
1251
+ "model = model.merge_and_unload()\n",
1252
+ "\n",
1253
+ "# Reload tokenizer to save it\n",
1254
+ "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
1255
+ "tokenizer.pad_token = tokenizer.eos_token\n",
1256
+ "tokenizer.padding_side = \"right\"\n",
1257
+ "\n",
1258
+ "# Save the merged model\n",
1259
+ "#model.save_pretrained(model_path)\n",
1260
+ "#tokenizer.save_pretrained(model_path)\n",
1261
+ "#from transformers import AutoModelForCausalLM, AutoTokenizer\n",
1262
+ "\n",
1263
+ "#model_path = \"/piyushgrover/phi-2-qlora-merged\" # change to your preferred path\n",
1264
+ "\n",
1265
+ "#model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)\n",
1266
+ "#tokenizer = AutoTokenizer.from_pretrained(model_path)"
1267
+ ]
1268
+ },
1269
+ {
1270
+ "cell_type": "code",
1271
+ "execution_count": null,
1272
+ "metadata": {},
1273
+ "outputs": [],
1274
+ "source": [
1275
+ "from transformers import pipeline\n",
1276
+ "\n",
1277
+ "prompt = \"What was the role of indian revolutionaries in indian independence movement ?\" # change to your desired prompt\n",
1278
+ "gen = pipeline('text-generation', model=model, tokenizer=tokenizer, max_length=500)\n",
1279
+ "result = gen(prompt)\n",
1280
+ "print(result[0]['generated_text'])"
1281
+ ]
1282
+ }
1283
+ ],
1284
+ "metadata": {
1285
+ "kaggle": {
1286
+ "accelerator": "gpu",
1287
+ "dataSources": [],
1288
+ "dockerImageVersionId": 30627.0,
1289
+ "isGpuEnabled": true,
1290
+ "isInternetEnabled": true,
1291
+ "language": "python",
1292
+ "sourceType": "notebook"
1293
+ },
1294
+ "kernelspec": {
1295
+ "display_name": "Python 3 (ipykernel)",
1296
+ "language": "python",
1297
+ "name": "python3"
1298
+ },
1299
+ "language_info": {
1300
+ "codemirror_mode": {
1301
+ "name": "ipython",
1302
+ "version": 3
1303
+ },
1304
+ "file_extension": ".py",
1305
+ "mimetype": "text/x-python",
1306
+ "name": "python",
1307
+ "nbconvert_exporter": "python",
1308
+ "pygments_lexer": "ipython3",
1309
+ "version": "3.8.5"
1310
+ }
1311
+ },
1312
+ "nbformat": 4,
1313
+ "nbformat_minor": 4
1314
+ }