entropice/tests/test_training.py

223 lines
7.5 KiB
Python
Raw Normal View History

2026-01-16 20:33:10 +01:00
"""Tests for training.py module, specifically random_cv function.
This test suite validates the random_cv training function across all model-task
combinations using a minimal hex level 3 grid with synopsis temporal mode.
Test Coverage:
- All 12 model-task combinations (4 models x 3 tasks): espa, xgboost, rf, knn
- Device handling for each model type (torch/CUDA/cuML compatibility)
- Multi-label target dataset support
- Temporal mode configuration (synopsis)
- Output file creation and validation
Running Tests:
# Run all training tests (18 tests total, ~3 iterations each)
pixi run pytest tests/test_training.py -v
# Run only device handling tests
pixi run pytest tests/test_training.py::TestRandomCV::test_device_handling -v
# Run a specific model-task combination
pixi run pytest tests/test_training.py::TestRandomCV::test_random_cv_all_combinations[binary-espa] -v
Note: Tests use minimal iterations (3) and level 3 grid for speed.
Full production runs use higher iteration counts (100-2000).
"""
import shutil
import pytest
from entropice.ml.dataset import DatasetEnsemble
from entropice.ml.training import CVSettings, random_cv
from entropice.utils.types import Model, Task
@pytest.fixture(scope="module")
def test_ensemble():
"""Create a minimal DatasetEnsemble for testing.
Uses hex level 3 grid with synopsis temporal mode for fast testing.
"""
return DatasetEnsemble(
grid="hex",
level=3,
temporal_mode="synopsis",
members=["AlphaEarth"], # Use only one member for faster tests
add_lonlat=True,
)
@pytest.fixture
def cleanup_results():
"""Clean up results directory after each test.
This fixture collects the actual result directories created during tests
and removes them after the test completes.
"""
created_dirs = []
def register_dir(results_dir):
"""Register a directory to be cleaned up."""
created_dirs.append(results_dir)
return results_dir
yield register_dir
# Clean up only the directories created during this test
for results_dir in created_dirs:
if results_dir.exists():
shutil.rmtree(results_dir)
# Model-task combinations to test
# Note: Not all combinations make sense, but we test all to ensure robustness
MODELS: list[Model] = ["espa", "xgboost", "rf", "knn"]
TASKS: list[Task] = ["binary", "count", "density"]
class TestRandomCV:
"""Test suite for random_cv function."""
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("task", TASKS)
def test_random_cv_all_combinations(self, test_ensemble, model: Model, task: Task, cleanup_results):
"""Test random_cv with all model-task combinations.
This test runs 3 iterations for each combination to verify:
- The function completes without errors
- Device handling works correctly for each model type
- All output files are created
"""
# Use darts_v1 as the primary target for all tests
settings = CVSettings(
n_iter=3,
task=task,
target="darts_v1",
model=model,
)
# Run the cross-validation and get the results directory
results_dir = random_cv(
dataset_ensemble=test_ensemble,
settings=settings,
experiment="test_training",
)
cleanup_results(results_dir)
# Verify results directory was created
assert results_dir.exists(), f"Results directory not created for {model=}, {task=}"
# Verify all expected output files exist
expected_files = [
"search_settings.toml",
"best_estimator_model.pkl",
"search_results.parquet",
"metrics.toml",
"predicted_probabilities.parquet",
]
# Add task-specific files
if task in ["binary", "count", "density"]:
# All tasks that use classification (including count/density when binned)
# Note: count and density without _regimes suffix might be regression
if task == "binary" or "_regimes" in task:
expected_files.append("confusion_matrix.nc")
# Add model-specific files
if model in ["espa", "xgboost", "rf"]:
expected_files.append("best_estimator_state.nc")
for filename in expected_files:
filepath = results_dir / filename
assert filepath.exists(), f"Expected file {filename} not found for {model=}, {task=}"
@pytest.mark.parametrize("model", MODELS)
def test_device_handling(self, test_ensemble, model: Model, cleanup_results):
"""Test that device handling works correctly for each model type.
Different models require different device configurations:
- espa: Uses torch with array API dispatch
- xgboost: Uses CUDA without array API dispatch
- rf/knn: GPU-accelerated via cuML
"""
settings = CVSettings(
n_iter=3,
task="binary", # Simple binary task for device testing
target="darts_v1",
model=model,
)
# This should complete without device-related errors
try:
results_dir = random_cv(
dataset_ensemble=test_ensemble,
settings=settings,
experiment="test_training",
)
cleanup_results(results_dir)
except RuntimeError as e:
# Check if error is device-related
error_msg = str(e).lower()
device_keywords = ["cuda", "gpu", "device", "cpu", "torch", "cupy"]
if any(keyword in error_msg for keyword in device_keywords):
pytest.fail(f"Device handling error for {model=}: {e}")
else:
# Re-raise non-device errors
raise
def test_random_cv_with_mllabels(self, test_ensemble, cleanup_results):
"""Test random_cv with multi-label target dataset."""
settings = CVSettings(
n_iter=3,
task="binary",
target="darts_mllabels",
model="espa",
)
# Run the cross-validation and get the results directory
results_dir = random_cv(
dataset_ensemble=test_ensemble,
settings=settings,
experiment="test_training",
)
cleanup_results(results_dir)
# Verify results were created
assert results_dir.exists(), "Results directory not created"
assert (results_dir / "search_settings.toml").exists()
def test_temporal_mode_synopsis(self, cleanup_results):
"""Test that temporal_mode='synopsis' is correctly used."""
import toml
ensemble = DatasetEnsemble(
grid="hex",
level=3,
temporal_mode="synopsis",
members=["AlphaEarth"],
add_lonlat=True,
)
settings = CVSettings(
n_iter=3,
task="binary",
target="darts_v1",
model="espa",
)
# This should use synopsis mode (all years aggregated)
results_dir = random_cv(
dataset_ensemble=ensemble,
settings=settings,
experiment="test_training",
)
cleanup_results(results_dir)
# Verify the settings were stored correctly
assert results_dir.exists(), "Results directory not created"
with open(results_dir / "search_settings.toml") as f:
stored_settings = toml.load(f)
assert stored_settings["settings"]["temporal_mode"] == "synopsis"