mrm8488's picture
Update README.md
e51ac35 verified
metadata
library_name: transformers
license: apache-2.0
datasets:
  - THUDM/AgentInstruct
language:
  - en

Mistral-7B fine-tuned on AgentInstruct

Mistral-7b-v1.0 fine-tuned on the dataset AgentInstruct for "better acting as an agent"

Model Details

Model Description

The Mistral-7B-v0.1 Large Language Model (LLM) is a pretrained generative text model with 7 billion parameters. Mistral-7B-v0.1 outperforms Llama 2 13B on all benchmarks we tested.

For full details of this model please read our paper and release blog post.

Model Architecture

Mistral-7B-v0.1 is a transformer model, with the following architecture choices:

  • Grouped-Query Attention
  • Sliding-Window Attention
  • Byte-fallback BPE tokenizer

Dataset Details

AgentInstruct is a meticulously curated dataset featuring 1,866 high-quality interactions, designed to enhance AI agents across six diverse real-world tasks, leveraging innovative methods like Task Derivation and Self-Instruct.

  • πŸ” CoT - Harness the power of ReAct, offering detailed thought explanations for each action, ensuring an intricate understanding of the model's decision-making journey.
  • 🌍 Diversity - Spanning 6 real-world scenarios, from Daily Household Routines to Database Operations, and their average turns range from 5 to 35.
  • 🎯 Precision - Not all trajectories of GPT-4 are effective! Ours are rigorously filtered using strict rewards to ensure top-notch quality.
  • βœ… Assurance - Rigorous checks to avoid data leakage, ensuring pristine dataset quality.

Task Overview

Task # Filt. Traj. Avg # Filt. Traj. Turns
ALFWorld 336 13.52
WebShop 351 3.68
Mind2Web 122 1.00
Knowledge Graph 324 6.04
Operating System 195 3.85
Database 538 2.06
AgentInstruct 1866 5.24

AgentInstruct includes 1,866 trajectories from 6 agents tasks. "Traj." stands for interaction trajectory. "Filt. Traj." stands for filtered trajectories.

Training Details

TBD

Example of usage

from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
import torch

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("mrm8488/mistral-7b-ft-AgentInstruct")
model = AutoModelForCausalLM.from_pretrained("mrm8488/mistral-7b-ft-AgentInstruct").to("cuda")

class MyStoppingCriteria(StoppingCriteria):
    def __init__(self, target_sequence, prompt):
        self.target_sequence = target_sequence
        self.prompt = prompt

    def __call__(self, input_ids, scores, **kwargs):
        # Decode without prompt and check for target sequence
        generated_text = tokenizer.decode(input_ids[0]).replace(self.prompt, '')
        return self.target_sequence in generated_text

    def __len__(self):
        return 1

def generate(context, max_new_tokens=256, min_new_tokens=64, temperature=0.3, top_p=0.75, top_k=40, do_sample=True, num_beams=2):
    # Prepare input data
    inputs = tokenizer(context, return_tensors="pt")
    input_ids = inputs["input_ids"].to("cuda")
    attention_mask = inputs["attention_mask"].to("cuda")

    # Generation settings
    generation_settings = {
        "max_new_tokens": max_new_tokens,
        "min_new_tokens": min_new_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "do_sample": do_sample,
        "num_beams": num_beams,
        "early_stopping": False,
        "use_cache": True,
        "stopping_criteria": MyStoppingCriteria("### human:", context)
    }

    # Generate response
    with torch.no_grad():
        generation_output = model.generate(input_ids, attention_mask, **generation_settings)

    output = tokenizer.decode(generation_output.sequences[0])
    return output

# Example usage
context = ""
human = """### human: Among the reference ID of under 10 who got response by marketing department, compare their education status.
There are 2 tables involved with this task. The name of the 1st table is Customers, and the headers of this table are ID,SEX,MARITAL_STATUS,GEOID,EDUCATIONNUM,OCCUPATION,age. The name of the 2nd table is Mailings1_2, and the headers of this table are REFID,REF_DATE,RESPONSE."""
context = human

solution = generate(context)
print(solution)

Citation

@misc {manuel_romero_2024,
    author       = { {Manuel Romero} },
    title        = { mistral-7b-ft-AgentInstruct (Revision 463b96d) },
    year         = 2024,
    url          = { https://huggingface.co/mrm8488/mistral-7b-ft-AgentInstruct },
    doi          = { 10.57967/hf/1650 },
    publisher    = { Hugging Face }
}