"""Tests for dataset.py module, specifically DatasetEnsemble class.""" from collections.abc import Generator import geopandas as gpd import numpy as np import pandas as pd import pytest from entropice.ml.dataset import DatasetEnsemble, TrainingSet from entropice.utils.types import all_target_datasets, all_tasks @pytest.fixture def sample_ensemble() -> Generator[DatasetEnsemble]: """Create a sample DatasetEnsemble for testing with minimal data.""" yield DatasetEnsemble( grid="hex", level=3, # Use level 3 for much faster tests members=["AlphaEarth"], # Use only one member for faster tests add_lonlat=True, ) @pytest.fixture def sample_ensemble_v2() -> Generator[DatasetEnsemble]: """Create a sample DatasetEnsemble for testing with v2 target.""" yield DatasetEnsemble( grid="hex", level=3, # Use level 3 for much faster tests members=["AlphaEarth"], # Use only one member for faster tests add_lonlat=True, ) class TestDatasetEnsemble: """Test suite for DatasetEnsemble class.""" def test_initialization(self, sample_ensemble: DatasetEnsemble) -> None: """Test that DatasetEnsemble initializes correctly.""" assert sample_ensemble.grid == "hex" assert sample_ensemble.level == 3 assert "AlphaEarth" in sample_ensemble.members assert sample_ensemble.add_lonlat is True def test_get_targets_returns_geodataframe(self, sample_ensemble: DatasetEnsemble) -> None: """Test that get_targets() returns a GeoDataFrame.""" targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary") assert isinstance(targets, gpd.GeoDataFrame) def test_get_targets_has_expected_columns(self, sample_ensemble: DatasetEnsemble) -> None: """Test that get_targets() returns expected columns.""" targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary") # Should have required columns assert "geometry" in targets.columns assert "y" in targets.columns # target label assert "z" in targets.columns # raw value assert targets.index.name == "cell_id" def test_get_targets_binary_task(self, sample_ensemble: DatasetEnsemble) -> None: """Test that binary task creates proper labels.""" targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary") # Binary task should have categorical y with "No RTS" and "RTS" assert targets["y"].dtype.name == "category" assert set(targets["y"].cat.categories) == {"No RTS", "RTS"} # z should be numeric (count) assert pd.api.types.is_numeric_dtype(targets["z"]) def test_get_targets_count_task(self, sample_ensemble: DatasetEnsemble) -> None: """Test that count task creates proper labels.""" targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="count") # Count task should have numeric y and z that are identical assert pd.api.types.is_numeric_dtype(targets["y"]) assert pd.api.types.is_numeric_dtype(targets["z"]) assert (targets["y"] == targets["z"]).all() def test_get_targets_density_task(self, sample_ensemble: DatasetEnsemble) -> None: """Test that density task creates proper labels.""" targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="density") # Density task should have numeric y and z that are identical assert pd.api.types.is_numeric_dtype(targets["y"]) assert pd.api.types.is_numeric_dtype(targets["z"]) assert (targets["y"] == targets["z"]).all() def test_make_features_returns_dataframe(self, sample_ensemble: DatasetEnsemble) -> None: """Test that make_features() returns a DataFrame.""" # Get a small subset of cell_ids for faster testing targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary") cell_ids: pd.Series = targets.index.to_series().head(50) features: pd.DataFrame = sample_ensemble.make_features(cell_ids=cell_ids, cache_mode="none") assert isinstance(features, pd.DataFrame) def test_make_features_has_expected_columns(self, sample_ensemble: DatasetEnsemble) -> None: """Test that make_features() returns expected feature columns.""" targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary") cell_ids: pd.Series = targets.index.to_series().head(50) features: pd.DataFrame = sample_ensemble.make_features(cell_ids=cell_ids, cache_mode="none") # Should NOT have geometry column assert "geometry" not in features.columns # Should have location columns if add_lonlat is True if sample_ensemble.add_lonlat: assert "x" in features.columns assert "y" in features.columns # Should have grid property columns assert "cell_area" in features.columns assert "land_area" in features.columns # Should have member feature columns assert any(col.startswith("embeddings_") for col in features.columns) def test_inference_df_batches_consistency(self, sample_ensemble: DatasetEnsemble) -> None: """Test that create_inference_df produces batches with consistent columns.""" batch_size: int = 100 batches: list[pd.DataFrame] = list( sample_ensemble.create_inference_df(batch_size=batch_size, cache_mode="none") ) if len(batches) == 0: pytest.skip("No batches created (dataset might be empty)") # All batches should have the same columns first_batch_cols = set(batches[0].columns) for i, batch in enumerate(batches[1:], start=1): assert set(batch.columns) == first_batch_cols, ( f"Batch {i} has different columns than batch 0. " f"Difference: {set(batch.columns).symmetric_difference(first_batch_cols)}" ) def test_training_set_structure(self, sample_ensemble: DatasetEnsemble) -> None: """Test that create_training_set creates proper structure.""" training_set: TrainingSet = sample_ensemble.create_training_set( task="binary", target="darts_v1", device="cpu", cache_mode="none" ) # Check that TrainingSet has all expected attributes assert hasattr(training_set, "targets") assert hasattr(training_set, "features") assert hasattr(training_set, "X") assert hasattr(training_set, "y") assert hasattr(training_set, "z") assert hasattr(training_set, "split") def test_training_set_feature_columns(self, sample_ensemble: DatasetEnsemble) -> None: """Test that create_training_set creates proper feature columns.""" training_set: TrainingSet = sample_ensemble.create_training_set( task="binary", target="darts_v1", device="cpu", cache_mode="none" ) # Get the feature column names feature_cols: set[str] = set(training_set.features.columns) # These columns should NOT be in features assert "geometry" not in feature_cols # Note: "y" and "x" are spatial coordinates from the grid and are valid features # The target labels are stored in training_set.targets["y"], not in features def test_training_set_no_target_leakage(self, sample_ensemble: DatasetEnsemble) -> None: """Test that target information doesn't leak into features.""" training_set: TrainingSet = sample_ensemble.create_training_set( task="binary", target="darts_v1", device="cpu", cache_mode="none" ) # Features should not contain target-related columns feature_cols: pd.Index = training_set.features.columns assert all(not col.startswith("darts_") for col in feature_cols) def test_inference_vs_training_feature_consistency(self, sample_ensemble: DatasetEnsemble) -> None: """Test that inference batches have the same features as training data. This test simulates the workflow in training.py and inference.py to ensure feature consistency between training and inference. """ # Step 1: Create training set (as in training.py) training_set: TrainingSet = sample_ensemble.create_training_set( task="binary", target="darts_v1", device="cpu", cache_mode="none" ) training_feature_cols: set[str] = set(training_set.features.columns) # Step 2: Create inference batch (as in inference.py) batch_generator: Generator[pd.DataFrame] = sample_ensemble.create_inference_df( batch_size=100, cache_mode="none" ) batch: pd.DataFrame | None = next(batch_generator, None) if batch is None: pytest.skip("No batches created (dataset might be empty)") assert batch is not None # For type checker inference_feature_cols: set[str] = set(batch.columns) # The features should match! assert training_feature_cols == inference_feature_cols, ( f"Feature mismatch between training and inference!\n" f"Only in training: {training_feature_cols - inference_feature_cols}\n" f"Only in inference: {inference_feature_cols - training_feature_cols}\n" f"Training features ({len(training_feature_cols)}): {sorted(training_feature_cols)}\n" f"Inference features ({len(inference_feature_cols)}): {sorted(inference_feature_cols)}" ) def test_all_tasks_feature_consistency(self, sample_ensemble: DatasetEnsemble) -> None: """Test that all task types produce consistent features.""" feature_sets: dict[str, set[str]] = {} for task in all_tasks: training_set: TrainingSet = sample_ensemble.create_training_set( task=task, target="darts_v1", device="cpu", cache_mode="none", ) feature_sets[task] = set(training_set.features.columns) # All tasks should have the same features binary_features: set[str] = feature_sets["binary"] for task, features in feature_sets.items(): assert features == binary_features, ( f"Task '{task}' has different features than 'binary'.\n" f"Difference: {features.symmetric_difference(binary_features)}" ) def test_training_set_shapes(self, sample_ensemble: DatasetEnsemble) -> None: """Test that training set has correct shapes.""" training_set: TrainingSet = sample_ensemble.create_training_set( task="binary", target="darts_v1", device="cpu", cache_mode="none" ) n_features: int = len(training_set.features.columns) n_samples_train: int = training_set.X.train.shape[0] n_samples_test: int = training_set.X.test.shape[0] # Check X shapes assert training_set.X.train.shape == (n_samples_train, n_features) assert training_set.X.test.shape == (n_samples_test, n_features) # Check y shapes assert training_set.y.train.shape == (n_samples_train,) assert training_set.y.test.shape == (n_samples_test,) # Check that train + test = total samples assert len(training_set) == n_samples_train + n_samples_test def test_no_nan_in_training_features(self, sample_ensemble: DatasetEnsemble) -> None: """Test that training features don't contain NaN values.""" training_set: TrainingSet = sample_ensemble.create_training_set( task="binary", target="darts_v1", device="cpu", cache_mode="none" ) # Convert to numpy for checking (handles both numpy and cupy arrays) X_train: np.ndarray = np.asarray(training_set.X.train) # noqa: N806 X_test: np.ndarray = np.asarray(training_set.X.test) # noqa: N806 assert not np.isnan(X_train).any(), "Training features contain NaN values" assert not np.isnan(X_test).any(), "Test features contain NaN values" def test_batch_coverage(self, sample_ensemble: DatasetEnsemble) -> None: """Test that batches cover all data without duplication.""" batches: list[pd.DataFrame] = list(sample_ensemble.create_inference_df(batch_size=100, cache_mode="none")) if len(batches) == 0: pytest.skip("No batches created (dataset might be empty)") # Collect all cell_ids from batches batch_cell_ids = set() for batch in batches: batch_ids = set(batch.index) # Check for duplicates across batches overlap = batch_cell_ids.intersection(batch_ids) assert len(overlap) == 0, f"Found {len(overlap)} duplicate cell_ids across batches" batch_cell_ids.update(batch_ids) # Check that all cell_ids from grid are in batches all_cell_ids = set(sample_ensemble.cell_ids.to_numpy()) assert batch_cell_ids == all_cell_ids, ( f"Batch coverage mismatch.\n" f"Missing from batches: {all_cell_ids - batch_cell_ids}\n" f"Extra in batches: {batch_cell_ids - all_cell_ids}" ) class TestDatasetEnsembleEdgeCases: """Test edge cases and error handling.""" def test_invalid_task_raises_error(self, sample_ensemble: DatasetEnsemble) -> None: """Test that invalid task raises ValueError.""" with pytest.raises(NotImplementedError, match=r"Task .* not supported"): sample_ensemble.create_training_set(task="invalid", target="darts_v1", device="cpu") # type: ignore[arg-type] def test_target_labels_property(self, sample_ensemble: DatasetEnsemble) -> None: """Test that TrainingSet provides target labels for categorical tasks.""" training_set: TrainingSet = sample_ensemble.create_training_set( task="binary", target="darts_v1", device="cpu", cache_mode="none" ) # Binary task should have labels labels: list[str] | None = training_set.target_labels assert labels is not None assert isinstance(labels, list) assert "No RTS" in labels assert "RTS" in labels def test_target_labels_none_for_regression(self, sample_ensemble: DatasetEnsemble) -> None: """Test that regression tasks return None for target_labels.""" training_set: TrainingSet = sample_ensemble.create_training_set( task="count", target="darts_v1", device="cpu", cache_mode="none" ) # Regression task should return None assert training_set.target_labels is None def test_to_dataframe_method(self, sample_ensemble: DatasetEnsemble) -> None: """Test that to_dataframe returns expected structure.""" training_set: TrainingSet = sample_ensemble.create_training_set( task="binary", target="darts_v1", device="cpu", cache_mode="none" ) # Get full dataset df_full: pd.DataFrame = training_set.to_dataframe(split=None) assert isinstance(df_full, pd.DataFrame) assert "label" in df_full.columns assert len(df_full) == len(training_set) # Get train split df_train = training_set.to_dataframe(split="train") assert len(df_train) < len(df_full) assert all(training_set.split[df_train.index] == "train") # Get test split df_test = training_set.to_dataframe(split="test") assert len(df_test) < len(df_full) assert all(training_set.split[df_test.index] == "test") # Train + test should equal full assert len(df_train) + len(df_test) == len(df_full) def test_different_targets(self, sample_ensemble: DatasetEnsemble) -> None: """Test that different target datasets can be used.""" for target in all_target_datasets: try: training_set: TrainingSet = sample_ensemble.create_training_set( task="binary", target=target, device="cpu", cache_mode="none", ) assert len(training_set) > 0 except Exception as e: pytest.skip(f"Target {target} not available: {e}") def test_feature_names_property(self, sample_ensemble: DatasetEnsemble) -> None: """Test that feature_names property returns list of feature names.""" training_set: TrainingSet = sample_ensemble.create_training_set( task="binary", target="darts_v1", device="cpu", cache_mode="none" ) feature_names: list[str] = training_set.feature_names assert isinstance(feature_names, list) assert len(feature_names) == training_set.features.shape[1] assert all(isinstance(name, str) for name in feature_names)