PythiaChat-2.8B_v0.1

This model is a fine-tuned version of EleutherAI/pythia-2.8b-deduped on the Baize dataset with LoRA, trained for only 200+ steps for testing. This model is trained for "chat" style instruction following capabilities.

Sample Use

Remember to mark the human messages with [|Human|] and AI messages with [|AI] at the start.

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig

peft_model_id = "linkanjarad/PythiaChat-2.8B_v0.1"
model_id = "EleutherAI/pythia-2.8b-deduped"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True) # you can add `load_in_4bit=True` for faster inference
model = PeftModel.from_pretrained(model, peft_model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = model.to('cuda')
model.eval()


input_text = """The conversation between human and AI assistant.
[|Human|] How do I open a file with python?
[|AI|]"""

# Tokenize the input text
input_ids = tokenizer.encode(input_text, return_tensors='pt').to('cuda')
len_input = len(input_ids[0])
# Generate text using the model
with torch.no_grad():
    output = model.generate(input_ids=input_ids, max_length=len_input+120, temperature=0.9, do_sample=True)

# Decode the generated output
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

print(generated_text)

Example Output

The conversation between human and AI assistant.
[|Human|] How do I open a file with python?
[|AI|] To open a file with python, you can use the open function as follows:

>>> with open('filename.txt', 'w') as f:
...     f.write(data)

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 7e-05
  • train_batch_size: 4
  • eval_batch_size: 8
  • seed: 42
  • gradient_accumulation_steps: 8
  • total_train_batch_size: 32
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • lr_scheduler_warmup_steps: 80
  • num_epochs: 1

Framework versions

  • PEFT 0.4.0
  • Transformers 4.31.0
  • Pytorch 2.0.0
  • Datasets 2.13.1
  • Tokenizers 0.13.3
Downloads last month
4
Inference API
Inference API (serverless) has been turned off for this model.

Model tree for linkanjarad/PythiaChat-2.8B_v0.1

Adapter
(1)
this model

Dataset used to train linkanjarad/PythiaChat-2.8B_v0.1