|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- amphora/QwQ-LongCoT-130K |
|
language: |
|
- en |
|
metrics: |
|
- perplexity |
|
base_model: |
|
- Qwen/Qwen2.5-0.5B-Instruct |
|
--- |
|
|
|
## Model Details: |
|
|
|
- **Base Model:** Qwen/Qwen2-0.5B-Instruct |
|
- **Teacher Model:** Qwen/QwQ-32B-Preview |
|
- **Distillation Framework:** Generative Knowledge Distillation (GKD) |
|
- **Task Type:** Conversational AI / Causal Language Modeling |
|
- **Parameters:** 0.5B |
|
- **Special Features:** |
|
- Optimized with LoraConfig for fine-tuning |
|
- Integrated gradient checkpointing for efficient training |
|
- Step-by-step reasoning capabilities for better problem-solving |
|
|
|
--- |
|
|
|
## Training: |
|
|
|
QwQ-0.5B-Distilled was trained using the **QwQ-LongCoT-130K dataset**, a carefully curated collection of long-context examples designed for reasoning and conversational AI tasks. The GKD framework ensures that the student model mimics the teacher model’s outputs, aligning its predictions with high-quality responses. |
|
### Training Progress: |
|
[▓░░░░░░░░░░] 10% |
|
|
|
### Training Script: |
|
|
|
```python |
|
from datasets import Dataset |
|
from trl import GKDConfig, GKDTrainer |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
) |
|
from datasets import load_dataset |
|
from peft import LoraConfig |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--temperature", type=float, default = 0.9) |
|
parser.add_argument("--lmbda", type=float, default = 0.5) |
|
parser.add_argument("--beta", type=float, default = 0.5) |
|
parser.add_argument("--max_new_tokens", type=int, default = 4096) |
|
parser.add_argument("--output_dir", type=str, default="gkd-model") |
|
parser.add_argument("--per_device_train_batch_size", type=int, default=1) |
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=16) |
|
parser.add_argument("--gradient_checkpointing", action="store_true", default=False) |
|
parser.add_argument("--resume_from_checkpoint", action="store_true", default=False) |
|
parser.add_argument("--lora", action="store_true") |
|
args = parser.parse_args() |
|
|
|
qwq_dataset = load_dataset("amphora/QwQ-LongCoT-130K", split = "train") |
|
messages = [] |
|
for each in qwq_dataset: |
|
msg = [ |
|
{"role": "system", "content": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."}, |
|
{"role": "user", "content": each["problem"]}, |
|
{"role": "assistant", "content": each["qwq"]}, |
|
] |
|
messages.append(msg) |
|
|
|
TRAIN_SPLIT_RATIO = 0.9 |
|
train_size = int(TRAIN_SPLIT_RATIO * len(messages)) |
|
eval_size = len(messages) - train_size |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") |
|
|
|
|
|
|
|
|
|
# The teacher model to calculate the KL divergence against |
|
teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/QwQ-32B-Preview", torch_dtype=torch.bfloat16, device_map="auto") |
|
teacher_model.lm_head.weight.data = teacher_model.lm_head.weight.data[:151936, :] |
|
teacher_model.lm_head.out_features = 151936 |
|
|
|
|
|
|
|
# The model to optimise |
|
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct", torch_dtype=torch.bfloat16, device_map="auto") |
|
|
|
|
|
|
|
### Real Dataset |
|
train_dataset = Dataset.from_dict({"messages":messages[:train_size]}) |
|
eval_dataset = Dataset.from_dict({"messages":messages[train_size:]}) |
|
training_args = GKDConfig( |
|
output_dir=args.output_dir, |
|
temperature=args.temperature, |
|
lmbda=args.lmbda, |
|
beta=args.beta, |
|
max_new_tokens=args.max_new_tokens, |
|
per_device_train_batch_size=args.per_device_train_batch_size, |
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
gradient_checkpointing = args.gradient_checkpointing, |
|
save_steps = 100, |
|
save_total_limit = 5 |
|
) |
|
|
|
lora_config = LoraConfig( |
|
r=16, |
|
lora_alpha=32, |
|
lora_dropout=0.05, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
|
|
trainer = GKDTrainer( |
|
model=model, |
|
teacher_model=teacher_model, |
|
args=training_args, |
|
processing_class=tokenizer, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
peft_config=lora_config if args.lora else None |
|
) |
|
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) |
|
``` |
|
|
|
### Dataset: |
|
- **Source:** `amphora/QwQ-LongCoT-130K` |
|
- **Split:** 90% Training, 10% Evaluation |
|
|
|
--- |
|
|
|
## Example Usage: |
|
|
|
```python |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
# Model name |
|
model_name = "kz919/QwQ-0.5B-Distilled" |
|
|
|
# Load the model |
|
print(f"Starting to load the model {model_name} into memory") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16, |
|
device_map={"": 0} |
|
) |
|
|
|
# Load the tokenizer |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
# Define the prompt |
|
prompt = "How many r in strawberry." |
|
messages = [ |
|
{"role": "system", "content": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."}, |
|
{"role": "user", "content": prompt} |
|
] |
|
|
|
# Tokenize the input |
|
text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) |
|
|
|
# Generate a response |
|
generated_ids = model.generate( |
|
**model_inputs, |
|
max_new_tokens=4096 |
|
) |
|
generated_ids = [ |
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) |
|
] |
|
|
|
# Decode the response |
|
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
print(response) |
|
``` |
|
|
|
--- |
|
|
|
## Applications: |
|
|
|
1. **Conversational Assistants:** |
|
Suitable for AI chatbots that require reasoning and long-context understanding. |
|
|
|
2. **Educational Tools:** |
|
Provides step-by-step explanations, making it ideal for learning environments. |
|
|
|
3. **Creative Writing:** |
|
Assists in generating coherent, contextually aware long-form content. |
|
|
|
4. **Technical Support:** |
|
Handles complex customer queries with precision and clarity. |
|
|
|
--- |
|
|
|
## Limitations: |
|
|
|
- While distilled for efficiency, performance on highly complex reasoning tasks may slightly trail the teacher model. |
|
- Warning 🚨🚨🚨: This model is not fully trained, merely a proof of concept. Don't yell at me if it's outputing nonesense. |
|
--- |
|
|
|
## Citation: |
|
|
|
If you use this model in your research or applications, please cite it as: |
|
|
|
```bibtex |
|
@model{qwq_0.5B_distilled, |
|
author = {Kaizhao Liang}, |
|
title = {QwQ-0.5B-Distilled: A Reasoning Model for Edge Devices}, |
|
year = {2024}, |
|
publisher = {Hugging Face}, |
|
version = {1.0} |
|
} |
|
``` |
|
|
|
--- |
|
|
|
This model is an example of how efficient fine-tuning and distillation methods can deliver robust conversational AI capabilities in a smaller, more manageable footprint. |