|
--- |
|
library_name: transformers |
|
license: apache-2.0 |
|
datasets: |
|
- THUDM/AgentInstruct |
|
language: |
|
- en |
|
--- |
|
|
|
# Mistral-7B fine-tuned on AgentInstruct |
|
|
|
[Mistral-7b-v1.0]() fine-tuned on the dataset [AgentInstruct](https://huggingface.co/datasets/THUDM/AgentInstruct) for "*better* acting as an agent" |
|
|
|
|
|
|
|
## Model Details |
|
|
|
### Model Description |
|
|
|
The Mistral-7B-v0.1 Large Language Model (LLM) is a pretrained generative text model with 7 billion parameters. |
|
Mistral-7B-v0.1 outperforms Llama 2 13B on all benchmarks we tested. |
|
|
|
For full details of this model please read our [paper](https://arxiv.org/abs/2310.06825) and [release blog post](https://mistral.ai/news/announcing-mistral-7b/). |
|
|
|
## Model Architecture |
|
|
|
Mistral-7B-v0.1 is a transformer model, with the following architecture choices: |
|
- Grouped-Query Attention |
|
- Sliding-Window Attention |
|
- Byte-fallback BPE tokenizer |
|
|
|
|
|
|
|
## Dataset Details |
|
|
|
**AgentInstruct** is a meticulously curated dataset featuring **1,866** high-quality interactions, designed to enhance AI agents across six diverse real-world tasks, leveraging innovative methods like **Task Derivation** and **Self-Instruct**. |
|
|
|
- π **CoT** - Harness the power of [ReAct](https://react-lm.github.io/), offering detailed thought explanations for each action, ensuring an intricate understanding of the model's decision-making journey. |
|
- π **Diversity** - Spanning 6 real-world scenarios, from Daily Household Routines to Database Operations, and their average turns range from 5 to 35. |
|
- π― **Precision** - Not all trajectories of GPT-4 are effective! Ours are rigorously filtered using strict rewards to ensure top-notch quality. |
|
- β
**Assurance** - Rigorous checks to avoid data leakage, ensuring pristine dataset quality. |
|
|
|
## Task Overview |
|
|
|
| Task | # Filt. Traj. | Avg # Filt. Traj. Turns | |
|
|---|---|---| |
|
|ALFWorld|336|13.52| |
|
|WebShop|351|3.68| |
|
|Mind2Web|122|1.00| |
|
|Knowledge Graph|324|6.04| |
|
|Operating System|195|3.85| |
|
|Database|538|2.06| |
|
|**AgentInstruct**|1866|5.24| |
|
|
|
AgentInstruct includes 1,866 trajectories from |
|
6 agents tasks. "Traj." stands for interaction trajectory. "Filt. Traj." |
|
stands for filtered trajectories. |
|
|
|
## Training Details |
|
TBD |
|
|
|
## Example of usage |
|
|
|
```py |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria |
|
import torch |
|
|
|
# Load tokenizer and model |
|
tokenizer = AutoTokenizer.from_pretrained("mrm8488/mistral-7b-ft-AgentInstruct") |
|
model = AutoModelForCausalLM.from_pretrained("mrm8488/mistral-7b-ft-AgentInstruct").to("cuda") |
|
|
|
class MyStoppingCriteria(StoppingCriteria): |
|
def __init__(self, target_sequence, prompt): |
|
self.target_sequence = target_sequence |
|
self.prompt = prompt |
|
|
|
def __call__(self, input_ids, scores, **kwargs): |
|
# Decode without prompt and check for target sequence |
|
generated_text = tokenizer.decode(input_ids[0]).replace(self.prompt, '') |
|
return self.target_sequence in generated_text |
|
|
|
def __len__(self): |
|
return 1 |
|
|
|
def generate(context, max_new_tokens=256, min_new_tokens=64, temperature=0.3, top_p=0.75, top_k=40, do_sample=True, num_beams=2): |
|
# Prepare input data |
|
inputs = tokenizer(context, return_tensors="pt") |
|
input_ids = inputs["input_ids"].to("cuda") |
|
attention_mask = inputs["attention_mask"].to("cuda") |
|
|
|
# Generation settings |
|
generation_settings = { |
|
"max_new_tokens": max_new_tokens, |
|
"min_new_tokens": min_new_tokens, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"top_k": top_k, |
|
"do_sample": do_sample, |
|
"num_beams": num_beams, |
|
"early_stopping": False, |
|
"use_cache": True, |
|
"stopping_criteria": MyStoppingCriteria("### human:", context) |
|
} |
|
|
|
# Generate response |
|
with torch.no_grad(): |
|
generation_output = model.generate(input_ids, attention_mask, **generation_settings) |
|
|
|
output = tokenizer.decode(generation_output.sequences[0]) |
|
return output |
|
|
|
# Example usage |
|
context = "" |
|
human = """### human: Among the reference ID of under 10 who got response by marketing department, compare their education status. |
|
There are 2 tables involved with this task. The name of the 1st table is Customers, and the headers of this table are ID,SEX,MARITAL_STATUS,GEOID,EDUCATIONNUM,OCCUPATION,age. The name of the 2nd table is Mailings1_2, and the headers of this table are REFID,REF_DATE,RESPONSE.""" |
|
context = human |
|
|
|
solution = generate(context) |
|
print(solution) |
|
``` |
|
|
|
## Citation |
|
```bibtext |
|
@misc {manuel_romero_2024, |
|
author = { {Manuel Romero} }, |
|
title = { mistral-7b-ft-AgentInstruct (Revision 463b96d) }, |
|
year = 2024, |
|
url = { https://huggingface.co/mrm8488/mistral-7b-ft-AgentInstruct }, |
|
doi = { 10.57967/hf/1650 }, |
|
publisher = { Hugging Face } |
|
} |
|
``` |