Fix dataset test

This commit is contained in:
Tobias Hölzer 2026-01-19 17:05:44 +01:00
parent 87a0d03af1
commit 073502c51d

View file

@ -1,91 +1,129 @@
"""Tests for dataset.py module, specifically DatasetEnsemble class."""
from collections.abc import Generator
import geopandas as gpd
import numpy as np
import pandas as pd
import pytest
from entropice.ml.dataset import DatasetEnsemble
from entropice.ml.dataset import DatasetEnsemble, TrainingSet
from entropice.utils.types import all_target_datasets, all_tasks
@pytest.fixture
def sample_ensemble():
def sample_ensemble() -> Generator[DatasetEnsemble]:
"""Create a sample DatasetEnsemble for testing with minimal data."""
return DatasetEnsemble(
yield DatasetEnsemble(
grid="hex",
level=3, # Use level 3 for much faster tests
target="darts_rts",
members=["AlphaEarth"], # Use only one member for faster tests
add_lonlat=True,
filter_target="darts_has_coverage", # Filter to reduce dataset size
)
@pytest.fixture
def sample_ensemble_mllabels():
"""Create a sample DatasetEnsemble with mllabels target."""
return DatasetEnsemble(
def sample_ensemble_v2() -> Generator[DatasetEnsemble]:
"""Create a sample DatasetEnsemble for testing with v2 target."""
yield DatasetEnsemble(
grid="hex",
level=3, # Use level 3 for much faster tests
target="darts_mllabels",
members=["AlphaEarth"], # Use only one member for faster tests
add_lonlat=True,
filter_target="dartsml_has_coverage", # Filter to reduce dataset size
)
class TestDatasetEnsemble:
"""Test suite for DatasetEnsemble class."""
def test_initialization(self, sample_ensemble):
def test_initialization(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that DatasetEnsemble initializes correctly."""
assert sample_ensemble.grid == "hex"
assert sample_ensemble.level == 3
assert sample_ensemble.target == "darts_rts"
assert "AlphaEarth" in sample_ensemble.members
assert sample_ensemble.add_lonlat is True
def test_covcol_property(self, sample_ensemble, sample_ensemble_mllabels):
"""Test that covcol property returns correct column name."""
assert sample_ensemble.covcol == "darts_has_coverage"
assert sample_ensemble_mllabels.covcol == "dartsml_has_coverage"
def test_get_targets_returns_geodataframe(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that get_targets() returns a GeoDataFrame."""
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
assert isinstance(targets, gpd.GeoDataFrame)
def test_taskcol_property(self, sample_ensemble, sample_ensemble_mllabels):
"""Test that taskcol returns correct column name for different tasks."""
assert sample_ensemble.taskcol("binary") == "darts_has_rts"
assert sample_ensemble.taskcol("count") == "darts_rts_count"
assert sample_ensemble.taskcol("density") == "darts_rts_density"
def test_get_targets_has_expected_columns(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that get_targets() returns expected columns."""
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
assert sample_ensemble_mllabels.taskcol("binary") == "dartsml_has_rts"
assert sample_ensemble_mllabels.taskcol("count") == "dartsml_rts_count"
assert sample_ensemble_mllabels.taskcol("density") == "dartsml_rts_density"
# Should have required columns
assert "geometry" in targets.columns
assert "y" in targets.columns # target label
assert "z" in targets.columns # raw value
assert targets.index.name == "cell_id"
def test_create_returns_geodataframe(self, sample_ensemble):
"""Test that create() returns a GeoDataFrame."""
dataset = sample_ensemble.create(cache_mode="n")
assert isinstance(dataset, gpd.GeoDataFrame)
def test_get_targets_binary_task(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that binary task creates proper labels."""
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
def test_create_has_expected_columns(self, sample_ensemble):
"""Test that create() returns dataset with expected columns."""
dataset = sample_ensemble.create(cache_mode="n")
# Binary task should have categorical y with "No RTS" and "RTS"
assert targets["y"].dtype.name == "category"
assert set(targets["y"].cat.categories) == {"No RTS", "RTS"}
# Should have geometry column
assert "geometry" in dataset.columns
# z should be numeric (count)
assert pd.api.types.is_numeric_dtype(targets["z"])
# Should have lat/lon if add_lonlat is True
def test_get_targets_count_task(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that count task creates proper labels."""
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="count")
# Count task should have numeric y and z that are identical
assert pd.api.types.is_numeric_dtype(targets["y"])
assert pd.api.types.is_numeric_dtype(targets["z"])
assert (targets["y"] == targets["z"]).all()
def test_get_targets_density_task(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that density task creates proper labels."""
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="density")
# Density task should have numeric y and z that are identical
assert pd.api.types.is_numeric_dtype(targets["y"])
assert pd.api.types.is_numeric_dtype(targets["z"])
assert (targets["y"] == targets["z"]).all()
def test_make_features_returns_dataframe(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that make_features() returns a DataFrame."""
# Get a small subset of cell_ids for faster testing
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
cell_ids: pd.Series = targets.index.to_series().head(50)
features: pd.DataFrame = sample_ensemble.make_features(cell_ids=cell_ids, cache_mode="none")
assert isinstance(features, pd.DataFrame)
def test_make_features_has_expected_columns(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that make_features() returns expected feature columns."""
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
cell_ids: pd.Series = targets.index.to_series().head(50)
features: pd.DataFrame = sample_ensemble.make_features(cell_ids=cell_ids, cache_mode="none")
# Should NOT have geometry column
assert "geometry" not in features.columns
# Should have location columns if add_lonlat is True
if sample_ensemble.add_lonlat:
assert "lon" in dataset.columns
assert "lat" in dataset.columns
assert "x" in features.columns
assert "y" in features.columns
# Should have target columns (darts_*)
assert any(col.startswith("darts_") for col in dataset.columns)
# Should have grid property columns
assert "cell_area" in features.columns
assert "land_area" in features.columns
# Should have member data columns
assert len(dataset.columns) > 3 # More than just geometry, lat, lon
# Should have member feature columns
assert any(col.startswith("embeddings_") for col in features.columns)
def test_create_batches_consistency(self, sample_ensemble):
"""Test that create_batches produces batches with consistent columns."""
batch_size = 100
batches = list(sample_ensemble.create_batches(batch_size=batch_size, cache_mode="n"))
def test_inference_df_batches_consistency(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that create_inference_df produces batches with consistent columns."""
batch_size: int = 100
batches: list[pd.DataFrame] = list(
sample_ensemble.create_inference_df(batch_size=batch_size, cache_mode="none")
)
if len(batches) == 0:
pytest.skip("No batches created (dataset might be empty)")
@ -98,83 +136,67 @@ class TestDatasetEnsemble:
f"Difference: {set(batch.columns).symmetric_difference(first_batch_cols)}"
)
def test_create_vs_create_batches_columns(self, sample_ensemble):
"""Test that create() and create_batches() return datasets with the same columns."""
full_dataset = sample_ensemble.create(cache_mode="n")
batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n"))
if len(batches) == 0:
pytest.skip("No batches created (dataset might be empty)")
# Columns should be identical
full_cols = set(full_dataset.columns)
batch_cols = set(batches[0].columns)
assert full_cols == batch_cols, (
f"Column mismatch between create() and create_batches().\n"
f"Only in create(): {full_cols - batch_cols}\n"
f"Only in create_batches(): {batch_cols - full_cols}"
def test_training_set_structure(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that create_training_set creates proper structure."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
def test_training_dataset_feature_columns(self, sample_ensemble):
"""Test that create_cat_training_dataset creates proper feature columns."""
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
# Check that TrainingSet has all expected attributes
assert hasattr(training_set, "targets")
assert hasattr(training_set, "features")
assert hasattr(training_set, "X")
assert hasattr(training_set, "y")
assert hasattr(training_set, "z")
assert hasattr(training_set, "split")
# Get the columns used for model inputs
model_input_cols = set(training_data.X.data.columns)
def test_training_set_feature_columns(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that create_training_set creates proper feature columns."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
# These columns should NOT be in model inputs
assert "geometry" not in model_input_cols
assert sample_ensemble.covcol not in model_input_cols
assert sample_ensemble.taskcol("binary") not in model_input_cols
# Get the feature column names
feature_cols: set[str] = set(training_set.features.columns)
# No darts_* columns should be in model inputs
for col in model_input_cols:
assert not col.startswith("darts_"), f"Column {col} should have been dropped from model inputs"
# These columns should NOT be in features
assert "geometry" not in feature_cols
# Note: "y" and "x" are spatial coordinates from the grid and are valid features
# The target labels are stored in training_set.targets["y"], not in features
def test_training_dataset_feature_columns_mllabels(self, sample_ensemble_mllabels):
"""Test feature columns for mllabels target."""
training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu")
def test_training_set_no_target_leakage(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that target information doesn't leak into features."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
model_input_cols = set(training_data.X.data.columns)
# Features should not contain target-related columns
feature_cols: pd.Index = training_set.features.columns
assert all(not col.startswith("darts_") for col in feature_cols)
# These columns should NOT be in model inputs
assert "geometry" not in model_input_cols
assert sample_ensemble_mllabels.covcol not in model_input_cols
assert sample_ensemble_mllabels.taskcol("binary") not in model_input_cols
# No dartsml_* columns should be in model inputs
for col in model_input_cols:
assert not col.startswith("dartsml_"), f"Column {col} should have been dropped from model inputs"
def test_inference_vs_training_feature_consistency(self, sample_ensemble):
"""Test that inference batches have the same features as training data after column dropping.
def test_inference_vs_training_feature_consistency(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that inference batches have the same features as training data.
This test simulates the workflow in training.py and inference.py to ensure
feature consistency between training and inference.
"""
# Step 1: Create training dataset (as in training.py)
# Use only a small subset by creating just one batch
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
training_feature_cols = set(training_data.X.data.columns)
# Step 1: Create training set (as in training.py)
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
training_feature_cols: set[str] = set(training_set.features.columns)
# Step 2: Create inference batch (as in inference.py)
# Get just the first batch to speed up test
batch_generator = sample_ensemble.create_batches(batch_size=100, cache_mode="n")
batch = next(batch_generator, None)
batch_generator: Generator[pd.DataFrame] = sample_ensemble.create_inference_df(
batch_size=100, cache_mode="none"
)
batch: pd.DataFrame | None = next(batch_generator, None)
if batch is None:
pytest.skip("No batches created (dataset might be empty)")
# Simulate the column dropping in predict_proba
cols_to_drop = ["geometry"]
if sample_ensemble.target == "darts_mllabels":
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
else:
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
inference_batch = batch.drop(columns=cols_to_drop)
inference_feature_cols = set(inference_batch.columns)
assert batch is not None # For type checker
inference_feature_cols: set[str] = set(batch.columns)
# The features should match!
assert training_feature_cols == inference_feature_cols, (
@ -185,85 +207,64 @@ class TestDatasetEnsemble:
f"Inference features ({len(inference_feature_cols)}): {sorted(inference_feature_cols)}"
)
def test_inference_vs_training_feature_consistency_mllabels(self, sample_ensemble_mllabels):
"""Test feature consistency for mllabels target."""
training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu")
training_feature_cols = set(training_data.X.data.columns)
# Get just the first batch to speed up test
batch_generator = sample_ensemble_mllabels.create_batches(batch_size=100, cache_mode="n")
batch = next(batch_generator, None)
if batch is None:
pytest.skip("No batches created (dataset might be empty)")
# Simulate the column dropping in predict_proba
cols_to_drop = ["geometry"]
if sample_ensemble_mllabels.target == "darts_mllabels":
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
else:
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
inference_batch = batch.drop(columns=cols_to_drop)
inference_feature_cols = set(inference_batch.columns)
assert training_feature_cols == inference_feature_cols, (
f"Feature mismatch between training and inference!\n"
f"Only in training: {training_feature_cols - inference_feature_cols}\n"
f"Only in inference: {inference_feature_cols - training_feature_cols}"
)
def test_all_tasks_feature_consistency(self, sample_ensemble):
def test_all_tasks_feature_consistency(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that all task types produce consistent features."""
tasks = ["binary", "count", "density"]
feature_sets = {}
feature_sets: dict[str, set[str]] = {}
for task in tasks:
training_data = sample_ensemble.create_cat_training_dataset(task=task, device="cpu")
feature_sets[task] = set(training_data.X.data.columns)
for task in all_tasks:
training_set: TrainingSet = sample_ensemble.create_training_set(
task=task,
target="darts_v1",
device="cpu",
cache_mode="none",
)
feature_sets[task] = set(training_set.features.columns)
# All tasks should have the same features
binary_features = feature_sets["binary"]
binary_features: set[str] = feature_sets["binary"]
for task, features in feature_sets.items():
assert features == binary_features, (
f"Task '{task}' has different features than 'binary'.\n"
f"Difference: {features.symmetric_difference(binary_features)}"
)
def test_training_dataset_shapes(self, sample_ensemble):
"""Test that training dataset has correct shapes."""
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
def test_training_set_shapes(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that training set has correct shapes."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
n_features = len(training_data.X.data.columns)
n_samples_train = training_data.X.train.shape[0]
n_samples_test = training_data.X.test.shape[0]
n_features: int = len(training_set.features.columns)
n_samples_train: int = training_set.X.train.shape[0]
n_samples_test: int = training_set.X.test.shape[0]
# Check X shapes
assert training_data.X.train.shape == (n_samples_train, n_features)
assert training_data.X.test.shape == (n_samples_test, n_features)
assert training_set.X.train.shape == (n_samples_train, n_features)
assert training_set.X.test.shape == (n_samples_test, n_features)
# Check y shapes
assert training_data.y.train.shape == (n_samples_train,)
assert training_data.y.test.shape == (n_samples_test,)
assert training_set.y.train.shape == (n_samples_train,)
assert training_set.y.test.shape == (n_samples_test,)
# Check that train + test = total samples
assert len(training_data.dataset) == n_samples_train + n_samples_test
assert len(training_set) == n_samples_train + n_samples_test
def test_no_nan_in_training_features(self, sample_ensemble):
def test_no_nan_in_training_features(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that training features don't contain NaN values."""
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
# Convert to numpy for checking (handles both numpy and cupy arrays)
X_train = np.asarray(training_data.X.train)
X_test = np.asarray(training_data.X.test)
X_train: np.ndarray = np.asarray(training_set.X.train) # noqa: N806
X_test: np.ndarray = np.asarray(training_set.X.test) # noqa: N806
assert not np.isnan(X_train).any(), "Training features contain NaN values"
assert not np.isnan(X_test).any(), "Test features contain NaN values"
def test_batch_coverage(self, sample_ensemble):
def test_batch_coverage(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that batches cover all data without duplication."""
full_dataset = sample_ensemble.create(cache_mode="n")
batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n"))
batches: list[pd.DataFrame] = list(sample_ensemble.create_inference_df(batch_size=100, cache_mode="none"))
if len(batches) == 0:
pytest.skip("No batches created (dataset might be empty)")
@ -277,34 +278,91 @@ class TestDatasetEnsemble:
assert len(overlap) == 0, f"Found {len(overlap)} duplicate cell_ids across batches"
batch_cell_ids.update(batch_ids)
# Check that all cell_ids from full dataset are in batches
full_cell_ids = set(full_dataset.index)
assert batch_cell_ids == full_cell_ids, (
# Check that all cell_ids from grid are in batches
all_cell_ids = set(sample_ensemble.cell_ids.to_numpy())
assert batch_cell_ids == all_cell_ids, (
f"Batch coverage mismatch.\n"
f"Missing from batches: {full_cell_ids - batch_cell_ids}\n"
f"Extra in batches: {batch_cell_ids - full_cell_ids}"
f"Missing from batches: {all_cell_ids - batch_cell_ids}\n"
f"Extra in batches: {batch_cell_ids - all_cell_ids}"
)
class TestDatasetEnsembleEdgeCases:
"""Test edge cases and error handling."""
def test_invalid_task_raises_error(self, sample_ensemble):
def test_invalid_task_raises_error(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that invalid task raises ValueError."""
with pytest.raises(ValueError, match="Invalid task"):
sample_ensemble.create_cat_training_dataset(task="invalid", device="cpu") # type: ignore
with pytest.raises(NotImplementedError, match=r"Task .* not supported"):
sample_ensemble.create_training_set(task="invalid", target="darts_v1", device="cpu") # type: ignore[arg-type]
def test_stats_method(self, sample_ensemble):
"""Test that get_stats returns expected structure."""
stats = sample_ensemble.get_stats()
def test_target_labels_property(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that TrainingSet provides target labels for categorical tasks."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
assert "target" in stats
assert "num_target_samples" in stats
assert "members" in stats
assert "total_features" in stats
# Binary task should have labels
labels: list[str] | None = training_set.target_labels
assert labels is not None
assert isinstance(labels, list)
assert "No RTS" in labels
assert "RTS" in labels
# Check that members dict contains info for each member
for member in sample_ensemble.members:
assert member in stats["members"]
assert "variables" in stats["members"][member]
assert "num_features" in stats["members"][member]
def test_target_labels_none_for_regression(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that regression tasks return None for target_labels."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="count", target="darts_v1", device="cpu", cache_mode="none"
)
# Regression task should return None
assert training_set.target_labels is None
def test_to_dataframe_method(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that to_dataframe returns expected structure."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
# Get full dataset
df_full: pd.DataFrame = training_set.to_dataframe(split=None)
assert isinstance(df_full, pd.DataFrame)
assert "label" in df_full.columns
assert len(df_full) == len(training_set)
# Get train split
df_train = training_set.to_dataframe(split="train")
assert len(df_train) < len(df_full)
assert all(training_set.split[df_train.index] == "train")
# Get test split
df_test = training_set.to_dataframe(split="test")
assert len(df_test) < len(df_full)
assert all(training_set.split[df_test.index] == "test")
# Train + test should equal full
assert len(df_train) + len(df_test) == len(df_full)
def test_different_targets(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that different target datasets can be used."""
for target in all_target_datasets:
try:
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary",
target=target,
device="cpu",
cache_mode="none",
)
assert len(training_set) > 0
except Exception as e:
pytest.skip(f"Target {target} not available: {e}")
def test_feature_names_property(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that feature_names property returns list of feature names."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
feature_names: list[str] = training_set.feature_names
assert isinstance(feature_names, list)
assert len(feature_names) == training_set.features.shape[1]
assert all(isinstance(name, str) for name in feature_names)