"""Tests for training.py module, specifically random_cv function. This test suite validates the random_cv training function across all model-task combinations using a minimal hex level 3 grid with synopsis temporal mode. Test Coverage: - All 12 model-task combinations (4 models x 3 tasks): espa, xgboost, rf, knn - Device handling for each model type (torch/CUDA/cuML compatibility) - Multi-label target dataset support - Temporal mode configuration (synopsis) - Output file creation and validation Running Tests: # Run all training tests (18 tests total, ~3 iterations each) pixi run pytest tests/test_training.py -v # Run only device handling tests pixi run pytest tests/test_training.py::TestRandomCV::test_device_handling -v # Run a specific model-task combination pixi run pytest tests/test_training.py::TestRandomCV::test_random_cv_all_combinations[binary-espa] -v Note: Tests use minimal iterations (3) and level 3 grid for speed. Full production runs use higher iteration counts (100-2000). """ import shutil import pytest from entropice.ml.dataset import DatasetEnsemble from entropice.ml.randomsearch import RunSettings, random_cv from entropice.utils.types import Model, Task @pytest.fixture(scope="module") def test_ensemble(): """Create a minimal DatasetEnsemble for testing. Uses hex level 3 grid with synopsis temporal mode for fast testing. """ return DatasetEnsemble( grid="hex", level=3, temporal_mode="synopsis", members=["AlphaEarth"], # Use only one member for faster tests add_lonlat=True, ) @pytest.fixture def cleanup_results(): """Clean up results directory after each test. This fixture collects the actual result directories created during tests and removes them after the test completes. """ created_dirs = [] def register_dir(results_dir): """Register a directory to be cleaned up.""" created_dirs.append(results_dir) return results_dir yield register_dir # Clean up only the directories created during this test for results_dir in created_dirs: if results_dir.exists(): shutil.rmtree(results_dir) # Model-task combinations to test # Note: Not all combinations make sense, but we test all to ensure robustness MODELS: list[Model] = ["espa", "xgboost", "rf", "knn"] TASKS: list[Task] = ["binary", "count", "density"] class TestRandomCV: """Test suite for random_cv function.""" @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("task", TASKS) def test_random_cv_all_combinations(self, test_ensemble, model: Model, task: Task, cleanup_results): """Test random_cv with all model-task combinations. This test runs 3 iterations for each combination to verify: - The function completes without errors - Device handling works correctly for each model type - All output files are created """ # Use darts_v1 as the primary target for all tests settings = RunSettings( n_iter=3, task=task, target="darts_v1", model=model, ) # Run the cross-validation and get the results directory results_dir = random_cv( dataset_ensemble=test_ensemble, settings=settings, experiment="test_training", ) cleanup_results(results_dir) # Verify results directory was created assert results_dir.exists(), f"Results directory not created for {model=}, {task=}" # Verify all expected output files exist expected_files = [ "search_settings.toml", "best_estimator_model.pkl", "search_results.parquet", "metrics.toml", "predicted_probabilities.parquet", ] # Add task-specific files if task in ["binary", "count", "density"]: # All tasks that use classification (including count/density when binned) # Note: count and density without _regimes suffix might be regression if task == "binary" or "_regimes" in task: expected_files.append("confusion_matrix.nc") # Add model-specific files if model in ["espa", "xgboost", "rf"]: expected_files.append("best_estimator_state.nc") for filename in expected_files: filepath = results_dir / filename assert filepath.exists(), f"Expected file {filename} not found for {model=}, {task=}" @pytest.mark.parametrize("model", MODELS) def test_device_handling(self, test_ensemble, model: Model, cleanup_results): """Test that device handling works correctly for each model type. Different models require different device configurations: - espa: Uses torch with array API dispatch - xgboost: Uses CUDA without array API dispatch - rf/knn: GPU-accelerated via cuML """ settings = RunSettings( n_iter=3, task="binary", # Simple binary task for device testing target="darts_v1", model=model, ) # This should complete without device-related errors try: results_dir = random_cv( dataset_ensemble=test_ensemble, settings=settings, experiment="test_training", ) cleanup_results(results_dir) except RuntimeError as e: # Check if error is device-related error_msg = str(e).lower() device_keywords = ["cuda", "gpu", "device", "cpu", "torch", "cupy"] if any(keyword in error_msg for keyword in device_keywords): pytest.fail(f"Device handling error for {model=}: {e}") else: # Re-raise non-device errors raise def test_random_cv_with_mllabels(self, test_ensemble, cleanup_results): """Test random_cv with multi-label target dataset.""" settings = RunSettings( n_iter=3, task="binary", target="darts_mllabels", model="espa", ) # Run the cross-validation and get the results directory results_dir = random_cv( dataset_ensemble=test_ensemble, settings=settings, experiment="test_training", ) cleanup_results(results_dir) # Verify results were created assert results_dir.exists(), "Results directory not created" assert (results_dir / "search_settings.toml").exists() def test_temporal_mode_synopsis(self, cleanup_results): """Test that temporal_mode='synopsis' is correctly used.""" import toml ensemble = DatasetEnsemble( grid="hex", level=3, temporal_mode="synopsis", members=["AlphaEarth"], add_lonlat=True, ) settings = RunSettings( n_iter=3, task="binary", target="darts_v1", model="espa", ) # This should use synopsis mode (all years aggregated) results_dir = random_cv( dataset_ensemble=ensemble, settings=settings, experiment="test_training", ) cleanup_results(results_dir) # Verify the settings were stored correctly assert results_dir.exists(), "Results directory not created" with open(results_dir / "search_settings.toml") as f: stored_settings = toml.load(f) assert stored_settings["settings"]["temporal_mode"] == "synopsis"