zzz / tests /unit /test_llm_config.py
ar08's picture
Upload 1040 files
246d201 verified
import pathlib
import pytest
from openhands.core.config import AppConfig
from openhands.core.config.utils import load_from_toml
@pytest.fixture
def default_config(monkeypatch):
# Fixture to provide a default AppConfig instance
yield AppConfig()
@pytest.fixture
def generic_llm_toml(tmp_path: pathlib.Path) -> str:
"""Fixture to create a generic LLM TOML configuration with all custom LLMs
providing mandatory 'model' and 'api_key', and testing fallback to the generic section values
for other attributes like 'num_retries'.
"""
toml_content = """
[core]
workspace_base = "./workspace"
[llm]
model = "base-model"
api_key = "base-api-key"
embedding_model = "base-embedding"
num_retries = 3
[llm.custom1]
model = "custom-model-1"
api_key = "custom-api-key-1"
# 'num_retries' is not overridden and should fallback to the value from [llm]
[llm.custom2]
model = "custom-model-2"
api_key = "custom-api-key-2"
num_retries = 5 # Overridden value
[llm.custom3]
model = "custom-model-3"
api_key = "custom-api-key-3"
# No overrides for additional attributes
"""
toml_file = tmp_path / 'llm_config.toml'
toml_file.write_text(toml_content)
return str(toml_file)
def test_load_from_toml_llm_with_fallback(
default_config: AppConfig, generic_llm_toml: str
) -> None:
"""Test that custom LLM configurations fallback non-overridden attributes
like 'num_retries' from the generic [llm] section.
"""
load_from_toml(default_config, generic_llm_toml)
# Verify generic LLM configuration
generic_llm = default_config.get_llm_config('llm')
assert generic_llm.model == 'base-model'
assert generic_llm.api_key.get_secret_value() == 'base-api-key'
assert generic_llm.embedding_model == 'base-embedding'
assert generic_llm.num_retries == 3
# Verify custom1 LLM falls back 'num_retries' from base
custom1 = default_config.get_llm_config('custom1')
assert custom1.model == 'custom-model-1'
assert custom1.api_key.get_secret_value() == 'custom-api-key-1'
assert custom1.embedding_model == 'base-embedding'
assert custom1.num_retries == 3 # from [llm]
# Verify custom2 LLM overrides 'num_retries'
custom2 = default_config.get_llm_config('custom2')
assert custom2.model == 'custom-model-2'
assert custom2.api_key.get_secret_value() == 'custom-api-key-2'
assert custom2.embedding_model == 'base-embedding'
assert custom2.num_retries == 5 # overridden value
# Verify custom3 LLM inherits all attributes except 'model' and 'api_key'
custom3 = default_config.get_llm_config('custom3')
assert custom3.model == 'custom-model-3'
assert custom3.api_key.get_secret_value() == 'custom-api-key-3'
assert custom3.embedding_model == 'base-embedding'
assert custom3.num_retries == 3 # from [llm]
def test_load_from_toml_llm_custom_overrides_all(
default_config: AppConfig, tmp_path: pathlib.Path
) -> None:
"""Test that a custom LLM can fully override all attributes from the generic [llm] section."""
toml_content = """
[core]
workspace_base = "./workspace"
[llm]
model = "base-model"
api_key = "base-api-key"
embedding_model = "base-embedding"
num_retries = 3
[llm.custom_full]
model = "full-custom-model"
api_key = "full-custom-api-key"
embedding_model = "full-custom-embedding"
num_retries = 10
"""
toml_file = tmp_path / 'full_override_llm.toml'
toml_file.write_text(toml_content)
load_from_toml(default_config, str(toml_file))
# Verify generic LLM configuration remains unchanged
generic_llm = default_config.get_llm_config('llm')
assert generic_llm.model == 'base-model'
assert generic_llm.api_key.get_secret_value() == 'base-api-key'
assert generic_llm.embedding_model == 'base-embedding'
assert generic_llm.num_retries == 3
# Verify custom_full LLM overrides all attributes
custom_full = default_config.get_llm_config('custom_full')
assert custom_full.model == 'full-custom-model'
assert custom_full.api_key.get_secret_value() == 'full-custom-api-key'
assert custom_full.embedding_model == 'full-custom-embedding'
assert custom_full.num_retries == 10 # overridden value
def test_load_from_toml_llm_custom_partial_override(
default_config: AppConfig, generic_llm_toml: str
) -> None:
"""Test that custom LLM configurations can partially override attributes
from the generic [llm] section while inheriting others.
"""
load_from_toml(default_config, generic_llm_toml)
# Verify custom1 LLM overrides 'model' and 'api_key' but inherits 'num_retries'
custom1 = default_config.get_llm_config('custom1')
assert custom1.model == 'custom-model-1'
assert custom1.api_key.get_secret_value() == 'custom-api-key-1'
assert custom1.embedding_model == 'base-embedding'
assert custom1.num_retries == 3 # from [llm]
# Verify custom2 LLM overrides 'model', 'api_key', and 'num_retries'
custom2 = default_config.get_llm_config('custom2')
assert custom2.model == 'custom-model-2'
assert custom2.api_key.get_secret_value() == 'custom-api-key-2'
assert custom2.embedding_model == 'base-embedding'
assert custom2.num_retries == 5 # Overridden value
def test_load_from_toml_llm_custom_no_override(
default_config: AppConfig, generic_llm_toml: str
) -> None:
"""Test that custom LLM configurations with no additional overrides
inherit all non-specified attributes from the generic [llm] section.
"""
load_from_toml(default_config, generic_llm_toml)
# Verify custom3 LLM inherits 'embedding_model' and 'num_retries' from generic
custom3 = default_config.get_llm_config('custom3')
assert custom3.model == 'custom-model-3'
assert custom3.api_key.get_secret_value() == 'custom-api-key-3'
assert custom3.embedding_model == 'base-embedding'
assert custom3.num_retries == 3 # from [llm]
def test_load_from_toml_llm_missing_generic(
default_config: AppConfig, tmp_path: pathlib.Path
) -> None:
"""Test that custom LLM configurations without a generic [llm] section
use only their own attributes and fallback to defaults for others.
"""
toml_content = """
[core]
workspace_base = "./workspace"
[llm.custom_only]
model = "custom-only-model"
api_key = "custom-only-api-key"
"""
toml_file = tmp_path / 'custom_only_llm.toml'
toml_file.write_text(toml_content)
load_from_toml(default_config, str(toml_file))
# Verify custom_only LLM uses its own attributes and defaults for others
custom_only = default_config.get_llm_config('custom_only')
assert custom_only.model == 'custom-only-model'
assert custom_only.api_key.get_secret_value() == 'custom-only-api-key'
assert custom_only.embedding_model == 'local' # default value
assert custom_only.num_retries == 8 # default value
def test_load_from_toml_llm_invalid_config(
default_config: AppConfig, tmp_path: pathlib.Path
) -> None:
"""Test that invalid custom LLM configurations do not override the generic
and raise appropriate warnings.
"""
toml_content = """
[core]
workspace_base = "./workspace"
[llm]
model = "base-model"
api_key = "base-api-key"
num_retries = 3
[llm.invalid_custom]
unknown_attr = "should_not_exist"
"""
toml_file = tmp_path / 'invalid_custom_llm.toml'
toml_file.write_text(toml_content)
load_from_toml(default_config, str(toml_file))
# Verify generic LLM is loaded correctly
generic_llm = default_config.get_llm_config('llm')
assert generic_llm.model == 'base-model'
assert generic_llm.api_key.get_secret_value() == 'base-api-key'
assert generic_llm.num_retries == 3
# Verify invalid_custom LLM does not override generic attributes
custom_invalid = default_config.get_llm_config('invalid_custom')
assert custom_invalid.model == 'base-model'
assert custom_invalid.api_key.get_secret_value() == 'base-api-key'
assert custom_invalid.num_retries == 3 # default value
assert custom_invalid.embedding_model == 'local' # default value