diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 8420ab6..c1959f1 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,91 +1,129 @@ """Tests for dataset.py module, specifically DatasetEnsemble class.""" +from collections.abc import Generator + import geopandas as gpd import numpy as np +import pandas as pd import pytest -from entropice.ml.dataset import DatasetEnsemble +from entropice.ml.dataset import DatasetEnsemble, TrainingSet +from entropice.utils.types import all_target_datasets, all_tasks @pytest.fixture -def sample_ensemble(): +def sample_ensemble() -> Generator[DatasetEnsemble]: """Create a sample DatasetEnsemble for testing with minimal data.""" - return DatasetEnsemble( + yield DatasetEnsemble( grid="hex", level=3, # Use level 3 for much faster tests - target="darts_rts", members=["AlphaEarth"], # Use only one member for faster tests add_lonlat=True, - filter_target="darts_has_coverage", # Filter to reduce dataset size ) @pytest.fixture -def sample_ensemble_mllabels(): - """Create a sample DatasetEnsemble with mllabels target.""" - return DatasetEnsemble( +def sample_ensemble_v2() -> Generator[DatasetEnsemble]: + """Create a sample DatasetEnsemble for testing with v2 target.""" + yield DatasetEnsemble( grid="hex", level=3, # Use level 3 for much faster tests - target="darts_mllabels", members=["AlphaEarth"], # Use only one member for faster tests add_lonlat=True, - filter_target="dartsml_has_coverage", # Filter to reduce dataset size ) class TestDatasetEnsemble: """Test suite for DatasetEnsemble class.""" - def test_initialization(self, sample_ensemble): + def test_initialization(self, sample_ensemble: DatasetEnsemble) -> None: """Test that DatasetEnsemble initializes correctly.""" assert sample_ensemble.grid == "hex" assert sample_ensemble.level == 3 - assert sample_ensemble.target == "darts_rts" assert "AlphaEarth" in sample_ensemble.members assert sample_ensemble.add_lonlat is True - def test_covcol_property(self, sample_ensemble, sample_ensemble_mllabels): - """Test that covcol property returns correct column name.""" - assert sample_ensemble.covcol == "darts_has_coverage" - assert sample_ensemble_mllabels.covcol == "dartsml_has_coverage" + def test_get_targets_returns_geodataframe(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that get_targets() returns a GeoDataFrame.""" + targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary") + assert isinstance(targets, gpd.GeoDataFrame) - def test_taskcol_property(self, sample_ensemble, sample_ensemble_mllabels): - """Test that taskcol returns correct column name for different tasks.""" - assert sample_ensemble.taskcol("binary") == "darts_has_rts" - assert sample_ensemble.taskcol("count") == "darts_rts_count" - assert sample_ensemble.taskcol("density") == "darts_rts_density" + def test_get_targets_has_expected_columns(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that get_targets() returns expected columns.""" + targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary") - assert sample_ensemble_mllabels.taskcol("binary") == "dartsml_has_rts" - assert sample_ensemble_mllabels.taskcol("count") == "dartsml_rts_count" - assert sample_ensemble_mllabels.taskcol("density") == "dartsml_rts_density" + # Should have required columns + assert "geometry" in targets.columns + assert "y" in targets.columns # target label + assert "z" in targets.columns # raw value + assert targets.index.name == "cell_id" - def test_create_returns_geodataframe(self, sample_ensemble): - """Test that create() returns a GeoDataFrame.""" - dataset = sample_ensemble.create(cache_mode="n") - assert isinstance(dataset, gpd.GeoDataFrame) + def test_get_targets_binary_task(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that binary task creates proper labels.""" + targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary") - def test_create_has_expected_columns(self, sample_ensemble): - """Test that create() returns dataset with expected columns.""" - dataset = sample_ensemble.create(cache_mode="n") + # Binary task should have categorical y with "No RTS" and "RTS" + assert targets["y"].dtype.name == "category" + assert set(targets["y"].cat.categories) == {"No RTS", "RTS"} - # Should have geometry column - assert "geometry" in dataset.columns + # z should be numeric (count) + assert pd.api.types.is_numeric_dtype(targets["z"]) - # Should have lat/lon if add_lonlat is True + def test_get_targets_count_task(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that count task creates proper labels.""" + targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="count") + + # Count task should have numeric y and z that are identical + assert pd.api.types.is_numeric_dtype(targets["y"]) + assert pd.api.types.is_numeric_dtype(targets["z"]) + assert (targets["y"] == targets["z"]).all() + + def test_get_targets_density_task(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that density task creates proper labels.""" + targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="density") + + # Density task should have numeric y and z that are identical + assert pd.api.types.is_numeric_dtype(targets["y"]) + assert pd.api.types.is_numeric_dtype(targets["z"]) + assert (targets["y"] == targets["z"]).all() + + def test_make_features_returns_dataframe(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that make_features() returns a DataFrame.""" + # Get a small subset of cell_ids for faster testing + targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary") + cell_ids: pd.Series = targets.index.to_series().head(50) + + features: pd.DataFrame = sample_ensemble.make_features(cell_ids=cell_ids, cache_mode="none") + assert isinstance(features, pd.DataFrame) + + def test_make_features_has_expected_columns(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that make_features() returns expected feature columns.""" + targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary") + cell_ids: pd.Series = targets.index.to_series().head(50) + + features: pd.DataFrame = sample_ensemble.make_features(cell_ids=cell_ids, cache_mode="none") + + # Should NOT have geometry column + assert "geometry" not in features.columns + + # Should have location columns if add_lonlat is True if sample_ensemble.add_lonlat: - assert "lon" in dataset.columns - assert "lat" in dataset.columns + assert "x" in features.columns + assert "y" in features.columns - # Should have target columns (darts_*) - assert any(col.startswith("darts_") for col in dataset.columns) + # Should have grid property columns + assert "cell_area" in features.columns + assert "land_area" in features.columns - # Should have member data columns - assert len(dataset.columns) > 3 # More than just geometry, lat, lon + # Should have member feature columns + assert any(col.startswith("embeddings_") for col in features.columns) - def test_create_batches_consistency(self, sample_ensemble): - """Test that create_batches produces batches with consistent columns.""" - batch_size = 100 - batches = list(sample_ensemble.create_batches(batch_size=batch_size, cache_mode="n")) + def test_inference_df_batches_consistency(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that create_inference_df produces batches with consistent columns.""" + batch_size: int = 100 + batches: list[pd.DataFrame] = list( + sample_ensemble.create_inference_df(batch_size=batch_size, cache_mode="none") + ) if len(batches) == 0: pytest.skip("No batches created (dataset might be empty)") @@ -98,83 +136,67 @@ class TestDatasetEnsemble: f"Difference: {set(batch.columns).symmetric_difference(first_batch_cols)}" ) - def test_create_vs_create_batches_columns(self, sample_ensemble): - """Test that create() and create_batches() return datasets with the same columns.""" - full_dataset = sample_ensemble.create(cache_mode="n") - batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n")) - - if len(batches) == 0: - pytest.skip("No batches created (dataset might be empty)") - - # Columns should be identical - full_cols = set(full_dataset.columns) - batch_cols = set(batches[0].columns) - - assert full_cols == batch_cols, ( - f"Column mismatch between create() and create_batches().\n" - f"Only in create(): {full_cols - batch_cols}\n" - f"Only in create_batches(): {batch_cols - full_cols}" + def test_training_set_structure(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that create_training_set creates proper structure.""" + training_set: TrainingSet = sample_ensemble.create_training_set( + task="binary", target="darts_v1", device="cpu", cache_mode="none" ) - def test_training_dataset_feature_columns(self, sample_ensemble): - """Test that create_cat_training_dataset creates proper feature columns.""" - training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu") + # Check that TrainingSet has all expected attributes + assert hasattr(training_set, "targets") + assert hasattr(training_set, "features") + assert hasattr(training_set, "X") + assert hasattr(training_set, "y") + assert hasattr(training_set, "z") + assert hasattr(training_set, "split") - # Get the columns used for model inputs - model_input_cols = set(training_data.X.data.columns) + def test_training_set_feature_columns(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that create_training_set creates proper feature columns.""" + training_set: TrainingSet = sample_ensemble.create_training_set( + task="binary", target="darts_v1", device="cpu", cache_mode="none" + ) - # These columns should NOT be in model inputs - assert "geometry" not in model_input_cols - assert sample_ensemble.covcol not in model_input_cols - assert sample_ensemble.taskcol("binary") not in model_input_cols + # Get the feature column names + feature_cols: set[str] = set(training_set.features.columns) - # No darts_* columns should be in model inputs - for col in model_input_cols: - assert not col.startswith("darts_"), f"Column {col} should have been dropped from model inputs" + # These columns should NOT be in features + assert "geometry" not in feature_cols + # Note: "y" and "x" are spatial coordinates from the grid and are valid features + # The target labels are stored in training_set.targets["y"], not in features - def test_training_dataset_feature_columns_mllabels(self, sample_ensemble_mllabels): - """Test feature columns for mllabels target.""" - training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu") + def test_training_set_no_target_leakage(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that target information doesn't leak into features.""" + training_set: TrainingSet = sample_ensemble.create_training_set( + task="binary", target="darts_v1", device="cpu", cache_mode="none" + ) - model_input_cols = set(training_data.X.data.columns) + # Features should not contain target-related columns + feature_cols: pd.Index = training_set.features.columns + assert all(not col.startswith("darts_") for col in feature_cols) - # These columns should NOT be in model inputs - assert "geometry" not in model_input_cols - assert sample_ensemble_mllabels.covcol not in model_input_cols - assert sample_ensemble_mllabels.taskcol("binary") not in model_input_cols - - # No dartsml_* columns should be in model inputs - for col in model_input_cols: - assert not col.startswith("dartsml_"), f"Column {col} should have been dropped from model inputs" - - def test_inference_vs_training_feature_consistency(self, sample_ensemble): - """Test that inference batches have the same features as training data after column dropping. + def test_inference_vs_training_feature_consistency(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that inference batches have the same features as training data. This test simulates the workflow in training.py and inference.py to ensure feature consistency between training and inference. """ - # Step 1: Create training dataset (as in training.py) - # Use only a small subset by creating just one batch - training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu") - training_feature_cols = set(training_data.X.data.columns) + # Step 1: Create training set (as in training.py) + training_set: TrainingSet = sample_ensemble.create_training_set( + task="binary", target="darts_v1", device="cpu", cache_mode="none" + ) + training_feature_cols: set[str] = set(training_set.features.columns) # Step 2: Create inference batch (as in inference.py) - # Get just the first batch to speed up test - batch_generator = sample_ensemble.create_batches(batch_size=100, cache_mode="n") - batch = next(batch_generator, None) + batch_generator: Generator[pd.DataFrame] = sample_ensemble.create_inference_df( + batch_size=100, cache_mode="none" + ) + batch: pd.DataFrame | None = next(batch_generator, None) if batch is None: pytest.skip("No batches created (dataset might be empty)") - # Simulate the column dropping in predict_proba - cols_to_drop = ["geometry"] - if sample_ensemble.target == "darts_mllabels": - cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")] - else: - cols_to_drop += [col for col in batch.columns if col.startswith("darts_")] - - inference_batch = batch.drop(columns=cols_to_drop) - inference_feature_cols = set(inference_batch.columns) + assert batch is not None # For type checker + inference_feature_cols: set[str] = set(batch.columns) # The features should match! assert training_feature_cols == inference_feature_cols, ( @@ -185,85 +207,64 @@ class TestDatasetEnsemble: f"Inference features ({len(inference_feature_cols)}): {sorted(inference_feature_cols)}" ) - def test_inference_vs_training_feature_consistency_mllabels(self, sample_ensemble_mllabels): - """Test feature consistency for mllabels target.""" - training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu") - training_feature_cols = set(training_data.X.data.columns) - - # Get just the first batch to speed up test - batch_generator = sample_ensemble_mllabels.create_batches(batch_size=100, cache_mode="n") - batch = next(batch_generator, None) - - if batch is None: - pytest.skip("No batches created (dataset might be empty)") - - # Simulate the column dropping in predict_proba - cols_to_drop = ["geometry"] - if sample_ensemble_mllabels.target == "darts_mllabels": - cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")] - else: - cols_to_drop += [col for col in batch.columns if col.startswith("darts_")] - - inference_batch = batch.drop(columns=cols_to_drop) - inference_feature_cols = set(inference_batch.columns) - - assert training_feature_cols == inference_feature_cols, ( - f"Feature mismatch between training and inference!\n" - f"Only in training: {training_feature_cols - inference_feature_cols}\n" - f"Only in inference: {inference_feature_cols - training_feature_cols}" - ) - - def test_all_tasks_feature_consistency(self, sample_ensemble): + def test_all_tasks_feature_consistency(self, sample_ensemble: DatasetEnsemble) -> None: """Test that all task types produce consistent features.""" - tasks = ["binary", "count", "density"] - feature_sets = {} + feature_sets: dict[str, set[str]] = {} - for task in tasks: - training_data = sample_ensemble.create_cat_training_dataset(task=task, device="cpu") - feature_sets[task] = set(training_data.X.data.columns) + for task in all_tasks: + training_set: TrainingSet = sample_ensemble.create_training_set( + task=task, + target="darts_v1", + device="cpu", + cache_mode="none", + ) + feature_sets[task] = set(training_set.features.columns) # All tasks should have the same features - binary_features = feature_sets["binary"] + binary_features: set[str] = feature_sets["binary"] for task, features in feature_sets.items(): assert features == binary_features, ( f"Task '{task}' has different features than 'binary'.\n" f"Difference: {features.symmetric_difference(binary_features)}" ) - def test_training_dataset_shapes(self, sample_ensemble): - """Test that training dataset has correct shapes.""" - training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu") + def test_training_set_shapes(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that training set has correct shapes.""" + training_set: TrainingSet = sample_ensemble.create_training_set( + task="binary", target="darts_v1", device="cpu", cache_mode="none" + ) - n_features = len(training_data.X.data.columns) - n_samples_train = training_data.X.train.shape[0] - n_samples_test = training_data.X.test.shape[0] + n_features: int = len(training_set.features.columns) + n_samples_train: int = training_set.X.train.shape[0] + n_samples_test: int = training_set.X.test.shape[0] # Check X shapes - assert training_data.X.train.shape == (n_samples_train, n_features) - assert training_data.X.test.shape == (n_samples_test, n_features) + assert training_set.X.train.shape == (n_samples_train, n_features) + assert training_set.X.test.shape == (n_samples_test, n_features) # Check y shapes - assert training_data.y.train.shape == (n_samples_train,) - assert training_data.y.test.shape == (n_samples_test,) + assert training_set.y.train.shape == (n_samples_train,) + assert training_set.y.test.shape == (n_samples_test,) # Check that train + test = total samples - assert len(training_data.dataset) == n_samples_train + n_samples_test + assert len(training_set) == n_samples_train + n_samples_test - def test_no_nan_in_training_features(self, sample_ensemble): + def test_no_nan_in_training_features(self, sample_ensemble: DatasetEnsemble) -> None: """Test that training features don't contain NaN values.""" - training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu") + training_set: TrainingSet = sample_ensemble.create_training_set( + task="binary", target="darts_v1", device="cpu", cache_mode="none" + ) # Convert to numpy for checking (handles both numpy and cupy arrays) - X_train = np.asarray(training_data.X.train) - X_test = np.asarray(training_data.X.test) + X_train: np.ndarray = np.asarray(training_set.X.train) # noqa: N806 + X_test: np.ndarray = np.asarray(training_set.X.test) # noqa: N806 assert not np.isnan(X_train).any(), "Training features contain NaN values" assert not np.isnan(X_test).any(), "Test features contain NaN values" - def test_batch_coverage(self, sample_ensemble): + def test_batch_coverage(self, sample_ensemble: DatasetEnsemble) -> None: """Test that batches cover all data without duplication.""" - full_dataset = sample_ensemble.create(cache_mode="n") - batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n")) + batches: list[pd.DataFrame] = list(sample_ensemble.create_inference_df(batch_size=100, cache_mode="none")) if len(batches) == 0: pytest.skip("No batches created (dataset might be empty)") @@ -277,34 +278,91 @@ class TestDatasetEnsemble: assert len(overlap) == 0, f"Found {len(overlap)} duplicate cell_ids across batches" batch_cell_ids.update(batch_ids) - # Check that all cell_ids from full dataset are in batches - full_cell_ids = set(full_dataset.index) - assert batch_cell_ids == full_cell_ids, ( + # Check that all cell_ids from grid are in batches + all_cell_ids = set(sample_ensemble.cell_ids.to_numpy()) + assert batch_cell_ids == all_cell_ids, ( f"Batch coverage mismatch.\n" - f"Missing from batches: {full_cell_ids - batch_cell_ids}\n" - f"Extra in batches: {batch_cell_ids - full_cell_ids}" + f"Missing from batches: {all_cell_ids - batch_cell_ids}\n" + f"Extra in batches: {batch_cell_ids - all_cell_ids}" ) class TestDatasetEnsembleEdgeCases: """Test edge cases and error handling.""" - def test_invalid_task_raises_error(self, sample_ensemble): + def test_invalid_task_raises_error(self, sample_ensemble: DatasetEnsemble) -> None: """Test that invalid task raises ValueError.""" - with pytest.raises(ValueError, match="Invalid task"): - sample_ensemble.create_cat_training_dataset(task="invalid", device="cpu") # type: ignore + with pytest.raises(NotImplementedError, match=r"Task .* not supported"): + sample_ensemble.create_training_set(task="invalid", target="darts_v1", device="cpu") # type: ignore[arg-type] - def test_stats_method(self, sample_ensemble): - """Test that get_stats returns expected structure.""" - stats = sample_ensemble.get_stats() + def test_target_labels_property(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that TrainingSet provides target labels for categorical tasks.""" + training_set: TrainingSet = sample_ensemble.create_training_set( + task="binary", target="darts_v1", device="cpu", cache_mode="none" + ) - assert "target" in stats - assert "num_target_samples" in stats - assert "members" in stats - assert "total_features" in stats + # Binary task should have labels + labels: list[str] | None = training_set.target_labels + assert labels is not None + assert isinstance(labels, list) + assert "No RTS" in labels + assert "RTS" in labels - # Check that members dict contains info for each member - for member in sample_ensemble.members: - assert member in stats["members"] - assert "variables" in stats["members"][member] - assert "num_features" in stats["members"][member] + def test_target_labels_none_for_regression(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that regression tasks return None for target_labels.""" + training_set: TrainingSet = sample_ensemble.create_training_set( + task="count", target="darts_v1", device="cpu", cache_mode="none" + ) + + # Regression task should return None + assert training_set.target_labels is None + + def test_to_dataframe_method(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that to_dataframe returns expected structure.""" + training_set: TrainingSet = sample_ensemble.create_training_set( + task="binary", target="darts_v1", device="cpu", cache_mode="none" + ) + + # Get full dataset + df_full: pd.DataFrame = training_set.to_dataframe(split=None) + assert isinstance(df_full, pd.DataFrame) + assert "label" in df_full.columns + assert len(df_full) == len(training_set) + + # Get train split + df_train = training_set.to_dataframe(split="train") + assert len(df_train) < len(df_full) + assert all(training_set.split[df_train.index] == "train") + + # Get test split + df_test = training_set.to_dataframe(split="test") + assert len(df_test) < len(df_full) + assert all(training_set.split[df_test.index] == "test") + + # Train + test should equal full + assert len(df_train) + len(df_test) == len(df_full) + + def test_different_targets(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that different target datasets can be used.""" + for target in all_target_datasets: + try: + training_set: TrainingSet = sample_ensemble.create_training_set( + task="binary", + target=target, + device="cpu", + cache_mode="none", + ) + assert len(training_set) > 0 + except Exception as e: + pytest.skip(f"Target {target} not available: {e}") + + def test_feature_names_property(self, sample_ensemble: DatasetEnsemble) -> None: + """Test that feature_names property returns list of feature names.""" + training_set: TrainingSet = sample_ensemble.create_training_set( + task="binary", target="darts_v1", device="cpu", cache_mode="none" + ) + + feature_names: list[str] = training_set.feature_names + assert isinstance(feature_names, list) + assert len(feature_names) == training_set.features.shape[1] + assert all(isinstance(name, str) for name in feature_names)