mrm8488's picture
Update README.md
e51ac35 verified
---
library_name: transformers
license: apache-2.0
datasets:
- THUDM/AgentInstruct
language:
- en
---
# Mistral-7B fine-tuned on AgentInstruct
[Mistral-7b-v1.0]() fine-tuned on the dataset [AgentInstruct](https://huggingface.co/datasets/THUDM/AgentInstruct) for "*better* acting as an agent"
## Model Details
### Model Description
The Mistral-7B-v0.1 Large Language Model (LLM) is a pretrained generative text model with 7 billion parameters.
Mistral-7B-v0.1 outperforms Llama 2 13B on all benchmarks we tested.
For full details of this model please read our [paper](https://arxiv.org/abs/2310.06825) and [release blog post](https://mistral.ai/news/announcing-mistral-7b/).
## Model Architecture
Mistral-7B-v0.1 is a transformer model, with the following architecture choices:
- Grouped-Query Attention
- Sliding-Window Attention
- Byte-fallback BPE tokenizer
## Dataset Details
**AgentInstruct** is a meticulously curated dataset featuring **1,866** high-quality interactions, designed to enhance AI agents across six diverse real-world tasks, leveraging innovative methods like **Task Derivation** and **Self-Instruct**.
- πŸ” **CoT** - Harness the power of [ReAct](https://react-lm.github.io/), offering detailed thought explanations for each action, ensuring an intricate understanding of the model's decision-making journey.
- 🌍 **Diversity** - Spanning 6 real-world scenarios, from Daily Household Routines to Database Operations, and their average turns range from 5 to 35.
- 🎯 **Precision** - Not all trajectories of GPT-4 are effective! Ours are rigorously filtered using strict rewards to ensure top-notch quality.
- βœ… **Assurance** - Rigorous checks to avoid data leakage, ensuring pristine dataset quality.
## Task Overview
| Task | # Filt. Traj. | Avg # Filt. Traj. Turns |
|---|---|---|
|ALFWorld|336|13.52|
|WebShop|351|3.68|
|Mind2Web|122|1.00|
|Knowledge Graph|324|6.04|
|Operating System|195|3.85|
|Database|538|2.06|
|**AgentInstruct**|1866|5.24|
AgentInstruct includes 1,866 trajectories from
6 agents tasks. "Traj." stands for interaction trajectory. "Filt. Traj."
stands for filtered trajectories.
## Training Details
TBD
## Example of usage
```py
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
import torch
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("mrm8488/mistral-7b-ft-AgentInstruct")
model = AutoModelForCausalLM.from_pretrained("mrm8488/mistral-7b-ft-AgentInstruct").to("cuda")
class MyStoppingCriteria(StoppingCriteria):
def __init__(self, target_sequence, prompt):
self.target_sequence = target_sequence
self.prompt = prompt
def __call__(self, input_ids, scores, **kwargs):
# Decode without prompt and check for target sequence
generated_text = tokenizer.decode(input_ids[0]).replace(self.prompt, '')
return self.target_sequence in generated_text
def __len__(self):
return 1
def generate(context, max_new_tokens=256, min_new_tokens=64, temperature=0.3, top_p=0.75, top_k=40, do_sample=True, num_beams=2):
# Prepare input data
inputs = tokenizer(context, return_tensors="pt")
input_ids = inputs["input_ids"].to("cuda")
attention_mask = inputs["attention_mask"].to("cuda")
# Generation settings
generation_settings = {
"max_new_tokens": max_new_tokens,
"min_new_tokens": min_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"do_sample": do_sample,
"num_beams": num_beams,
"early_stopping": False,
"use_cache": True,
"stopping_criteria": MyStoppingCriteria("### human:", context)
}
# Generate response
with torch.no_grad():
generation_output = model.generate(input_ids, attention_mask, **generation_settings)
output = tokenizer.decode(generation_output.sequences[0])
return output
# Example usage
context = ""
human = """### human: Among the reference ID of under 10 who got response by marketing department, compare their education status.
There are 2 tables involved with this task. The name of the 1st table is Customers, and the headers of this table are ID,SEX,MARITAL_STATUS,GEOID,EDUCATIONNUM,OCCUPATION,age. The name of the 2nd table is Mailings1_2, and the headers of this table are REFID,REF_DATE,RESPONSE."""
context = human
solution = generate(context)
print(solution)
```
## Citation
```bibtext
@misc {manuel_romero_2024,
author = { {Manuel Romero} },
title = { mistral-7b-ft-AgentInstruct (Revision 463b96d) },
year = 2024,
url = { https://huggingface.co/mrm8488/mistral-7b-ft-AgentInstruct },
doi = { 10.57967/hf/1650 },
publisher = { Hugging Face }
}
```