entropice/test_feature_extraction.py

146 lines
4.7 KiB
Python

"""Test script to verify feature extraction works correctly."""
import numpy as np
import xarray as xr
# Create a mock model state with various feature types
features = [
# Embedding features: embedding_{agg}_{band}_{year}
"embedding_mean_B02_2020",
"embedding_std_B03_2021",
"embedding_max_B04_2022",
# ERA5 features without aggregations: era5_{variable}_{time}
"era5_temperature_2020_summer",
"era5_precipitation_2021_winter",
# ERA5 features with aggregations: era5_{variable}_{time}_{agg}
"era5_temperature_2020_summer_mean",
"era5_precipitation_2021_winter_std",
# ArcticDEM features: arcticdem_{variable}_{agg}
"arcticdem_elevation_mean",
"arcticdem_slope_std",
"arcticdem_aspect_max",
# Common features
"cell_area",
"water_area",
"land_area",
"land_ratio",
"lon",
"lat",
]
# Create mock importance values
importance_values = np.random.rand(len(features))
# Create a mock model state for ESPA
model_state_espa = xr.Dataset(
{
"feature_weights": xr.DataArray(
importance_values,
dims=["feature"],
coords={"feature": features},
)
}
)
# Create a mock model state for XGBoost
model_state_xgb = xr.Dataset(
{
"feature_importance_gain": xr.DataArray(
importance_values,
dims=["feature"],
coords={"feature": features},
),
"feature_importance_weight": xr.DataArray(
importance_values * 0.8,
dims=["feature"],
coords={"feature": features},
),
}
)
# Create a mock model state for Random Forest
model_state_rf = xr.Dataset(
{
"feature_importance": xr.DataArray(
importance_values,
dims=["feature"],
coords={"feature": features},
)
}
)
# Test extraction functions
from entropice.dashboard.utils.data import (
extract_arcticdem_features,
extract_common_features,
extract_embedding_features,
extract_era5_features,
)
print("=" * 80)
print("Testing ESPA model state")
print("=" * 80)
embedding_array = extract_embedding_features(model_state_espa)
print(f"\nEmbedding features extracted: {embedding_array is not None}")
if embedding_array is not None:
print(f" Dimensions: {embedding_array.dims}")
print(f" Shape: {embedding_array.shape}")
print(f" Coordinates: {list(embedding_array.coords)}")
era5_array = extract_era5_features(model_state_espa)
print(f"\nERA5 features extracted: {era5_array is not None}")
if era5_array is not None:
print(f" Dimensions: {era5_array.dims}")
print(f" Shape: {era5_array.shape}")
print(f" Coordinates: {list(era5_array.coords)}")
arcticdem_array = extract_arcticdem_features(model_state_espa)
print(f"\nArcticDEM features extracted: {arcticdem_array is not None}")
if arcticdem_array is not None:
print(f" Dimensions: {arcticdem_array.dims}")
print(f" Shape: {arcticdem_array.shape}")
print(f" Coordinates: {list(arcticdem_array.coords)}")
common_array = extract_common_features(model_state_espa)
print(f"\nCommon features extracted: {common_array is not None}")
if common_array is not None:
print(f" Dimensions: {common_array.dims}")
print(f" Shape: {common_array.shape}")
print(f" Size: {common_array.size}")
print("\n" + "=" * 80)
print("Testing XGBoost model state")
print("=" * 80)
embedding_array_xgb = extract_embedding_features(model_state_xgb, importance_type="feature_importance_gain")
print(f"\nEmbedding features (gain) extracted: {embedding_array_xgb is not None}")
if embedding_array_xgb is not None:
print(f" Dimensions: {embedding_array_xgb.dims}")
print(f" Shape: {embedding_array_xgb.shape}")
era5_array_xgb = extract_era5_features(model_state_xgb, importance_type="feature_importance_weight")
print(f"\nERA5 features (weight) extracted: {era5_array_xgb is not None}")
if era5_array_xgb is not None:
print(f" Dimensions: {era5_array_xgb.dims}")
print(f" Shape: {era5_array_xgb.shape}")
print("\n" + "=" * 80)
print("Testing Random Forest model state")
print("=" * 80)
embedding_array_rf = extract_embedding_features(model_state_rf, importance_type="feature_importance")
print(f"\nEmbedding features extracted: {embedding_array_rf is not None}")
if embedding_array_rf is not None:
print(f" Dimensions: {embedding_array_rf.dims}")
print(f" Shape: {embedding_array_rf.shape}")
arcticdem_array_rf = extract_arcticdem_features(model_state_rf, importance_type="feature_importance")
print(f"\nArcticDEM features extracted: {arcticdem_array_rf is not None}")
if arcticdem_array_rf is not None:
print(f" Dimensions: {arcticdem_array_rf.dims}")
print(f" Shape: {arcticdem_array_rf.shape}")
print("\n" + "=" * 80)
print("All tests completed successfully!")
print("=" * 80)