|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from openhands.controller.agent import Agent
|
|
from openhands.controller.agent_controller import AgentController
|
|
from openhands.controller.state.state import State, TrafficControlState
|
|
from openhands.core.config import AppConfig
|
|
from openhands.core.main import run_controller
|
|
from openhands.core.schema import AgentState
|
|
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
|
|
from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
|
|
from openhands.events.observation import (
|
|
ErrorObservation,
|
|
)
|
|
from openhands.events.serialization import event_to_dict
|
|
from openhands.llm import LLM
|
|
from openhands.llm.metrics import Metrics
|
|
from openhands.runtime.base import Runtime
|
|
from openhands.storage.memory import InMemoryFileStore
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
|
|
return str(tmp_path_factory.mktemp('test_event_stream'))
|
|
|
|
|
|
@pytest.fixture(scope='function')
|
|
def event_loop():
|
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_agent():
|
|
agent = MagicMock(spec=Agent)
|
|
agent.llm = MagicMock(spec=LLM)
|
|
agent.llm.metrics = Metrics()
|
|
return agent
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_event_stream():
|
|
mock = MagicMock(spec=EventStream)
|
|
mock.get_latest_event_id.return_value = 0
|
|
return mock
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_status_callback():
|
|
return AsyncMock()
|
|
|
|
|
|
async def send_event_to_controller(controller, event):
|
|
await controller._on_event(event)
|
|
await asyncio.sleep(0.1)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_set_agent_state(mock_agent, mock_event_stream):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
await controller.set_agent_state_to(AgentState.RUNNING)
|
|
assert controller.get_agent_state() == AgentState.RUNNING
|
|
|
|
await controller.set_agent_state_to(AgentState.PAUSED)
|
|
assert controller.get_agent_state() == AgentState.PAUSED
|
|
await controller.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_event_message_action(mock_agent, mock_event_stream):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
controller.state.agent_state = AgentState.RUNNING
|
|
message_action = MessageAction(content='Test message')
|
|
await send_event_to_controller(controller, message_action)
|
|
assert controller.get_agent_state() == AgentState.RUNNING
|
|
await controller.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
controller.state.agent_state = AgentState.RUNNING
|
|
change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED)
|
|
await send_event_to_controller(controller, change_state_action)
|
|
assert controller.get_agent_state() == AgentState.PAUSED
|
|
await controller.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_callback):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
status_callback=mock_status_callback,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
error_message = 'Test error'
|
|
await controller._react_to_exception(RuntimeError(error_message))
|
|
controller.status_callback.assert_called_once()
|
|
await controller.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_controller_with_fatal_error():
|
|
config = AppConfig()
|
|
file_store = InMemoryFileStore({})
|
|
event_stream = EventStream(sid='test', file_store=file_store)
|
|
|
|
agent = MagicMock(spec=Agent)
|
|
agent = MagicMock(spec=Agent)
|
|
|
|
def agent_step_fn(state):
|
|
print(f'agent_step_fn received state: {state}')
|
|
return CmdRunAction(command='ls')
|
|
|
|
agent.step = agent_step_fn
|
|
agent.llm = MagicMock(spec=LLM)
|
|
agent.llm.metrics = Metrics()
|
|
agent.llm.config = config.get_llm_config()
|
|
|
|
runtime = MagicMock(spec=Runtime)
|
|
|
|
def on_event(event: Event):
|
|
if isinstance(event, CmdRunAction):
|
|
error_obs = ErrorObservation('You messed around with Jim')
|
|
error_obs._cause = event.id
|
|
event_stream.add_event(error_obs, EventSource.USER)
|
|
|
|
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
|
runtime.event_stream = event_stream
|
|
|
|
state = await run_controller(
|
|
config=config,
|
|
initial_user_action=MessageAction(content='Test message'),
|
|
runtime=runtime,
|
|
sid='test',
|
|
agent=agent,
|
|
fake_user_response_fn=lambda _: 'repeat',
|
|
)
|
|
print(f'state: {state}')
|
|
events = list(event_stream.get_events())
|
|
print(f'event_stream: {events}')
|
|
assert state.iteration == 4
|
|
assert state.agent_state == AgentState.ERROR
|
|
assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop'
|
|
assert len(events) == 11
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_controller_stop_with_stuck():
|
|
config = AppConfig()
|
|
file_store = InMemoryFileStore({})
|
|
event_stream = EventStream(sid='test', file_store=file_store)
|
|
|
|
agent = MagicMock(spec=Agent)
|
|
|
|
def agent_step_fn(state):
|
|
print(f'agent_step_fn received state: {state}')
|
|
return CmdRunAction(command='ls')
|
|
|
|
agent.step = agent_step_fn
|
|
agent.llm = MagicMock(spec=LLM)
|
|
agent.llm.metrics = Metrics()
|
|
agent.llm.config = config.get_llm_config()
|
|
runtime = MagicMock(spec=Runtime)
|
|
|
|
def on_event(event: Event):
|
|
if isinstance(event, CmdRunAction):
|
|
non_fatal_error_obs = ErrorObservation(
|
|
'Non fatal error here to trigger loop'
|
|
)
|
|
non_fatal_error_obs._cause = event.id
|
|
event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
|
|
|
|
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
|
runtime.event_stream = event_stream
|
|
|
|
state = await run_controller(
|
|
config=config,
|
|
initial_user_action=MessageAction(content='Test message'),
|
|
runtime=runtime,
|
|
sid='test',
|
|
agent=agent,
|
|
fake_user_response_fn=lambda _: 'repeat',
|
|
)
|
|
events = list(event_stream.get_events())
|
|
print(f'state: {state}')
|
|
for i, event in enumerate(events):
|
|
print(f'event {i}: {event_to_dict(event)}')
|
|
|
|
assert state.iteration == 4
|
|
assert len(events) == 11
|
|
|
|
repeating_actions_and_observations = events[2:10]
|
|
for action, observation in zip(
|
|
repeating_actions_and_observations[0::2],
|
|
repeating_actions_and_observations[1::2],
|
|
):
|
|
action_dict = event_to_dict(action)
|
|
observation_dict = event_to_dict(observation)
|
|
assert action_dict['action'] == 'run' and action_dict['args']['command'] == 'ls'
|
|
assert (
|
|
observation_dict['observation'] == 'error'
|
|
and observation_dict['content'] == 'Non fatal error here to trigger loop'
|
|
)
|
|
last_event = event_to_dict(events[-1])
|
|
assert last_event['extras']['agent_state'] == 'error'
|
|
assert last_event['observation'] == 'agent_state_changed'
|
|
|
|
assert state.agent_state == AgentState.ERROR
|
|
assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop'
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_max_iterations_extension(mock_agent, mock_event_stream):
|
|
|
|
initial_state = State(max_iterations=10)
|
|
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=False,
|
|
initial_state=initial_state,
|
|
)
|
|
controller.state.agent_state = AgentState.RUNNING
|
|
controller.state.iteration = 10
|
|
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
|
|
|
|
|
await controller._step()
|
|
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
|
assert controller.state.agent_state == AgentState.ERROR
|
|
|
|
|
|
message_action = MessageAction(content='Test message')
|
|
message_action._source = EventSource.USER
|
|
await send_event_to_controller(controller, message_action)
|
|
|
|
|
|
assert (
|
|
controller.state.max_iterations == 20
|
|
)
|
|
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
|
assert controller.state.agent_state == AgentState.RUNNING
|
|
|
|
|
|
await controller.close()
|
|
|
|
|
|
initial_state = State(max_iterations=10)
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
initial_state=initial_state,
|
|
)
|
|
controller.state.agent_state = AgentState.RUNNING
|
|
controller.state.iteration = 10
|
|
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
|
|
|
|
|
message_action = MessageAction(content='Test message')
|
|
message_action._source = EventSource.USER
|
|
await send_event_to_controller(controller, message_action)
|
|
|
|
|
|
assert controller.state.max_iterations == 10
|
|
|
|
|
|
await controller._step()
|
|
|
|
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
|
assert controller.state.agent_state == AgentState.ERROR
|
|
await controller.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_step_max_budget(mock_agent, mock_event_stream):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
max_budget_per_task=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=False,
|
|
)
|
|
controller.state.agent_state = AgentState.RUNNING
|
|
controller.state.metrics.accumulated_cost = 10.1
|
|
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
|
await controller._step()
|
|
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
|
assert controller.state.agent_state == AgentState.ERROR
|
|
await controller.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_step_max_budget_headless(mock_agent, mock_event_stream):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
max_budget_per_task=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
controller.state.agent_state = AgentState.RUNNING
|
|
controller.state.metrics.accumulated_cost = 10.1
|
|
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
|
await controller._step()
|
|
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
|
|
|
assert controller.state.agent_state == AgentState.ERROR
|
|
await controller.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_stream):
|
|
"""Test reset() when there's a pending action with tool call metadata but no observation."""
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
|
|
|
|
pending_action = CmdRunAction(command='test')
|
|
pending_action.tool_call_metadata = {
|
|
'function': 'test_function',
|
|
'args': {'arg1': 'value1'},
|
|
}
|
|
controller._pending_action = pending_action
|
|
|
|
|
|
controller._reset()
|
|
|
|
|
|
mock_event_stream.add_event.assert_called_once()
|
|
args, kwargs = mock_event_stream.add_event.call_args
|
|
error_obs, source = args
|
|
assert isinstance(error_obs, ErrorObservation)
|
|
assert error_obs.content == 'The action has not been executed.'
|
|
assert error_obs.tool_call_metadata == pending_action.tool_call_metadata
|
|
assert error_obs._cause == pending_action.id
|
|
assert source == EventSource.AGENT
|
|
|
|
|
|
assert controller._pending_action is None
|
|
|
|
|
|
mock_agent.reset.assert_called_once()
|
|
await controller.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reset_with_pending_action_existing_observation(
|
|
mock_agent, mock_event_stream
|
|
):
|
|
"""Test reset() when there's a pending action with tool call metadata and an existing observation."""
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
|
|
|
|
pending_action = CmdRunAction(command='test')
|
|
pending_action.tool_call_metadata = {
|
|
'function': 'test_function',
|
|
'args': {'arg1': 'value1'},
|
|
}
|
|
controller._pending_action = pending_action
|
|
|
|
|
|
existing_obs = ErrorObservation(content='Previous error')
|
|
existing_obs.tool_call_metadata = pending_action.tool_call_metadata
|
|
controller.state.history.append(existing_obs)
|
|
|
|
|
|
controller._reset()
|
|
|
|
|
|
mock_event_stream.add_event.assert_not_called()
|
|
|
|
|
|
assert controller._pending_action is None
|
|
|
|
|
|
mock_agent.reset.assert_called_once()
|
|
await controller.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reset_without_pending_action(mock_agent, mock_event_stream):
|
|
"""Test reset() when there's no pending action."""
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
|
|
|
|
controller._reset()
|
|
|
|
|
|
mock_event_stream.add_event.assert_not_called()
|
|
|
|
|
|
assert controller._pending_action is None
|
|
|
|
|
|
mock_agent.reset.assert_called_once()
|
|
await controller.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reset_with_pending_action_no_metadata(
|
|
mock_agent, mock_event_stream, monkeypatch
|
|
):
|
|
"""Test reset() when there's a pending action without tool call metadata."""
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
|
|
|
|
pending_action = CmdRunAction(command='test')
|
|
|
|
original_hasattr = hasattr
|
|
|
|
def mock_hasattr(obj, name):
|
|
if obj == pending_action and name == 'tool_call_metadata':
|
|
return False
|
|
return original_hasattr(obj, name)
|
|
|
|
monkeypatch.setattr('builtins.hasattr', mock_hasattr)
|
|
controller._pending_action = pending_action
|
|
|
|
|
|
controller._reset()
|
|
|
|
|
|
mock_event_stream.add_event.assert_not_called()
|
|
|
|
|
|
assert controller._pending_action is None
|
|
|
|
|
|
mock_agent.reset.assert_called_once()
|
|
await controller.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_controller_max_iterations_has_metrics():
|
|
config = AppConfig(
|
|
max_iterations=3,
|
|
)
|
|
file_store = InMemoryFileStore({})
|
|
event_stream = EventStream(sid='test', file_store=file_store)
|
|
|
|
agent = MagicMock(spec=Agent)
|
|
agent.llm = MagicMock(spec=LLM)
|
|
agent.llm.metrics = Metrics()
|
|
agent.llm.config = config.get_llm_config()
|
|
|
|
def agent_step_fn(state):
|
|
print(f'agent_step_fn received state: {state}')
|
|
|
|
agent.llm.metrics.add_cost(10.0)
|
|
print(
|
|
f'agent.llm.metrics.accumulated_cost: {agent.llm.metrics.accumulated_cost}'
|
|
)
|
|
return CmdRunAction(command='ls')
|
|
|
|
agent.step = agent_step_fn
|
|
|
|
runtime = MagicMock(spec=Runtime)
|
|
|
|
def on_event(event: Event):
|
|
if isinstance(event, CmdRunAction):
|
|
non_fatal_error_obs = ErrorObservation(
|
|
'Non fatal error. event id: ' + str(event.id)
|
|
)
|
|
non_fatal_error_obs._cause = event.id
|
|
event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
|
|
|
|
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
|
runtime.event_stream = event_stream
|
|
|
|
state = await run_controller(
|
|
config=config,
|
|
initial_user_action=MessageAction(content='Test message'),
|
|
runtime=runtime,
|
|
sid='test',
|
|
agent=agent,
|
|
fake_user_response_fn=lambda _: 'repeat',
|
|
)
|
|
assert state.iteration == 3
|
|
assert state.agent_state == AgentState.ERROR
|
|
assert (
|
|
state.last_error
|
|
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 3, max iteration: 3'
|
|
)
|
|
assert (
|
|
state.metrics.accumulated_cost == 10.0 * 3
|
|
), f'Expected accumulated cost to be 30.0, but got {state.metrics.accumulated_cost}'
|
|
|