|
import pathlib
|
|
|
|
import pytest
|
|
|
|
from openhands.core.config import AppConfig
|
|
from openhands.core.config.utils import load_from_toml
|
|
|
|
|
|
@pytest.fixture
|
|
def default_config(monkeypatch):
|
|
|
|
yield AppConfig()
|
|
|
|
|
|
@pytest.fixture
|
|
def generic_llm_toml(tmp_path: pathlib.Path) -> str:
|
|
"""Fixture to create a generic LLM TOML configuration with all custom LLMs
|
|
providing mandatory 'model' and 'api_key', and testing fallback to the generic section values
|
|
for other attributes like 'num_retries'.
|
|
"""
|
|
toml_content = """
|
|
[core]
|
|
workspace_base = "./workspace"
|
|
|
|
[llm]
|
|
model = "base-model"
|
|
api_key = "base-api-key"
|
|
embedding_model = "base-embedding"
|
|
num_retries = 3
|
|
|
|
[llm.custom1]
|
|
model = "custom-model-1"
|
|
api_key = "custom-api-key-1"
|
|
# 'num_retries' is not overridden and should fallback to the value from [llm]
|
|
|
|
[llm.custom2]
|
|
model = "custom-model-2"
|
|
api_key = "custom-api-key-2"
|
|
num_retries = 5 # Overridden value
|
|
|
|
[llm.custom3]
|
|
model = "custom-model-3"
|
|
api_key = "custom-api-key-3"
|
|
# No overrides for additional attributes
|
|
"""
|
|
toml_file = tmp_path / 'llm_config.toml'
|
|
toml_file.write_text(toml_content)
|
|
return str(toml_file)
|
|
|
|
|
|
def test_load_from_toml_llm_with_fallback(
|
|
default_config: AppConfig, generic_llm_toml: str
|
|
) -> None:
|
|
"""Test that custom LLM configurations fallback non-overridden attributes
|
|
like 'num_retries' from the generic [llm] section.
|
|
"""
|
|
load_from_toml(default_config, generic_llm_toml)
|
|
|
|
|
|
generic_llm = default_config.get_llm_config('llm')
|
|
assert generic_llm.model == 'base-model'
|
|
assert generic_llm.api_key.get_secret_value() == 'base-api-key'
|
|
assert generic_llm.embedding_model == 'base-embedding'
|
|
assert generic_llm.num_retries == 3
|
|
|
|
|
|
custom1 = default_config.get_llm_config('custom1')
|
|
assert custom1.model == 'custom-model-1'
|
|
assert custom1.api_key.get_secret_value() == 'custom-api-key-1'
|
|
assert custom1.embedding_model == 'base-embedding'
|
|
assert custom1.num_retries == 3
|
|
|
|
|
|
custom2 = default_config.get_llm_config('custom2')
|
|
assert custom2.model == 'custom-model-2'
|
|
assert custom2.api_key.get_secret_value() == 'custom-api-key-2'
|
|
assert custom2.embedding_model == 'base-embedding'
|
|
assert custom2.num_retries == 5
|
|
|
|
|
|
custom3 = default_config.get_llm_config('custom3')
|
|
assert custom3.model == 'custom-model-3'
|
|
assert custom3.api_key.get_secret_value() == 'custom-api-key-3'
|
|
assert custom3.embedding_model == 'base-embedding'
|
|
assert custom3.num_retries == 3
|
|
|
|
|
|
def test_load_from_toml_llm_custom_overrides_all(
|
|
default_config: AppConfig, tmp_path: pathlib.Path
|
|
) -> None:
|
|
"""Test that a custom LLM can fully override all attributes from the generic [llm] section."""
|
|
toml_content = """
|
|
[core]
|
|
workspace_base = "./workspace"
|
|
|
|
[llm]
|
|
model = "base-model"
|
|
api_key = "base-api-key"
|
|
embedding_model = "base-embedding"
|
|
num_retries = 3
|
|
|
|
[llm.custom_full]
|
|
model = "full-custom-model"
|
|
api_key = "full-custom-api-key"
|
|
embedding_model = "full-custom-embedding"
|
|
num_retries = 10
|
|
"""
|
|
toml_file = tmp_path / 'full_override_llm.toml'
|
|
toml_file.write_text(toml_content)
|
|
|
|
load_from_toml(default_config, str(toml_file))
|
|
|
|
|
|
generic_llm = default_config.get_llm_config('llm')
|
|
assert generic_llm.model == 'base-model'
|
|
assert generic_llm.api_key.get_secret_value() == 'base-api-key'
|
|
assert generic_llm.embedding_model == 'base-embedding'
|
|
assert generic_llm.num_retries == 3
|
|
|
|
|
|
custom_full = default_config.get_llm_config('custom_full')
|
|
assert custom_full.model == 'full-custom-model'
|
|
assert custom_full.api_key.get_secret_value() == 'full-custom-api-key'
|
|
assert custom_full.embedding_model == 'full-custom-embedding'
|
|
assert custom_full.num_retries == 10
|
|
|
|
|
|
def test_load_from_toml_llm_custom_partial_override(
|
|
default_config: AppConfig, generic_llm_toml: str
|
|
) -> None:
|
|
"""Test that custom LLM configurations can partially override attributes
|
|
from the generic [llm] section while inheriting others.
|
|
"""
|
|
load_from_toml(default_config, generic_llm_toml)
|
|
|
|
|
|
custom1 = default_config.get_llm_config('custom1')
|
|
assert custom1.model == 'custom-model-1'
|
|
assert custom1.api_key.get_secret_value() == 'custom-api-key-1'
|
|
assert custom1.embedding_model == 'base-embedding'
|
|
assert custom1.num_retries == 3
|
|
|
|
|
|
custom2 = default_config.get_llm_config('custom2')
|
|
assert custom2.model == 'custom-model-2'
|
|
assert custom2.api_key.get_secret_value() == 'custom-api-key-2'
|
|
assert custom2.embedding_model == 'base-embedding'
|
|
assert custom2.num_retries == 5
|
|
|
|
|
|
def test_load_from_toml_llm_custom_no_override(
|
|
default_config: AppConfig, generic_llm_toml: str
|
|
) -> None:
|
|
"""Test that custom LLM configurations with no additional overrides
|
|
inherit all non-specified attributes from the generic [llm] section.
|
|
"""
|
|
load_from_toml(default_config, generic_llm_toml)
|
|
|
|
|
|
custom3 = default_config.get_llm_config('custom3')
|
|
assert custom3.model == 'custom-model-3'
|
|
assert custom3.api_key.get_secret_value() == 'custom-api-key-3'
|
|
assert custom3.embedding_model == 'base-embedding'
|
|
assert custom3.num_retries == 3
|
|
|
|
|
|
def test_load_from_toml_llm_missing_generic(
|
|
default_config: AppConfig, tmp_path: pathlib.Path
|
|
) -> None:
|
|
"""Test that custom LLM configurations without a generic [llm] section
|
|
use only their own attributes and fallback to defaults for others.
|
|
"""
|
|
toml_content = """
|
|
[core]
|
|
workspace_base = "./workspace"
|
|
|
|
[llm.custom_only]
|
|
model = "custom-only-model"
|
|
api_key = "custom-only-api-key"
|
|
"""
|
|
toml_file = tmp_path / 'custom_only_llm.toml'
|
|
toml_file.write_text(toml_content)
|
|
|
|
load_from_toml(default_config, str(toml_file))
|
|
|
|
|
|
custom_only = default_config.get_llm_config('custom_only')
|
|
assert custom_only.model == 'custom-only-model'
|
|
assert custom_only.api_key.get_secret_value() == 'custom-only-api-key'
|
|
assert custom_only.embedding_model == 'local'
|
|
assert custom_only.num_retries == 8
|
|
|
|
|
|
def test_load_from_toml_llm_invalid_config(
|
|
default_config: AppConfig, tmp_path: pathlib.Path
|
|
) -> None:
|
|
"""Test that invalid custom LLM configurations do not override the generic
|
|
and raise appropriate warnings.
|
|
"""
|
|
toml_content = """
|
|
[core]
|
|
workspace_base = "./workspace"
|
|
|
|
[llm]
|
|
model = "base-model"
|
|
api_key = "base-api-key"
|
|
num_retries = 3
|
|
|
|
[llm.invalid_custom]
|
|
unknown_attr = "should_not_exist"
|
|
"""
|
|
toml_file = tmp_path / 'invalid_custom_llm.toml'
|
|
toml_file.write_text(toml_content)
|
|
|
|
load_from_toml(default_config, str(toml_file))
|
|
|
|
|
|
generic_llm = default_config.get_llm_config('llm')
|
|
assert generic_llm.model == 'base-model'
|
|
assert generic_llm.api_key.get_secret_value() == 'base-api-key'
|
|
assert generic_llm.num_retries == 3
|
|
|
|
|
|
custom_invalid = default_config.get_llm_config('invalid_custom')
|
|
assert custom_invalid.model == 'base-model'
|
|
assert custom_invalid.api_key.get_secret_value() == 'base-api-key'
|
|
assert custom_invalid.num_retries == 3
|
|
assert custom_invalid.embedding_model == 'local'
|
|
|