Spaces:
Paused
Paused
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
from typing import List, Dict | |
from datetime import datetime | |
class MissionContext: | |
def __init__(self): | |
self.mission_counter = 1 | |
self.current_objectives = {} | |
self.conversation_history = [] | |
def add_to_history(self, role: str, content: str): | |
self.conversation_history.append({ | |
"role": role, | |
"content": content, | |
"timestamp": datetime.now().isoformat() | |
}) | |
# Keep only last 5 messages for context | |
if len(self.conversation_history) > 5: | |
self.conversation_history.pop(0) | |
class MissionGenerator: | |
def __init__(self): | |
# Using FLAN-T5-base, a free and lightweight model good for instruction following | |
self.model_name = "google/flan-t5-base" | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) | |
self.context = MissionContext() | |
def format_conversation_history(self) -> str: | |
"""Format conversation history for the model input""" | |
formatted = "" | |
for msg in self.context.conversation_history: | |
role = "User" if msg["role"] == "user" else "Assistant" | |
formatted += f"{role}: {msg['content']}\n" | |
return formatted | |
def generate_response(self, user_input: str) -> tuple[str, str]: | |
"""Generate both conversational response and formatted mission objectives""" | |
self.context.add_to_history("user", user_input) | |
# Create prompt for the model | |
conversation_history = self.format_conversation_history() | |
prompt = f""" | |
Previous conversation: | |
{conversation_history} | |
Task: Generate a mission for Original War game based on the conversation. | |
Format the response as follows: | |
1. A conversational response understanding the mission | |
2. The mission objectives in Original War format using Add Main/Secondary/Alternative | |
Current request: {user_input} | |
""" | |
# Generate response using the model | |
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True) | |
outputs = self.model.generate( | |
inputs["input_ids"], | |
max_length=256, | |
num_beams=4, | |
temperature=0.7, | |
no_repeat_ngram_size=2 | |
) | |
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Split the response into conversational and formatted parts | |
try: | |
parts = full_response.split("# M1") | |
chat_response = parts[0].strip() | |
formatted_objectives = "# M1" + parts[1] if len(parts) > 1 else self.generate_fallback_objectives(user_input) | |
except Exception: | |
chat_response = full_response | |
formatted_objectives = self.generate_fallback_objectives(user_input) | |
self.context.add_to_history("assistant", chat_response) | |
return chat_response, formatted_objectives | |
def generate_fallback_objectives(self, user_input: str) -> str: | |
"""Generate basic objectives if the main generation fails""" | |
return f"""# M1 | |
Add Main mission_objective | |
- Complete the primary mission goal | |
Add Secondary bonus_objective | |
- Optional additional task | |
#""" | |
def create_gradio_interface(): | |
generator = MissionGenerator() | |
def process_input(user_input: str, history: List[Dict]) -> tuple[List[Dict], str]: | |
chat_response, formatted_output = generator.generate_response(user_input) | |
history.append({"user": user_input, "bot": chat_response}) | |
return history, formatted_output | |
with gr.Blocks() as interface: | |
gr.Markdown(""" | |
# Original War Mission Objective Generator | |
Describe your mission scenario in natural language, and I'll help you create formatted mission objectives. | |
""") | |
chatbot = gr.Chatbot(height=400) | |
msg = gr.Textbox( | |
label="Describe your mission scenario", | |
placeholder="Tell me about the mission you want to create..." | |
) | |
clear = gr.Button("Clear Conversation") | |
formatted_output = gr.Textbox( | |
label="Generated Mission Objectives", | |
lines=10, | |
placeholder="Mission objectives will appear here..." | |
) | |
msg.submit(process_input, | |
inputs=[msg, chatbot], | |
outputs=[chatbot, formatted_output]) | |
clear.click(lambda: ([], ""), outputs=[chatbot, formatted_output]) | |
gr.Examples([ | |
["I need a mission where players have to infiltrate an enemy base. They should try to avoid detection, but if they get spotted, they'll need to fight their way through."], | |
["Create a defensive mission where players protect a convoy. They should also try to minimize civilian casualties."], | |
["I want players to capture a strategic point. They can either do it by force or try diplomatic negotiations with the local faction."] | |
]) | |
return interface | |
# Launch the interface | |
if __name__ == "__main__": | |
iface = create_gradio_interface() | |
iface.launch() | |
""" | |
# Discord bot implementation using the same generator | |
import discord | |
from discord.ext import commands | |
import os | |
class MissionBot(commands.Bot): | |
def __init__(self): | |
super().__init__(command_prefix="!") | |
self.generator = MissionGenerator() | |
async def on_ready(self): | |
print(f'{self.user} has connected to Discord!') | |
@commands.command(name='mission') | |
async def generate_mission(self, ctx, *, description: str): | |
chat_response, formatted_output = self.generator.generate_response(description) | |
# Split response if it's too long for Discord | |
if len(formatted_output) > 1990: # Discord has 2000 char limit | |
await ctx.send(f"π {chat_response}") | |
await ctx.send(f"```\n{formatted_output}\n```") | |
else: | |
await ctx.send(f"π {chat_response}\n\n```\n{formatted_output}\n```") | |
# Initialize and run the bot | |
if __name__ == "__main__": | |
bot = MissionBot() | |
bot.run(os.getenv('DISCORD_TOKEN')) | |
""" |