Fix dataset test

This commit is contained in:
Tobias Hölzer 2026-01-19 17:05:44 +01:00
parent 87a0d03af1
commit 073502c51d

View file

@ -1,91 +1,129 @@
"""Tests for dataset.py module, specifically DatasetEnsemble class.""" """Tests for dataset.py module, specifically DatasetEnsemble class."""
from collections.abc import Generator
import geopandas as gpd import geopandas as gpd
import numpy as np import numpy as np
import pandas as pd
import pytest import pytest
from entropice.ml.dataset import DatasetEnsemble from entropice.ml.dataset import DatasetEnsemble, TrainingSet
from entropice.utils.types import all_target_datasets, all_tasks
@pytest.fixture @pytest.fixture
def sample_ensemble(): def sample_ensemble() -> Generator[DatasetEnsemble]:
"""Create a sample DatasetEnsemble for testing with minimal data.""" """Create a sample DatasetEnsemble for testing with minimal data."""
return DatasetEnsemble( yield DatasetEnsemble(
grid="hex", grid="hex",
level=3, # Use level 3 for much faster tests level=3, # Use level 3 for much faster tests
target="darts_rts",
members=["AlphaEarth"], # Use only one member for faster tests members=["AlphaEarth"], # Use only one member for faster tests
add_lonlat=True, add_lonlat=True,
filter_target="darts_has_coverage", # Filter to reduce dataset size
) )
@pytest.fixture @pytest.fixture
def sample_ensemble_mllabels(): def sample_ensemble_v2() -> Generator[DatasetEnsemble]:
"""Create a sample DatasetEnsemble with mllabels target.""" """Create a sample DatasetEnsemble for testing with v2 target."""
return DatasetEnsemble( yield DatasetEnsemble(
grid="hex", grid="hex",
level=3, # Use level 3 for much faster tests level=3, # Use level 3 for much faster tests
target="darts_mllabels",
members=["AlphaEarth"], # Use only one member for faster tests members=["AlphaEarth"], # Use only one member for faster tests
add_lonlat=True, add_lonlat=True,
filter_target="dartsml_has_coverage", # Filter to reduce dataset size
) )
class TestDatasetEnsemble: class TestDatasetEnsemble:
"""Test suite for DatasetEnsemble class.""" """Test suite for DatasetEnsemble class."""
def test_initialization(self, sample_ensemble): def test_initialization(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that DatasetEnsemble initializes correctly.""" """Test that DatasetEnsemble initializes correctly."""
assert sample_ensemble.grid == "hex" assert sample_ensemble.grid == "hex"
assert sample_ensemble.level == 3 assert sample_ensemble.level == 3
assert sample_ensemble.target == "darts_rts"
assert "AlphaEarth" in sample_ensemble.members assert "AlphaEarth" in sample_ensemble.members
assert sample_ensemble.add_lonlat is True assert sample_ensemble.add_lonlat is True
def test_covcol_property(self, sample_ensemble, sample_ensemble_mllabels): def test_get_targets_returns_geodataframe(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that covcol property returns correct column name.""" """Test that get_targets() returns a GeoDataFrame."""
assert sample_ensemble.covcol == "darts_has_coverage" targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
assert sample_ensemble_mllabels.covcol == "dartsml_has_coverage" assert isinstance(targets, gpd.GeoDataFrame)
def test_taskcol_property(self, sample_ensemble, sample_ensemble_mllabels): def test_get_targets_has_expected_columns(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that taskcol returns correct column name for different tasks.""" """Test that get_targets() returns expected columns."""
assert sample_ensemble.taskcol("binary") == "darts_has_rts" targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
assert sample_ensemble.taskcol("count") == "darts_rts_count"
assert sample_ensemble.taskcol("density") == "darts_rts_density"
assert sample_ensemble_mllabels.taskcol("binary") == "dartsml_has_rts" # Should have required columns
assert sample_ensemble_mllabels.taskcol("count") == "dartsml_rts_count" assert "geometry" in targets.columns
assert sample_ensemble_mllabels.taskcol("density") == "dartsml_rts_density" assert "y" in targets.columns # target label
assert "z" in targets.columns # raw value
assert targets.index.name == "cell_id"
def test_create_returns_geodataframe(self, sample_ensemble): def test_get_targets_binary_task(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that create() returns a GeoDataFrame.""" """Test that binary task creates proper labels."""
dataset = sample_ensemble.create(cache_mode="n") targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
assert isinstance(dataset, gpd.GeoDataFrame)
def test_create_has_expected_columns(self, sample_ensemble): # Binary task should have categorical y with "No RTS" and "RTS"
"""Test that create() returns dataset with expected columns.""" assert targets["y"].dtype.name == "category"
dataset = sample_ensemble.create(cache_mode="n") assert set(targets["y"].cat.categories) == {"No RTS", "RTS"}
# Should have geometry column # z should be numeric (count)
assert "geometry" in dataset.columns assert pd.api.types.is_numeric_dtype(targets["z"])
# Should have lat/lon if add_lonlat is True def test_get_targets_count_task(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that count task creates proper labels."""
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="count")
# Count task should have numeric y and z that are identical
assert pd.api.types.is_numeric_dtype(targets["y"])
assert pd.api.types.is_numeric_dtype(targets["z"])
assert (targets["y"] == targets["z"]).all()
def test_get_targets_density_task(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that density task creates proper labels."""
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="density")
# Density task should have numeric y and z that are identical
assert pd.api.types.is_numeric_dtype(targets["y"])
assert pd.api.types.is_numeric_dtype(targets["z"])
assert (targets["y"] == targets["z"]).all()
def test_make_features_returns_dataframe(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that make_features() returns a DataFrame."""
# Get a small subset of cell_ids for faster testing
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
cell_ids: pd.Series = targets.index.to_series().head(50)
features: pd.DataFrame = sample_ensemble.make_features(cell_ids=cell_ids, cache_mode="none")
assert isinstance(features, pd.DataFrame)
def test_make_features_has_expected_columns(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that make_features() returns expected feature columns."""
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
cell_ids: pd.Series = targets.index.to_series().head(50)
features: pd.DataFrame = sample_ensemble.make_features(cell_ids=cell_ids, cache_mode="none")
# Should NOT have geometry column
assert "geometry" not in features.columns
# Should have location columns if add_lonlat is True
if sample_ensemble.add_lonlat: if sample_ensemble.add_lonlat:
assert "lon" in dataset.columns assert "x" in features.columns
assert "lat" in dataset.columns assert "y" in features.columns
# Should have target columns (darts_*) # Should have grid property columns
assert any(col.startswith("darts_") for col in dataset.columns) assert "cell_area" in features.columns
assert "land_area" in features.columns
# Should have member data columns # Should have member feature columns
assert len(dataset.columns) > 3 # More than just geometry, lat, lon assert any(col.startswith("embeddings_") for col in features.columns)
def test_create_batches_consistency(self, sample_ensemble): def test_inference_df_batches_consistency(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that create_batches produces batches with consistent columns.""" """Test that create_inference_df produces batches with consistent columns."""
batch_size = 100 batch_size: int = 100
batches = list(sample_ensemble.create_batches(batch_size=batch_size, cache_mode="n")) batches: list[pd.DataFrame] = list(
sample_ensemble.create_inference_df(batch_size=batch_size, cache_mode="none")
)
if len(batches) == 0: if len(batches) == 0:
pytest.skip("No batches created (dataset might be empty)") pytest.skip("No batches created (dataset might be empty)")
@ -98,83 +136,67 @@ class TestDatasetEnsemble:
f"Difference: {set(batch.columns).symmetric_difference(first_batch_cols)}" f"Difference: {set(batch.columns).symmetric_difference(first_batch_cols)}"
) )
def test_create_vs_create_batches_columns(self, sample_ensemble): def test_training_set_structure(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that create() and create_batches() return datasets with the same columns.""" """Test that create_training_set creates proper structure."""
full_dataset = sample_ensemble.create(cache_mode="n") training_set: TrainingSet = sample_ensemble.create_training_set(
batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n")) task="binary", target="darts_v1", device="cpu", cache_mode="none"
if len(batches) == 0:
pytest.skip("No batches created (dataset might be empty)")
# Columns should be identical
full_cols = set(full_dataset.columns)
batch_cols = set(batches[0].columns)
assert full_cols == batch_cols, (
f"Column mismatch between create() and create_batches().\n"
f"Only in create(): {full_cols - batch_cols}\n"
f"Only in create_batches(): {batch_cols - full_cols}"
) )
def test_training_dataset_feature_columns(self, sample_ensemble): # Check that TrainingSet has all expected attributes
"""Test that create_cat_training_dataset creates proper feature columns.""" assert hasattr(training_set, "targets")
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu") assert hasattr(training_set, "features")
assert hasattr(training_set, "X")
assert hasattr(training_set, "y")
assert hasattr(training_set, "z")
assert hasattr(training_set, "split")
# Get the columns used for model inputs def test_training_set_feature_columns(self, sample_ensemble: DatasetEnsemble) -> None:
model_input_cols = set(training_data.X.data.columns) """Test that create_training_set creates proper feature columns."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
# These columns should NOT be in model inputs # Get the feature column names
assert "geometry" not in model_input_cols feature_cols: set[str] = set(training_set.features.columns)
assert sample_ensemble.covcol not in model_input_cols
assert sample_ensemble.taskcol("binary") not in model_input_cols
# No darts_* columns should be in model inputs # These columns should NOT be in features
for col in model_input_cols: assert "geometry" not in feature_cols
assert not col.startswith("darts_"), f"Column {col} should have been dropped from model inputs" # Note: "y" and "x" are spatial coordinates from the grid and are valid features
# The target labels are stored in training_set.targets["y"], not in features
def test_training_dataset_feature_columns_mllabels(self, sample_ensemble_mllabels): def test_training_set_no_target_leakage(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test feature columns for mllabels target.""" """Test that target information doesn't leak into features."""
training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu") training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
model_input_cols = set(training_data.X.data.columns) # Features should not contain target-related columns
feature_cols: pd.Index = training_set.features.columns
assert all(not col.startswith("darts_") for col in feature_cols)
# These columns should NOT be in model inputs def test_inference_vs_training_feature_consistency(self, sample_ensemble: DatasetEnsemble) -> None:
assert "geometry" not in model_input_cols """Test that inference batches have the same features as training data.
assert sample_ensemble_mllabels.covcol not in model_input_cols
assert sample_ensemble_mllabels.taskcol("binary") not in model_input_cols
# No dartsml_* columns should be in model inputs
for col in model_input_cols:
assert not col.startswith("dartsml_"), f"Column {col} should have been dropped from model inputs"
def test_inference_vs_training_feature_consistency(self, sample_ensemble):
"""Test that inference batches have the same features as training data after column dropping.
This test simulates the workflow in training.py and inference.py to ensure This test simulates the workflow in training.py and inference.py to ensure
feature consistency between training and inference. feature consistency between training and inference.
""" """
# Step 1: Create training dataset (as in training.py) # Step 1: Create training set (as in training.py)
# Use only a small subset by creating just one batch training_set: TrainingSet = sample_ensemble.create_training_set(
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu") task="binary", target="darts_v1", device="cpu", cache_mode="none"
training_feature_cols = set(training_data.X.data.columns) )
training_feature_cols: set[str] = set(training_set.features.columns)
# Step 2: Create inference batch (as in inference.py) # Step 2: Create inference batch (as in inference.py)
# Get just the first batch to speed up test batch_generator: Generator[pd.DataFrame] = sample_ensemble.create_inference_df(
batch_generator = sample_ensemble.create_batches(batch_size=100, cache_mode="n") batch_size=100, cache_mode="none"
batch = next(batch_generator, None) )
batch: pd.DataFrame | None = next(batch_generator, None)
if batch is None: if batch is None:
pytest.skip("No batches created (dataset might be empty)") pytest.skip("No batches created (dataset might be empty)")
# Simulate the column dropping in predict_proba assert batch is not None # For type checker
cols_to_drop = ["geometry"] inference_feature_cols: set[str] = set(batch.columns)
if sample_ensemble.target == "darts_mllabels":
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
else:
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
inference_batch = batch.drop(columns=cols_to_drop)
inference_feature_cols = set(inference_batch.columns)
# The features should match! # The features should match!
assert training_feature_cols == inference_feature_cols, ( assert training_feature_cols == inference_feature_cols, (
@ -185,85 +207,64 @@ class TestDatasetEnsemble:
f"Inference features ({len(inference_feature_cols)}): {sorted(inference_feature_cols)}" f"Inference features ({len(inference_feature_cols)}): {sorted(inference_feature_cols)}"
) )
def test_inference_vs_training_feature_consistency_mllabels(self, sample_ensemble_mllabels): def test_all_tasks_feature_consistency(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test feature consistency for mllabels target."""
training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu")
training_feature_cols = set(training_data.X.data.columns)
# Get just the first batch to speed up test
batch_generator = sample_ensemble_mllabels.create_batches(batch_size=100, cache_mode="n")
batch = next(batch_generator, None)
if batch is None:
pytest.skip("No batches created (dataset might be empty)")
# Simulate the column dropping in predict_proba
cols_to_drop = ["geometry"]
if sample_ensemble_mllabels.target == "darts_mllabels":
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
else:
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
inference_batch = batch.drop(columns=cols_to_drop)
inference_feature_cols = set(inference_batch.columns)
assert training_feature_cols == inference_feature_cols, (
f"Feature mismatch between training and inference!\n"
f"Only in training: {training_feature_cols - inference_feature_cols}\n"
f"Only in inference: {inference_feature_cols - training_feature_cols}"
)
def test_all_tasks_feature_consistency(self, sample_ensemble):
"""Test that all task types produce consistent features.""" """Test that all task types produce consistent features."""
tasks = ["binary", "count", "density"] feature_sets: dict[str, set[str]] = {}
feature_sets = {}
for task in tasks: for task in all_tasks:
training_data = sample_ensemble.create_cat_training_dataset(task=task, device="cpu") training_set: TrainingSet = sample_ensemble.create_training_set(
feature_sets[task] = set(training_data.X.data.columns) task=task,
target="darts_v1",
device="cpu",
cache_mode="none",
)
feature_sets[task] = set(training_set.features.columns)
# All tasks should have the same features # All tasks should have the same features
binary_features = feature_sets["binary"] binary_features: set[str] = feature_sets["binary"]
for task, features in feature_sets.items(): for task, features in feature_sets.items():
assert features == binary_features, ( assert features == binary_features, (
f"Task '{task}' has different features than 'binary'.\n" f"Task '{task}' has different features than 'binary'.\n"
f"Difference: {features.symmetric_difference(binary_features)}" f"Difference: {features.symmetric_difference(binary_features)}"
) )
def test_training_dataset_shapes(self, sample_ensemble): def test_training_set_shapes(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that training dataset has correct shapes.""" """Test that training set has correct shapes."""
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu") training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
n_features = len(training_data.X.data.columns) n_features: int = len(training_set.features.columns)
n_samples_train = training_data.X.train.shape[0] n_samples_train: int = training_set.X.train.shape[0]
n_samples_test = training_data.X.test.shape[0] n_samples_test: int = training_set.X.test.shape[0]
# Check X shapes # Check X shapes
assert training_data.X.train.shape == (n_samples_train, n_features) assert training_set.X.train.shape == (n_samples_train, n_features)
assert training_data.X.test.shape == (n_samples_test, n_features) assert training_set.X.test.shape == (n_samples_test, n_features)
# Check y shapes # Check y shapes
assert training_data.y.train.shape == (n_samples_train,) assert training_set.y.train.shape == (n_samples_train,)
assert training_data.y.test.shape == (n_samples_test,) assert training_set.y.test.shape == (n_samples_test,)
# Check that train + test = total samples # Check that train + test = total samples
assert len(training_data.dataset) == n_samples_train + n_samples_test assert len(training_set) == n_samples_train + n_samples_test
def test_no_nan_in_training_features(self, sample_ensemble): def test_no_nan_in_training_features(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that training features don't contain NaN values.""" """Test that training features don't contain NaN values."""
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu") training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
# Convert to numpy for checking (handles both numpy and cupy arrays) # Convert to numpy for checking (handles both numpy and cupy arrays)
X_train = np.asarray(training_data.X.train) X_train: np.ndarray = np.asarray(training_set.X.train) # noqa: N806
X_test = np.asarray(training_data.X.test) X_test: np.ndarray = np.asarray(training_set.X.test) # noqa: N806
assert not np.isnan(X_train).any(), "Training features contain NaN values" assert not np.isnan(X_train).any(), "Training features contain NaN values"
assert not np.isnan(X_test).any(), "Test features contain NaN values" assert not np.isnan(X_test).any(), "Test features contain NaN values"
def test_batch_coverage(self, sample_ensemble): def test_batch_coverage(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that batches cover all data without duplication.""" """Test that batches cover all data without duplication."""
full_dataset = sample_ensemble.create(cache_mode="n") batches: list[pd.DataFrame] = list(sample_ensemble.create_inference_df(batch_size=100, cache_mode="none"))
batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n"))
if len(batches) == 0: if len(batches) == 0:
pytest.skip("No batches created (dataset might be empty)") pytest.skip("No batches created (dataset might be empty)")
@ -277,34 +278,91 @@ class TestDatasetEnsemble:
assert len(overlap) == 0, f"Found {len(overlap)} duplicate cell_ids across batches" assert len(overlap) == 0, f"Found {len(overlap)} duplicate cell_ids across batches"
batch_cell_ids.update(batch_ids) batch_cell_ids.update(batch_ids)
# Check that all cell_ids from full dataset are in batches # Check that all cell_ids from grid are in batches
full_cell_ids = set(full_dataset.index) all_cell_ids = set(sample_ensemble.cell_ids.to_numpy())
assert batch_cell_ids == full_cell_ids, ( assert batch_cell_ids == all_cell_ids, (
f"Batch coverage mismatch.\n" f"Batch coverage mismatch.\n"
f"Missing from batches: {full_cell_ids - batch_cell_ids}\n" f"Missing from batches: {all_cell_ids - batch_cell_ids}\n"
f"Extra in batches: {batch_cell_ids - full_cell_ids}" f"Extra in batches: {batch_cell_ids - all_cell_ids}"
) )
class TestDatasetEnsembleEdgeCases: class TestDatasetEnsembleEdgeCases:
"""Test edge cases and error handling.""" """Test edge cases and error handling."""
def test_invalid_task_raises_error(self, sample_ensemble): def test_invalid_task_raises_error(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that invalid task raises ValueError.""" """Test that invalid task raises ValueError."""
with pytest.raises(ValueError, match="Invalid task"): with pytest.raises(NotImplementedError, match=r"Task .* not supported"):
sample_ensemble.create_cat_training_dataset(task="invalid", device="cpu") # type: ignore sample_ensemble.create_training_set(task="invalid", target="darts_v1", device="cpu") # type: ignore[arg-type]
def test_stats_method(self, sample_ensemble): def test_target_labels_property(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that get_stats returns expected structure.""" """Test that TrainingSet provides target labels for categorical tasks."""
stats = sample_ensemble.get_stats() training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
assert "target" in stats # Binary task should have labels
assert "num_target_samples" in stats labels: list[str] | None = training_set.target_labels
assert "members" in stats assert labels is not None
assert "total_features" in stats assert isinstance(labels, list)
assert "No RTS" in labels
assert "RTS" in labels
# Check that members dict contains info for each member def test_target_labels_none_for_regression(self, sample_ensemble: DatasetEnsemble) -> None:
for member in sample_ensemble.members: """Test that regression tasks return None for target_labels."""
assert member in stats["members"] training_set: TrainingSet = sample_ensemble.create_training_set(
assert "variables" in stats["members"][member] task="count", target="darts_v1", device="cpu", cache_mode="none"
assert "num_features" in stats["members"][member] )
# Regression task should return None
assert training_set.target_labels is None
def test_to_dataframe_method(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that to_dataframe returns expected structure."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
# Get full dataset
df_full: pd.DataFrame = training_set.to_dataframe(split=None)
assert isinstance(df_full, pd.DataFrame)
assert "label" in df_full.columns
assert len(df_full) == len(training_set)
# Get train split
df_train = training_set.to_dataframe(split="train")
assert len(df_train) < len(df_full)
assert all(training_set.split[df_train.index] == "train")
# Get test split
df_test = training_set.to_dataframe(split="test")
assert len(df_test) < len(df_full)
assert all(training_set.split[df_test.index] == "test")
# Train + test should equal full
assert len(df_train) + len(df_test) == len(df_full)
def test_different_targets(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that different target datasets can be used."""
for target in all_target_datasets:
try:
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary",
target=target,
device="cpu",
cache_mode="none",
)
assert len(training_set) > 0
except Exception as e:
pytest.skip(f"Target {target} not available: {e}")
def test_feature_names_property(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that feature_names property returns list of feature names."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
feature_names: list[str] = training_set.feature_names
assert isinstance(feature_names, list)
assert len(feature_names) == training_set.features.shape[1]
assert all(isinstance(name, str) for name in feature_names)