Start redoing the dashboard
This commit is contained in:
parent
d22b857722
commit
f5ea72e05e
22 changed files with 2610 additions and 1448 deletions
|
|
@ -75,62 +75,14 @@ Since the resolution of the ERA5 dataset is spatially smaller than the resolutio
|
||||||
|
|
||||||
For geometries crossing the antimeridian, geometries are corrected.
|
For geometries crossing the antimeridian, geometries are corrected.
|
||||||
|
|
||||||
| grid | method |
|
| grid | method | ~#pixel |
|
||||||
| ----- | ----------- |
|
| ----- | ----------- | ------------ |
|
||||||
| Hex3 | Common |
|
| Hex3 | Common | 235 [30,850] |
|
||||||
| Hex4 | Mean |
|
| Hex4 | Common | 44 [8,166] |
|
||||||
| Hex5 | Interpolate |
|
| Hex5 | Mean-only | 11 [3,41] |
|
||||||
| Hex6 | Interpolate |
|
| Hex6 | Interpolate | 4 [2,14] |
|
||||||
| Hpx6 | Common |
|
| Hpx6 | Common | 204 [25,769] |
|
||||||
| Hpx7 | Mean |
|
| Hpx7 | Common | 62 [9,231] |
|
||||||
| Hpx8 | Mean |
|
| Hpx8 | Mean-only | 21 [4,75] |
|
||||||
| Hpx9 | Interpolate |
|
| Hpx9 | Mean-only | 9 [2,29] |
|
||||||
| Hpx10 | Interpolate |
|
| Hpx10 | Interpolate | 2 [2,15] |
|
||||||
|
|
||||||
- hex level 3
|
|
||||||
min: 30.0
|
|
||||||
max: 850.0
|
|
||||||
mean: 251.25216674804688
|
|
||||||
median: 235.5
|
|
||||||
- hex level 4
|
|
||||||
min: 8.0
|
|
||||||
max: 166.0
|
|
||||||
mean: 47.2462158203125
|
|
||||||
median: 44.0
|
|
||||||
- hex level 5
|
|
||||||
min: 3.0
|
|
||||||
max: 41.0
|
|
||||||
mean: 11.164162635803223
|
|
||||||
median: 10.0
|
|
||||||
- hex level 6
|
|
||||||
min: 2.0
|
|
||||||
max: 14.0
|
|
||||||
mean: 4.509947776794434
|
|
||||||
median: 4.0
|
|
||||||
- healpix level 6
|
|
||||||
min: 25.0
|
|
||||||
max: 769.0
|
|
||||||
mean: 214.97296142578125
|
|
||||||
median: 204.0
|
|
||||||
healpix level 7
|
|
||||||
min: 9.0
|
|
||||||
max: 231.0
|
|
||||||
mean: 65.91140747070312
|
|
||||||
median: 62.0
|
|
||||||
healpix level 8
|
|
||||||
min: 4.0
|
|
||||||
max: 75.0
|
|
||||||
mean: 22.516725540161133
|
|
||||||
median: 21.0
|
|
||||||
healpix level 9
|
|
||||||
min: 2.0
|
|
||||||
max: 29.0
|
|
||||||
mean: 8.952080726623535
|
|
||||||
median: 9.0
|
|
||||||
healpix level 10
|
|
||||||
min: 2.0
|
|
||||||
max: 15.0
|
|
||||||
mean: 4.361577987670898
|
|
||||||
median: 4.0
|
|
||||||
|
|
||||||
???
|
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,10 @@ dependencies = [
|
||||||
"xarray-histogram>=0.2.2,<0.3",
|
"xarray-histogram>=0.2.2,<0.3",
|
||||||
"antimeridian>=0.4.5,<0.5",
|
"antimeridian>=0.4.5,<0.5",
|
||||||
"duckdb>=1.4.2,<2",
|
"duckdb>=1.4.2,<2",
|
||||||
|
"pydeck>=0.9.1,<0.10",
|
||||||
|
"pypalettes>=0.2.1,<0.3",
|
||||||
|
"ty>=0.0.2,<0.0.3",
|
||||||
|
"ruff>=0.14.9,<0.15", "pandas-stubs>=2.3.3.251201,<3",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
@ -70,8 +74,7 @@ darts = "entropice.darts:cli"
|
||||||
alpha-earth = "entropice.alphaearth:main"
|
alpha-earth = "entropice.alphaearth:main"
|
||||||
era5 = "entropice.era5:cli"
|
era5 = "entropice.era5:cli"
|
||||||
arcticdem = "entropice.arcticdem:cli"
|
arcticdem = "entropice.arcticdem:cli"
|
||||||
train = "entropice.training:main"
|
train = "entropice.training:cli"
|
||||||
dataset = "entropice.dataset:main"
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
|
|
@ -126,7 +129,7 @@ entropice = { path = ".", editable = true }
|
||||||
dashboard = { cmd = [
|
dashboard = { cmd = [
|
||||||
"streamlit",
|
"streamlit",
|
||||||
"run",
|
"run",
|
||||||
"src/entropice/training_analysis_dashboard.py",
|
"src/entropice/dashboard/app.py",
|
||||||
"--server.port",
|
"--server.port",
|
||||||
"8501",
|
"8501",
|
||||||
"--server.address",
|
"--server.address",
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ import geopandas as gpd
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import odc.geo.geobox
|
import odc.geo.geobox
|
||||||
|
import odc.geo.types
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import psutil
|
import psutil
|
||||||
import shapely
|
import shapely
|
||||||
|
|
@ -186,7 +187,7 @@ def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
|
||||||
x, y = enclosing.shape
|
x, y = enclosing.shape
|
||||||
if x <= 1 or y <= 1:
|
if x <= 1 or y <= 1:
|
||||||
return False
|
return False
|
||||||
roi: tuple[slice, slice] = geobox.overlap_roi(enclosing)
|
roi: odc.geo.types.NormalizedROI = geobox.overlap_roi(enclosing)
|
||||||
roix, roiy = roi
|
roix, roiy = roi
|
||||||
return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1
|
return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1
|
||||||
|
|
||||||
|
|
@ -216,7 +217,7 @@ def _extract_cell_data(cropped: xr.Dataset | xr.DataArray, aggregations: _Aggreg
|
||||||
|
|
||||||
|
|
||||||
@stopwatch("Extracting split cell data", log=False)
|
@stopwatch("Extracting split cell data", log=False)
|
||||||
def _extract_split_cell_data(cropped_list: list[xr.Dataset | xr.DataArray], aggregations: _Aggregations):
|
def _extract_split_cell_data(cropped_list: list[xr.Dataset] | list[xr.DataArray], aggregations: _Aggregations):
|
||||||
spatdims = (
|
spatdims = (
|
||||||
["latitude", "longitude"]
|
["latitude", "longitude"]
|
||||||
if "latitude" in cropped_list[0].dims and "longitude" in cropped_list[0].dims
|
if "latitude" in cropped_list[0].dims and "longitude" in cropped_list[0].dims
|
||||||
|
|
@ -370,6 +371,7 @@ def _align_partition(
|
||||||
# => There is a shift towards step 2 being the bottleneck for higher resolution grids, thus a simple loop becomes
|
# => There is a shift towards step 2 being the bottleneck for higher resolution grids, thus a simple loop becomes
|
||||||
# faster than a processpoolexecutor
|
# faster than a processpoolexecutor
|
||||||
|
|
||||||
|
assert memprof is not None, "Memory profiler is not initialized in worker"
|
||||||
memprof.log_memory("Before reading partial raster", log=False)
|
memprof.log_memory("Before reading partial raster", log=False)
|
||||||
|
|
||||||
need_to_close_raster = False
|
need_to_close_raster = False
|
||||||
|
|
@ -513,6 +515,7 @@ def _align_data(
|
||||||
n_partitions = len(grid_gdf)
|
n_partitions = len(grid_gdf)
|
||||||
grid_partitions = grid_gdf
|
grid_partitions = grid_gdf
|
||||||
else:
|
else:
|
||||||
|
assert n_partitions is not None, "n_partitions must be provided when grid_gdf is not a list"
|
||||||
grid_partitions = partition_grid(grid_gdf, n_partitions)
|
grid_partitions = partition_grid(grid_gdf, n_partitions)
|
||||||
|
|
||||||
if n_partitions < concurrent_partitions:
|
if n_partitions < concurrent_partitions:
|
||||||
|
|
@ -539,7 +542,8 @@ def _align_data(
|
||||||
# For spawn or forkserver, we need to copy the raster into each worker
|
# For spawn or forkserver, we need to copy the raster into each worker
|
||||||
workerargs = (None if not is_raster_in_memory or not is_mpfork else raster,)
|
workerargs = (None if not is_raster_in_memory or not is_mpfork else raster,)
|
||||||
# For mp start method fork, we can share the raster dataset between workers
|
# For mp start method fork, we can share the raster dataset between workers
|
||||||
if mp.get_start_method(allow_none=True) == "fork" and is_raster_in_memory:
|
if is_mpfork and is_raster_in_memory:
|
||||||
|
assert isinstance(raster, xr.Dataset) # satisfy type checker, but this is already checked above
|
||||||
_init_worker(raster)
|
_init_worker(raster)
|
||||||
|
|
||||||
with ProcessPoolExecutor(
|
with ProcessPoolExecutor(
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,7 @@ def download(grid: Literal["hex", "healpix"], level: int):
|
||||||
scale_factor = scale_factors[grid][level]
|
scale_factor = scale_factors[grid][level]
|
||||||
print(f"Using scale factor of {scale_factor} for grid {grid} at level {level}.")
|
print(f"Using scale factor of {scale_factor} for grid {grid} at level {level}.")
|
||||||
|
|
||||||
for year in track(range(2018, 2025), total=7, description="Processing years..."):
|
for year in track(range(2021, 2025), total=4, description="Processing years..."):
|
||||||
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
|
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
|
||||||
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
|
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
|
||||||
aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"]
|
aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"]
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ import smart_geocubes
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
import xdggs
|
import xdggs
|
||||||
import xrspatial
|
import xrspatial
|
||||||
|
import xrspatial.convolution
|
||||||
import zarr
|
import zarr
|
||||||
from cupyx.scipy.ndimage import binary_dilation, binary_erosion, distance_transform_edt
|
from cupyx.scipy.ndimage import binary_dilation, binary_erosion, distance_transform_edt
|
||||||
from rich import pretty, print, traceback
|
from rich import pretty, print, traceback
|
||||||
|
|
@ -116,7 +117,8 @@ def ruggedness_cupy(chunk, slope, aspect, kernels: _KernelFactory):
|
||||||
return vrm
|
return vrm
|
||||||
|
|
||||||
|
|
||||||
def _get_xy_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> tuple[cp.array, cp.array]:
|
def _get_xy_chunk(chunk: np.ndarray, x: np.ndarray, y: np.ndarray, block_info=None) -> tuple[cp.ndarray, cp.ndarray]:
|
||||||
|
assert isinstance(block_info, list) and len(block_info) >= 1
|
||||||
chunk_loc = block_info[0]["chunk-location"]
|
chunk_loc = block_info[0]["chunk-location"]
|
||||||
d = 15
|
d = 15
|
||||||
cs = 3600
|
cs = 3600
|
||||||
|
|
@ -149,7 +151,7 @@ def _get_xy_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) ->
|
||||||
return xx, yy
|
return xx, yy
|
||||||
|
|
||||||
|
|
||||||
def _enrich_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> np.array:
|
def _enrich_chunk(chunk: np.ndarray, x: np.ndarray, y: np.ndarray, block_info=None) -> np.ndarray:
|
||||||
res = 32 # 32m resolution
|
res = 32 # 32m resolution
|
||||||
small_kernels = _KernelFactory(res=res, size_px=3) # ~3x3 kernels (96m)
|
small_kernels = _KernelFactory(res=res, size_px=3) # ~3x3 kernels (96m)
|
||||||
medium_kernels = _KernelFactory(res=res, size_px=7) # ~7x7 kernels (224m)
|
medium_kernels = _KernelFactory(res=res, size_px=7) # ~7x7 kernels (224m)
|
||||||
|
|
|
||||||
|
|
@ -63,10 +63,15 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply corrections to NaNs
|
# Apply corrections to NaNs
|
||||||
covered = ~grid_gdf[f"darts_{year}_coverage"].isnull()
|
covered = ~grid_gdf[f"darts_{year}_coverage"].isna()
|
||||||
grid_gdf.loc[covered, f"darts_{year}_rts_count"] = grid_gdf.loc[covered, f"darts_{year}_rts_count"].fillna(
|
grid_gdf.loc[covered, f"darts_{year}_rts_count"] = grid_gdf.loc[covered, f"darts_{year}_rts_count"].fillna(
|
||||||
0.0
|
0.0
|
||||||
)
|
)
|
||||||
|
grid_gdf.loc[covered, f"darts_{year}_rts_density"] = grid_gdf.loc[
|
||||||
|
covered, f"darts_{year}_rts_density"
|
||||||
|
].fillna(0.0)
|
||||||
|
grid_gdf[f"darts_{year}_has_coverage"] = covered
|
||||||
|
grid_gdf[f"darts_{year}_has_rts"] = grid_gdf[f"darts_{year}_rts_count"] > 0
|
||||||
|
|
||||||
grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1)
|
grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1)
|
||||||
grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1)
|
grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1)
|
||||||
|
|
@ -128,9 +133,10 @@ def extract_darts_mllabels(grid: Literal["hex", "healpix"], level: int):
|
||||||
# Apply corrections to NaNs
|
# Apply corrections to NaNs
|
||||||
covered = ~grid_gdf["dartsml_coverage"].isna()
|
covered = ~grid_gdf["dartsml_coverage"].isna()
|
||||||
grid_gdf.loc[covered, "dartsml_rts_count"] = grid_gdf.loc[covered, "dartsml_rts_count"].fillna(0.0)
|
grid_gdf.loc[covered, "dartsml_rts_count"] = grid_gdf.loc[covered, "dartsml_rts_count"].fillna(0.0)
|
||||||
|
grid_gdf.loc[covered, "dartsml_rts_density"] = grid_gdf.loc[covered, "dartsml_rts_density"].fillna(0.0)
|
||||||
|
|
||||||
grid_gdf["dartsml_has_coverage"] = ~grid_gdf["dartsml_coverage"].isna()
|
grid_gdf["dartsml_has_coverage"] = covered
|
||||||
grid_gdf["dartsml_has_rts"] = ~grid_gdf["dartsml_rts_count"].isna()
|
grid_gdf["dartsml_has_rts"] = grid_gdf["dartsml_rts_count"] > 0
|
||||||
|
|
||||||
output_path = get_darts_rts_file(grid, level, labels=True)
|
output_path = get_darts_rts_file(grid, level, labels=True)
|
||||||
grid_gdf.to_parquet(output_path)
|
grid_gdf.to_parquet(output_path)
|
||||||
|
|
|
||||||
0
src/entropice/dashboard/__init__.py
Normal file
0
src/entropice/dashboard/__init__.py
Normal file
44
src/entropice/dashboard/app.py
Normal file
44
src/entropice/dashboard/app.py
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
"""Streamlit app for Entropice dashboard.
|
||||||
|
|
||||||
|
Pages:
|
||||||
|
|
||||||
|
- Overview: List of available result directories with some summary statistics.
|
||||||
|
- Training Data: Visualization of training data distributions.
|
||||||
|
- Training Results Analysis: Analysis of training results and model performance.
|
||||||
|
- Model State: Visualization of model state and features.
|
||||||
|
- Inference: Visualization of inference results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.inference_page import render_inference_page
|
||||||
|
from entropice.dashboard.model_state_page import render_model_state_page
|
||||||
|
from entropice.dashboard.overview_page import render_overview_page
|
||||||
|
from entropice.dashboard.training_analysis_page import render_training_analysis_page
|
||||||
|
from entropice.dashboard.training_data_page import render_training_data_page
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run the dashboard."""
|
||||||
|
st.set_page_config(page_title="Entropice Dashboard", layout="wide")
|
||||||
|
|
||||||
|
# Setup Navigation
|
||||||
|
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
|
||||||
|
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
|
||||||
|
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
||||||
|
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
|
||||||
|
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
|
||||||
|
|
||||||
|
pg = st.navigation(
|
||||||
|
{
|
||||||
|
"Overview": [overview_page],
|
||||||
|
"Training": [training_data_page, training_analysis_page],
|
||||||
|
"Model State": [model_state_page],
|
||||||
|
"Inference": [inference_page],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
pg.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
8
src/entropice/dashboard/inference_page.py
Normal file
8
src/entropice/dashboard/inference_page.py
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
|
||||||
|
def render_inference_page():
|
||||||
|
"""Render the Inference page of the dashboard."""
|
||||||
|
st.title("Inference Results")
|
||||||
|
st.write("This page will display inference results and visualizations.")
|
||||||
|
# Add more components and visualizations as needed for inference results.
|
||||||
8
src/entropice/dashboard/model_state_page.py
Normal file
8
src/entropice/dashboard/model_state_page.py
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
|
||||||
|
def render_model_state_page():
|
||||||
|
"""Render the Model State page of the dashboard."""
|
||||||
|
st.title("Model State")
|
||||||
|
st.write("This page will display model state and feature visualizations.")
|
||||||
|
# Add more components and visualizations as needed for model state.
|
||||||
155
src/entropice/dashboard/overview_page.py
Normal file
155
src/entropice/dashboard/overview_page.py
Normal file
|
|
@ -0,0 +1,155 @@
|
||||||
|
"""Overview page: List of available result directories with some summary statistics."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.utils.data import load_all_training_results
|
||||||
|
|
||||||
|
|
||||||
|
def render_overview_page():
|
||||||
|
"""Render the Overview page of the dashboard."""
|
||||||
|
st.title("🏡 Training Results Overview")
|
||||||
|
|
||||||
|
training_results = load_all_training_results()
|
||||||
|
|
||||||
|
if not training_results:
|
||||||
|
st.warning("No training results found. Please run some training experiments first.")
|
||||||
|
return
|
||||||
|
|
||||||
|
st.write(f"Found **{len(training_results)}** training result(s)")
|
||||||
|
|
||||||
|
# Summary statistics at the top
|
||||||
|
st.subheader("Summary Statistics")
|
||||||
|
col1, col2, col3, col4 = st.columns(4)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
tasks = {tr.settings.get("task", "Unknown") for tr in training_results}
|
||||||
|
st.metric("Tasks", len(tasks))
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
grids = {tr.settings.get("grid", "Unknown") for tr in training_results}
|
||||||
|
st.metric("Grid Types", len(grids))
|
||||||
|
|
||||||
|
with col3:
|
||||||
|
models = {tr.settings.get("model", "Unknown") for tr in training_results}
|
||||||
|
st.metric("Model Types", len(models))
|
||||||
|
|
||||||
|
with col4:
|
||||||
|
latest = training_results[0] # Already sorted by creation time
|
||||||
|
latest_date = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d")
|
||||||
|
st.metric("Latest Run", latest_date)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
# Detailed results table
|
||||||
|
st.subheader("Training Results")
|
||||||
|
|
||||||
|
# Build a summary dataframe
|
||||||
|
summary_data = []
|
||||||
|
for tr in training_results:
|
||||||
|
# Extract best scores from the results dataframe
|
||||||
|
score_cols = [col for col in tr.results.columns if col.startswith("mean_test_")]
|
||||||
|
|
||||||
|
best_scores = {}
|
||||||
|
for col in score_cols:
|
||||||
|
metric_name = col.replace("mean_test_", "")
|
||||||
|
best_score = tr.results[col].max()
|
||||||
|
best_scores[metric_name] = best_score
|
||||||
|
|
||||||
|
# Get primary metric (usually the first one or accuracy)
|
||||||
|
primary_metric = (
|
||||||
|
"accuracy"
|
||||||
|
if "mean_test_accuracy" in tr.results.columns
|
||||||
|
else score_cols[0].replace("mean_test_", "")
|
||||||
|
if score_cols
|
||||||
|
else "N/A"
|
||||||
|
)
|
||||||
|
primary_score = best_scores.get(primary_metric, 0.0)
|
||||||
|
|
||||||
|
summary_data.append(
|
||||||
|
{
|
||||||
|
"Date": datetime.fromtimestamp(tr.created_at).strftime("%Y-%m-%d %H:%M"),
|
||||||
|
"Task": tr.settings.get("task", "Unknown"),
|
||||||
|
"Grid": tr.settings.get("grid", "Unknown"),
|
||||||
|
"Level": tr.settings.get("level", "Unknown"),
|
||||||
|
"Model": tr.settings.get("model", "Unknown"),
|
||||||
|
f"Best {primary_metric.title()}": f"{primary_score:.4f}",
|
||||||
|
"Trials": len(tr.results),
|
||||||
|
"Path": str(tr.path.name),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
summary_df = pd.DataFrame(summary_data)
|
||||||
|
|
||||||
|
# Display with color coding for best scores
|
||||||
|
st.dataframe(
|
||||||
|
summary_df,
|
||||||
|
width="stretch",
|
||||||
|
hide_index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
# Expandable details for each result
|
||||||
|
st.subheader("Detailed Results")
|
||||||
|
|
||||||
|
for tr in training_results:
|
||||||
|
with st.expander(tr.name):
|
||||||
|
col1, col2 = st.columns([1, 2])
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.write("**Configuration:**")
|
||||||
|
st.write(f"- **Task:** {tr.settings.get('task', 'Unknown')}")
|
||||||
|
st.write(f"- **Grid:** {tr.settings.get('grid', 'Unknown')}")
|
||||||
|
st.write(f"- **Level:** {tr.settings.get('level', 'Unknown')}")
|
||||||
|
st.write(f"- **Model:** {tr.settings.get('model', 'Unknown')}")
|
||||||
|
st.write(f"- **CV Splits:** {tr.settings.get('cv_splits', 'Unknown')}")
|
||||||
|
st.write(f"- **Classes:** {tr.settings.get('classes', 'Unknown')}")
|
||||||
|
|
||||||
|
st.write("\n**Files:**")
|
||||||
|
st.write("- 📊 search_results.parquet")
|
||||||
|
st.write("- 🧮 best_estimator_state.nc")
|
||||||
|
st.write("- 🎯 predicted_probabilities.parquet")
|
||||||
|
st.write("- ⚙️ search_settings.toml")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.write("**Best Scores:**")
|
||||||
|
|
||||||
|
# Extract all test scores
|
||||||
|
score_cols = [col for col in tr.results.columns if col.startswith("mean_test_")]
|
||||||
|
|
||||||
|
if score_cols:
|
||||||
|
metric_data = []
|
||||||
|
for col in score_cols:
|
||||||
|
metric_name = col.replace("mean_test_", "").title()
|
||||||
|
best_score = tr.results[col].max()
|
||||||
|
mean_score = tr.results[col].mean()
|
||||||
|
std_score = tr.results[col].std()
|
||||||
|
|
||||||
|
metric_data.append(
|
||||||
|
{
|
||||||
|
"Metric": metric_name,
|
||||||
|
"Best": f"{best_score:.4f}",
|
||||||
|
"Mean": f"{mean_score:.4f}",
|
||||||
|
"Std": f"{std_score:.4f}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
metric_df = pd.DataFrame(metric_data)
|
||||||
|
st.dataframe(metric_df, width="stretch", hide_index=True)
|
||||||
|
else:
|
||||||
|
st.write("No test scores found in results.")
|
||||||
|
|
||||||
|
# Show parameter space explored
|
||||||
|
if "initial_K" in tr.results.columns: # Common parameter
|
||||||
|
st.write("\n**Parameter Ranges Explored:**")
|
||||||
|
for param in ["initial_K", "eps_cl", "eps_e"]:
|
||||||
|
if param in tr.results.columns:
|
||||||
|
min_val = tr.results[param].min()
|
||||||
|
max_val = tr.results[param].max()
|
||||||
|
unique_vals = tr.results[param].nunique()
|
||||||
|
st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
|
||||||
|
|
||||||
|
st.write(f"\n**Path:** `{tr.path}`")
|
||||||
92
src/entropice/dashboard/plots/colors.py
Normal file
92
src/entropice/dashboard/plots/colors.py
Normal file
|
|
@ -0,0 +1,92 @@
|
||||||
|
"""Color related utilities for dashboard plots.
|
||||||
|
|
||||||
|
Color palettes from https://python-graph-gallery.com/color-palette-finder/
|
||||||
|
|
||||||
|
Material palettes:
|
||||||
|
- amber_material
|
||||||
|
- blue_grey_material
|
||||||
|
- blue_material
|
||||||
|
- brown_material
|
||||||
|
- cyan_material
|
||||||
|
- deep_orange_material
|
||||||
|
- deep_purple_material
|
||||||
|
- green_material
|
||||||
|
- grey_material
|
||||||
|
- indigo_material
|
||||||
|
- light_blue_material
|
||||||
|
- light_green_material
|
||||||
|
- lime_material
|
||||||
|
- orange_material
|
||||||
|
- pink_material
|
||||||
|
- purple_material
|
||||||
|
- red_material
|
||||||
|
- teal_material
|
||||||
|
- yellow_material
|
||||||
|
"""
|
||||||
|
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
from pypalettes import load_cmap
|
||||||
|
|
||||||
|
|
||||||
|
def get_cmap(variable: str) -> mcolors.Colormap:
|
||||||
|
"""Get a color palette by a "data" variable.
|
||||||
|
|
||||||
|
Each variable (meaning of the data) should be associated with another color palette when plotting.
|
||||||
|
This function should help to standardize the color palettes used for each variable type.
|
||||||
|
|
||||||
|
The variable can be any string, descriptive names are recommended.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
variable: The variable to load a palette for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of hex color strings.
|
||||||
|
|
||||||
|
"""
|
||||||
|
material_palettes = [
|
||||||
|
"amber_material",
|
||||||
|
"blue_grey_material",
|
||||||
|
"blue_material",
|
||||||
|
"brown_material",
|
||||||
|
"cyan_material",
|
||||||
|
"deep_orange_material",
|
||||||
|
"deep_purple_material",
|
||||||
|
"green_material",
|
||||||
|
"grey_material",
|
||||||
|
"indigo_material",
|
||||||
|
"light_blue_material",
|
||||||
|
"light_green_material",
|
||||||
|
"lime_material",
|
||||||
|
"orange_material",
|
||||||
|
"pink_material",
|
||||||
|
"purple_material",
|
||||||
|
"red_material",
|
||||||
|
"teal_material",
|
||||||
|
"yellow_material",
|
||||||
|
]
|
||||||
|
# Fuzzy map from variable type to palette name
|
||||||
|
material_idx = sum(ord(c) for c in variable) % len(material_palettes)
|
||||||
|
palette_name = material_palettes[material_idx]
|
||||||
|
cmap = load_cmap(name=palette_name)
|
||||||
|
return cmap
|
||||||
|
|
||||||
|
|
||||||
|
def get_palette(variable: str, n_colors: int) -> list[str]:
|
||||||
|
"""Get a color palette by a "data" variable.
|
||||||
|
|
||||||
|
Each variable (meaning of the data) should be associated with another color palette when plotting.
|
||||||
|
This function should help to standardize the color palettes used for each variable type.
|
||||||
|
|
||||||
|
The variable can be any string, descriptive names are recommended.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
variable: The variable to load a palette for.
|
||||||
|
n_colors: The number of colors to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of hex color strings.
|
||||||
|
|
||||||
|
"""
|
||||||
|
cmap = get_cmap(variable).resampled(n_colors)
|
||||||
|
colors = [mcolors.to_hex(cmap(i)) for i in range(cmap.N)]
|
||||||
|
return colors
|
||||||
356
src/entropice/dashboard/plots/training_data.py
Normal file
356
src/entropice/dashboard/plots/training_data.py
Normal file
|
|
@ -0,0 +1,356 @@
|
||||||
|
"""Plotting functions for training data visualizations."""
|
||||||
|
|
||||||
|
import geopandas as gpd
|
||||||
|
import pandas as pd
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
import pydeck as pdk
|
||||||
|
import streamlit as st
|
||||||
|
from shapely.geometry import shape
|
||||||
|
|
||||||
|
from entropice.dashboard.plots.colors import get_palette
|
||||||
|
from entropice.dataset import CategoricalTrainingDataset
|
||||||
|
|
||||||
|
|
||||||
|
def render_all_distribution_histograms(train_data_dict: dict[str, CategoricalTrainingDataset]):
|
||||||
|
"""Render histograms for all three tasks side by side.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data_dict: Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.subheader("📊 Target Distribution by Task")
|
||||||
|
|
||||||
|
# Create a 3-column layout for the three tasks
|
||||||
|
cols = st.columns(3)
|
||||||
|
|
||||||
|
tasks = ["binary", "count", "density"]
|
||||||
|
task_titles = {
|
||||||
|
"binary": "Binary Classification",
|
||||||
|
"count": "Count Classification",
|
||||||
|
"density": "Density Classification",
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, task in enumerate(tasks):
|
||||||
|
dataset = train_data_dict[task]
|
||||||
|
categories = dataset.y.binned.cat.categories.tolist()
|
||||||
|
colors = get_palette(task, len(categories))
|
||||||
|
|
||||||
|
with cols[idx]:
|
||||||
|
st.markdown(f"**{task_titles[task]}**")
|
||||||
|
|
||||||
|
# Create histogram data
|
||||||
|
counts_df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"Category": categories,
|
||||||
|
"Train": [((dataset.y.binned == cat) & (dataset.split == "train")).sum() for cat in categories],
|
||||||
|
"Test": [((dataset.y.binned == cat) & (dataset.split == "test")).sum() for cat in categories],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create stacked bar chart
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
fig.add_trace(
|
||||||
|
go.Bar(
|
||||||
|
name="Train",
|
||||||
|
x=counts_df["Category"],
|
||||||
|
y=counts_df["Train"],
|
||||||
|
marker_color=colors,
|
||||||
|
opacity=0.9,
|
||||||
|
text=counts_df["Train"],
|
||||||
|
textposition="inside",
|
||||||
|
textfont={"size": 10, "color": "white"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.add_trace(
|
||||||
|
go.Bar(
|
||||||
|
name="Test",
|
||||||
|
x=counts_df["Category"],
|
||||||
|
y=counts_df["Test"],
|
||||||
|
marker_color=colors,
|
||||||
|
opacity=0.6,
|
||||||
|
text=counts_df["Test"],
|
||||||
|
textposition="inside",
|
||||||
|
textfont={"size": 10, "color": "white"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
barmode="group",
|
||||||
|
height=400,
|
||||||
|
margin={"l": 20, "r": 20, "t": 20, "b": 20},
|
||||||
|
showlegend=True,
|
||||||
|
legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1},
|
||||||
|
xaxis_title=None,
|
||||||
|
yaxis_title="Count",
|
||||||
|
xaxis={"tickangle": -45},
|
||||||
|
)
|
||||||
|
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
|
||||||
|
# Show summary statistics
|
||||||
|
total = len(dataset)
|
||||||
|
train_pct = (dataset.split == "train").sum() / total * 100
|
||||||
|
test_pct = (dataset.split == "test").sum() / total * 100
|
||||||
|
|
||||||
|
st.caption(f"Total: {total:,} | Train: {train_pct:.1f}% | Test: {test_pct:.1f}%")
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_hex_geometry(geom):
|
||||||
|
"""Fix hexagon geometry crossing the antimeridian."""
|
||||||
|
import antimeridian
|
||||||
|
|
||||||
|
try:
|
||||||
|
return shape(antimeridian.fix_shape(geom))
|
||||||
|
except ValueError as e:
|
||||||
|
st.error(f"Error fixing geometry: {e}")
|
||||||
|
return geom
|
||||||
|
|
||||||
|
|
||||||
|
def _assign_colors_by_mode(gdf, color_mode, dataset, selected_task):
|
||||||
|
"""Assign colors to geodataframe based on the selected color mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gdf: GeoDataFrame to add colors to
|
||||||
|
color_mode: One of 'target_class' or 'split'
|
||||||
|
dataset: CategoricalTrainingDataset
|
||||||
|
selected_task: Task name for color palette selection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GeoDataFrame with 'fill_color' column added
|
||||||
|
|
||||||
|
"""
|
||||||
|
if color_mode == "target_class":
|
||||||
|
categories = dataset.y.binned.cat.categories.tolist()
|
||||||
|
colors_palette = get_palette(selected_task, len(categories))
|
||||||
|
|
||||||
|
# Create color mapping
|
||||||
|
color_map = {cat: colors_palette[i] for i, cat in enumerate(categories)}
|
||||||
|
gdf["color"] = gdf["target_class"].map(color_map)
|
||||||
|
|
||||||
|
# Convert hex colors to RGB
|
||||||
|
def hex_to_rgb(hex_color):
|
||||||
|
hex_color = hex_color.lstrip("#")
|
||||||
|
return [int(hex_color[i : i + 2], 16) for i in (0, 2, 4)]
|
||||||
|
|
||||||
|
gdf["fill_color"] = gdf["color"].apply(hex_to_rgb)
|
||||||
|
|
||||||
|
elif color_mode == "split":
|
||||||
|
split_colors = {"train": [66, 135, 245], "test": [245, 135, 66]} # Blue # Orange
|
||||||
|
gdf["fill_color"] = gdf["split"].map(split_colors)
|
||||||
|
|
||||||
|
return gdf
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]):
|
||||||
|
"""Render a pydeck spatial map showing training data distribution with interactive controls.
|
||||||
|
|
||||||
|
This is a Streamlit fragment that reruns independently when users interact with the
|
||||||
|
visualization controls (color mode and opacity), without re-running the entire page.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data_dict: Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.subheader("🗺️ Spatial Distribution Map")
|
||||||
|
|
||||||
|
# Create controls in columns
|
||||||
|
col1, col2 = st.columns([3, 1])
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
vis_mode = st.selectbox(
|
||||||
|
"Visualization mode",
|
||||||
|
options=["binary", "count", "density", "split"],
|
||||||
|
format_func=lambda x: x.capitalize() if x != "split" else "Train/Test Split",
|
||||||
|
key="spatial_map_mode",
|
||||||
|
)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="spatial_map_opacity")
|
||||||
|
|
||||||
|
# Determine which task dataset to use and color mode
|
||||||
|
if vis_mode == "split":
|
||||||
|
# Use binary dataset for split visualization
|
||||||
|
dataset = train_data_dict["binary"]
|
||||||
|
color_mode = "split"
|
||||||
|
selected_task = "binary"
|
||||||
|
else:
|
||||||
|
# Use the selected task
|
||||||
|
dataset = train_data_dict[vis_mode]
|
||||||
|
color_mode = "target_class"
|
||||||
|
selected_task = vis_mode
|
||||||
|
|
||||||
|
# Prepare data for visualization - dataset.dataset should already be a GeoDataFrame
|
||||||
|
gdf: gpd.GeoDataFrame = dataset.dataset.copy() # type: ignore[assignment]
|
||||||
|
|
||||||
|
# Fix antimeridian issues
|
||||||
|
gdf["geometry"] = gdf["geometry"].apply(_fix_hex_geometry)
|
||||||
|
|
||||||
|
# Add binned labels and split information from current dataset
|
||||||
|
gdf["target_class"] = dataset.y.binned.to_numpy()
|
||||||
|
gdf["split"] = dataset.split.to_numpy()
|
||||||
|
gdf["raw_value"] = dataset.z.to_numpy()
|
||||||
|
|
||||||
|
# Add information from all three tasks for tooltip
|
||||||
|
gdf["binary_label"] = train_data_dict["binary"].y.binned.to_numpy()
|
||||||
|
gdf["count_category"] = train_data_dict["count"].y.binned.to_numpy()
|
||||||
|
gdf["count_raw"] = train_data_dict["count"].z.to_numpy()
|
||||||
|
gdf["density_category"] = train_data_dict["density"].y.binned.to_numpy()
|
||||||
|
gdf["density_raw"] = train_data_dict["density"].z.to_numpy()
|
||||||
|
|
||||||
|
# Convert to WGS84 for pydeck
|
||||||
|
gdf_wgs84: gpd.GeoDataFrame = gdf.to_crs("EPSG:4326") # type: ignore[assignment]
|
||||||
|
|
||||||
|
# Assign colors based on the selected mode
|
||||||
|
gdf_wgs84 = _assign_colors_by_mode(gdf_wgs84, color_mode, dataset, selected_task)
|
||||||
|
|
||||||
|
# Convert to GeoJSON format and add elevation for 3D visualization
|
||||||
|
geojson_data = []
|
||||||
|
# Normalize raw values for elevation (only for count and density)
|
||||||
|
use_elevation = vis_mode in ["count", "density"]
|
||||||
|
if use_elevation:
|
||||||
|
raw_values = gdf_wgs84["raw_value"]
|
||||||
|
min_val, max_val = raw_values.min(), raw_values.max()
|
||||||
|
# Normalize to 0-1 range for better 3D visualization
|
||||||
|
if max_val > min_val:
|
||||||
|
gdf_wgs84["elevation"] = ((raw_values - min_val) / (max_val - min_val)).fillna(0)
|
||||||
|
else:
|
||||||
|
gdf_wgs84["elevation"] = 0
|
||||||
|
|
||||||
|
for _, row in gdf_wgs84.iterrows():
|
||||||
|
feature = {
|
||||||
|
"type": "Feature",
|
||||||
|
"geometry": row["geometry"].__geo_interface__,
|
||||||
|
"properties": {
|
||||||
|
"target_class": str(row["target_class"]),
|
||||||
|
"split": str(row["split"]),
|
||||||
|
"raw_value": float(row["raw_value"]),
|
||||||
|
"fill_color": row["fill_color"],
|
||||||
|
"elevation": float(row["elevation"]) if use_elevation else 0,
|
||||||
|
"binary_label": str(row["binary_label"]),
|
||||||
|
"count_category": str(row["count_category"]),
|
||||||
|
"count_raw": int(row["count_raw"]),
|
||||||
|
"density_category": str(row["density_category"]),
|
||||||
|
"density_raw": f"{float(row['density_raw']):.4f}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
geojson_data.append(feature)
|
||||||
|
|
||||||
|
# Create pydeck layer
|
||||||
|
layer = pdk.Layer(
|
||||||
|
"GeoJsonLayer",
|
||||||
|
geojson_data,
|
||||||
|
opacity=opacity,
|
||||||
|
stroked=True,
|
||||||
|
filled=True,
|
||||||
|
extruded=use_elevation,
|
||||||
|
wireframe=False,
|
||||||
|
get_fill_color="properties.fill_color",
|
||||||
|
get_line_color=[80, 80, 80],
|
||||||
|
line_width_min_pixels=0.5,
|
||||||
|
get_elevation="properties.elevation" if use_elevation else 0,
|
||||||
|
elevation_scale=500000, # Scale normalized values (0-1) to 500km height
|
||||||
|
pickable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set initial view state (centered on the Arctic)
|
||||||
|
# Adjust pitch and zoom based on whether we're using elevation
|
||||||
|
view_state = pdk.ViewState(
|
||||||
|
latitude=70, longitude=0, zoom=2 if not use_elevation else 1.5, pitch=0 if not use_elevation else 45
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create deck
|
||||||
|
deck = pdk.Deck(
|
||||||
|
layers=[layer],
|
||||||
|
initial_view_state=view_state,
|
||||||
|
tooltip={
|
||||||
|
"html": "<b>Binary:</b> {binary_label}<br/>"
|
||||||
|
"<b>Count Category:</b> {count_category}<br/>"
|
||||||
|
"<b>Count Raw:</b> {count_raw}<br/>"
|
||||||
|
"<b>Density Category:</b> {density_category}<br/>"
|
||||||
|
"<b>Density Raw:</b> {density_raw}<br/>"
|
||||||
|
"<b>Split:</b> {split}",
|
||||||
|
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||||||
|
},
|
||||||
|
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Render the map
|
||||||
|
st.pydeck_chart(deck)
|
||||||
|
|
||||||
|
# Show info about 3D visualization
|
||||||
|
if use_elevation:
|
||||||
|
st.info("💡 3D elevation represents raw values. Rotate the map by holding Ctrl/Cmd and dragging.")
|
||||||
|
|
||||||
|
# Add legend
|
||||||
|
with st.expander("Legend", expanded=True):
|
||||||
|
if color_mode == "target_class":
|
||||||
|
st.markdown("**Target Classes:**")
|
||||||
|
categories = dataset.y.binned.cat.categories.tolist()
|
||||||
|
colors_palette = get_palette(selected_task, len(categories))
|
||||||
|
intervals = dataset.y.intervals
|
||||||
|
|
||||||
|
# For count and density tasks, show intervals
|
||||||
|
if selected_task in ["count", "density"]:
|
||||||
|
for i, cat in enumerate(categories):
|
||||||
|
color = colors_palette[i]
|
||||||
|
interval_min, interval_max = intervals[i]
|
||||||
|
|
||||||
|
# Format interval display
|
||||||
|
if interval_min is None or interval_max is None:
|
||||||
|
interval_str = ""
|
||||||
|
elif selected_task == "count":
|
||||||
|
# Integer values for count
|
||||||
|
if interval_min == interval_max:
|
||||||
|
interval_str = f" ({int(interval_min)})"
|
||||||
|
else:
|
||||||
|
interval_str = f" ({int(interval_min)}-{int(interval_max)})"
|
||||||
|
else: # density
|
||||||
|
# Percentage values for density
|
||||||
|
if interval_min == interval_max:
|
||||||
|
interval_str = f" ({interval_min * 100:.4f}%)"
|
||||||
|
else:
|
||||||
|
interval_str = f" ({interval_min * 100:.4f}%-{interval_max * 100:.4f}%)"
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
|
||||||
|
f'<div style="width: 20px; height: 20px; background-color: {color}; '
|
||||||
|
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
|
||||||
|
f"<span>{cat}{interval_str}</span></div>",
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Binary task: use original column layout
|
||||||
|
legend_cols = st.columns(len(categories))
|
||||||
|
for i, cat in enumerate(categories):
|
||||||
|
with legend_cols[i]:
|
||||||
|
color = colors_palette[i]
|
||||||
|
st.markdown(
|
||||||
|
f'<div style="display: flex; align-items: center;">'
|
||||||
|
f'<div style="width: 20px; height: 20px; background-color: {color}; '
|
||||||
|
f'margin-right: 8px; border: 1px solid #ccc;"></div>'
|
||||||
|
f"<span>{cat}</span></div>",
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
if use_elevation:
|
||||||
|
st.markdown("---")
|
||||||
|
st.markdown("**Elevation (3D):**")
|
||||||
|
min_val = gdf_wgs84["raw_value"].min()
|
||||||
|
max_val = gdf_wgs84["raw_value"].max()
|
||||||
|
st.markdown(f"Height represents raw value: {min_val:.2f} (low) → {max_val:.2f} (high)")
|
||||||
|
elif color_mode == "split":
|
||||||
|
st.markdown("**Data Split:**")
|
||||||
|
legend_html = (
|
||||||
|
'<div style="display: flex; gap: 20px;">'
|
||||||
|
'<div style="display: flex; align-items: center;">'
|
||||||
|
'<div style="width: 20px; height: 20px; background-color: rgb(66, 135, 245); '
|
||||||
|
'margin-right: 8px; border: 1px solid #ccc;"></div>'
|
||||||
|
"<span>Train</span></div>"
|
||||||
|
'<div style="display: flex; align-items: center;">'
|
||||||
|
'<div style="width: 20px; height: 20px; background-color: rgb(245, 135, 66); '
|
||||||
|
'margin-right: 8px; border: 1px solid #ccc;"></div>'
|
||||||
|
"<span>Test</span></div></div>"
|
||||||
|
)
|
||||||
|
st.markdown(legend_html, unsafe_allow_html=True)
|
||||||
10
src/entropice/dashboard/training_analysis_page.py
Normal file
10
src/entropice/dashboard/training_analysis_page.py
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
"""Training Results Analysis page: Analysis of training results and model performance."""
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
|
||||||
|
def render_training_analysis_page():
|
||||||
|
"""Render the Training Results Analysis page of the dashboard."""
|
||||||
|
st.title("Training Results Analysis")
|
||||||
|
st.write("This page will display analysis of training results and model performance.")
|
||||||
|
# Add more components and visualizations as needed for training results analysis.
|
||||||
136
src/entropice/dashboard/training_data_page.py
Normal file
136
src/entropice/dashboard/training_data_page.py
Normal file
|
|
@ -0,0 +1,136 @@
|
||||||
|
"""Training Data page: Visualization of training data distributions."""
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map
|
||||||
|
from entropice.dashboard.utils.data import load_all_training_data
|
||||||
|
from entropice.dataset import DatasetEnsemble
|
||||||
|
|
||||||
|
|
||||||
|
def render_training_data_page():
|
||||||
|
"""Render the Training Data page of the dashboard."""
|
||||||
|
st.title("Training Data")
|
||||||
|
|
||||||
|
# Sidebar widgets for dataset configuration in a form
|
||||||
|
with st.sidebar.form("dataset_config_form"):
|
||||||
|
st.header("Dataset Configuration")
|
||||||
|
|
||||||
|
# Combined grid and level selection
|
||||||
|
grid_options = [
|
||||||
|
"hex-3",
|
||||||
|
"hex-4",
|
||||||
|
"hex-5",
|
||||||
|
"hex-6",
|
||||||
|
"healpix-6",
|
||||||
|
"healpix-7",
|
||||||
|
"healpix-8",
|
||||||
|
"healpix-9",
|
||||||
|
"healpix-10",
|
||||||
|
]
|
||||||
|
|
||||||
|
grid_level_combined = st.selectbox(
|
||||||
|
"Grid Configuration", options=grid_options, index=0, help="Select the grid system and resolution level"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse grid type and level
|
||||||
|
grid, level_str = grid_level_combined.split("-")
|
||||||
|
level = int(level_str)
|
||||||
|
|
||||||
|
# Target feature selection
|
||||||
|
target = st.selectbox(
|
||||||
|
"Target Feature",
|
||||||
|
options=["darts_rts", "darts_mllabels"],
|
||||||
|
index=0,
|
||||||
|
help="Select the target variable for training",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Members selection
|
||||||
|
st.subheader("Dataset Members")
|
||||||
|
|
||||||
|
# Check if AlphaEarth should be disabled
|
||||||
|
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||||
|
|
||||||
|
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
|
selected_members = []
|
||||||
|
|
||||||
|
for member in all_members:
|
||||||
|
if member == "AlphaEarth" and disable_alphaearth:
|
||||||
|
# Show disabled checkbox with explanation
|
||||||
|
st.checkbox(
|
||||||
|
member, value=False, disabled=True, help=f"AlphaEarth is not available for {grid} level {level}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if st.checkbox(member, value=True, help=f"Include {member} in the dataset"):
|
||||||
|
selected_members.append(member)
|
||||||
|
|
||||||
|
# Form submit button
|
||||||
|
load_button = st.form_submit_button(
|
||||||
|
"Load Dataset", type="primary", use_container_width=True, disabled=len(selected_members) == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create DatasetEnsemble only when form is submitted
|
||||||
|
if load_button:
|
||||||
|
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members)
|
||||||
|
# Store ensemble in session state
|
||||||
|
st.session_state["dataset_ensemble"] = ensemble
|
||||||
|
st.session_state["dataset_loaded"] = True
|
||||||
|
|
||||||
|
# Display dataset information if loaded
|
||||||
|
if st.session_state.get("dataset_loaded", False) and "dataset_ensemble" in st.session_state:
|
||||||
|
ensemble = st.session_state["dataset_ensemble"]
|
||||||
|
|
||||||
|
# Display current configuration
|
||||||
|
st.subheader("📊 Current Configuration")
|
||||||
|
|
||||||
|
# Create a visually appealing layout with columns
|
||||||
|
col1, col2, col3, col4 = st.columns(4)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.metric(label="Grid Type", value=ensemble.grid.upper())
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.metric(label="Grid Level", value=ensemble.level)
|
||||||
|
|
||||||
|
with col3:
|
||||||
|
st.metric(label="Target Feature", value=ensemble.target.replace("darts_", ""))
|
||||||
|
|
||||||
|
with col4:
|
||||||
|
st.metric(label="Members", value=len(ensemble.members))
|
||||||
|
|
||||||
|
# Display members in an expandable section
|
||||||
|
with st.expander("🗂️ Dataset Members", expanded=False):
|
||||||
|
members_cols = st.columns(len(ensemble.members))
|
||||||
|
for idx, member in enumerate(ensemble.members):
|
||||||
|
with members_cols[idx]:
|
||||||
|
st.markdown(f"✓ **{member}**")
|
||||||
|
|
||||||
|
# Display dataset ID in a styled container
|
||||||
|
st.info(f"**Dataset ID:** `{ensemble.id()}`")
|
||||||
|
|
||||||
|
# Load training data for all three tasks
|
||||||
|
train_data_dict = load_all_training_data(ensemble)
|
||||||
|
|
||||||
|
# Calculate total samples (use binary as reference)
|
||||||
|
total_samples = len(train_data_dict["binary"])
|
||||||
|
train_samples = (train_data_dict["binary"].split == "train").sum().item()
|
||||||
|
test_samples = (train_data_dict["binary"].split == "test").sum().item()
|
||||||
|
|
||||||
|
st.success(f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks")
|
||||||
|
|
||||||
|
# Render distribution histograms
|
||||||
|
st.markdown("---")
|
||||||
|
render_all_distribution_histograms(train_data_dict)
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
# Render spatial map (as a fragment for efficient re-rendering)
|
||||||
|
# Extract geometries from the X.data dataframe (which has geometry as a column)
|
||||||
|
# The index should be cell_id
|
||||||
|
binary_dataset = train_data_dict["binary"]
|
||||||
|
assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset"
|
||||||
|
|
||||||
|
render_spatial_map(train_data_dict)
|
||||||
|
|
||||||
|
# Add more components and visualizations as needed for training data.
|
||||||
|
else:
|
||||||
|
st.info("Configure the dataset settings in the sidebar and click 'Load Dataset' to begin.")
|
||||||
101
src/entropice/dashboard/utils/data.py
Normal file
101
src/entropice/dashboard/utils/data.py
Normal file
|
|
@ -0,0 +1,101 @@
|
||||||
|
"""Data utilities for Entropice dashboard."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import antimeridian
|
||||||
|
import pandas as pd
|
||||||
|
import streamlit as st
|
||||||
|
import toml
|
||||||
|
from shapely.geometry import shape
|
||||||
|
|
||||||
|
import entropice.paths
|
||||||
|
from entropice.dataset import CategoricalTrainingDataset, DatasetEnsemble
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingResult:
|
||||||
|
"""Simple wrapper of training result data."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
path: Path
|
||||||
|
settings: dict
|
||||||
|
results: pd.DataFrame
|
||||||
|
created_at: float
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_path(cls, result_path: Path) -> "TrainingResult":
|
||||||
|
"""Load a TrainingResult from a given result directory path."""
|
||||||
|
result_file = result_path / "search_results.parquet"
|
||||||
|
state_file = result_path / "best_estimator_state.nc"
|
||||||
|
preds_file = result_path / "predicted_probabilities.parquet"
|
||||||
|
settings_file = result_path / "search_settings.toml"
|
||||||
|
if not all([result_file.exists(), state_file.exists(), preds_file.exists(), settings_file.exists()]):
|
||||||
|
raise FileNotFoundError(f"Missing required files in {result_path}")
|
||||||
|
|
||||||
|
created_at = result_path.stat().st_ctime
|
||||||
|
settings = toml.load(settings_file)["settings"]
|
||||||
|
results = pd.read_parquet(result_file)
|
||||||
|
|
||||||
|
# Name should be "task grid-level (created_at)"
|
||||||
|
name = (
|
||||||
|
f"**{settings.get('task', 'Unknown').capitalize()}** -"
|
||||||
|
f" {settings.get('grid', 'Unknown').capitalize()}-{settings.get('level', 'Unknown')}"
|
||||||
|
f" ({datetime.fromtimestamp(created_at).strftime('%Y-%m-%d %H:%M')})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=name,
|
||||||
|
path=result_path,
|
||||||
|
settings=settings,
|
||||||
|
results=results,
|
||||||
|
created_at=created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_hex_geometry(geom):
|
||||||
|
"""Fix hexagon geometry crossing the antimeridian."""
|
||||||
|
try:
|
||||||
|
return shape(antimeridian.fix_shape(geom))
|
||||||
|
except ValueError as e:
|
||||||
|
st.error(f"Error fixing geometry: {e}")
|
||||||
|
return geom
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_data
|
||||||
|
def load_all_training_results() -> list[TrainingResult]:
|
||||||
|
"""Load all training results from the results directory."""
|
||||||
|
results_dir = entropice.paths.RESULTS_DIR
|
||||||
|
training_results: list[TrainingResult] = []
|
||||||
|
for result_path in results_dir.iterdir():
|
||||||
|
if not result_path.is_dir():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
training_result = TrainingResult.from_path(result_path)
|
||||||
|
training_results.append(training_result)
|
||||||
|
except FileNotFoundError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Sort by creation time (most recent first)
|
||||||
|
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
|
||||||
|
return training_results
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_data
|
||||||
|
def load_all_training_data(e: DatasetEnsemble) -> dict[str, CategoricalTrainingDataset]:
|
||||||
|
"""Load training data for all three tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
e: DatasetEnsemble object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"binary": e.create_cat_training_dataset("binary"),
|
||||||
|
"count": e.create_cat_training_dataset("count"),
|
||||||
|
"density": e.create_cat_training_dataset("density"),
|
||||||
|
}
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
# ruff: noqa: N806
|
||||||
"""Training dataset preparation and model training.
|
"""Training dataset preparation and model training.
|
||||||
|
|
||||||
Naming conventions:
|
Naming conventions:
|
||||||
|
|
@ -14,15 +15,18 @@ Naming conventions:
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from functools import cached_property, lru_cache
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import cyclopts
|
import cyclopts
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from rich import pretty, traceback
|
from rich import pretty, traceback
|
||||||
from sklearn import set_config
|
from sklearn import set_config
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
import entropice.paths
|
import entropice.paths
|
||||||
|
|
||||||
|
|
@ -35,29 +39,111 @@ sns.set_theme("talk", "whitegrid")
|
||||||
|
|
||||||
|
|
||||||
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
|
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
|
||||||
|
time_index = pd.DatetimeIndex(df.index.get_level_values("time"))
|
||||||
if temporal == "yearly":
|
if temporal == "yearly":
|
||||||
return df.index.get_level_values("time").year
|
return time_index.year
|
||||||
elif temporal == "seasonal":
|
elif temporal == "seasonal":
|
||||||
seasons = {10: "winter", 4: "summer"}
|
seasons = {10: "winter", 4: "summer"}
|
||||||
return (
|
return time_index.month.map(lambda x: seasons.get(x)).str.cat(time_index.year.astype(str), sep="_")
|
||||||
df.index.get_level_values("time")
|
|
||||||
.month.map(lambda x: seasons.get(x))
|
|
||||||
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
|
||||||
)
|
|
||||||
elif temporal == "shoulder":
|
elif temporal == "shoulder":
|
||||||
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
|
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
|
||||||
return (
|
return time_index.month.map(lambda x: shoulder_seasons.get(x)).str.cat(time_index.year.astype(str), sep="_")
|
||||||
df.index.get_level_values("time")
|
|
||||||
.month.map(lambda x: shoulder_seasons.get(x))
|
|
||||||
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
type L2Dataset = Literal["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
type L2Dataset = Literal["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
|
type Task = Literal["binary", "count", "density"]
|
||||||
|
|
||||||
|
|
||||||
|
def bin_values(
|
||||||
|
values: pd.Series,
|
||||||
|
task: Literal["count", "density"],
|
||||||
|
none_val: float = 0,
|
||||||
|
) -> pd.Series:
|
||||||
|
"""Bin values into predefined intervals for different tasks.
|
||||||
|
|
||||||
|
First, a 'none' bin is created for values equal to `none_val` usually 0.
|
||||||
|
Then, the remaining values are binned automatically into 5 bins, all containing roughly the same number of samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
values (pd.Series): Pandas Series of numerical values to bin.
|
||||||
|
task (Literal["count", "density"]): Task type - 'count' or 'density'.
|
||||||
|
none_val (float, optional): Value representing 'none' or 'empty' (e.g., 0 for count). Defaults to 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.Series: Pandas Series of ordered categorical binned values.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If an value is NaN.
|
||||||
|
|
||||||
|
"""
|
||||||
|
labels_dict = {
|
||||||
|
"count": ["None", "Very Few", "Few", "Several", "Many", "Very Many"],
|
||||||
|
"density": ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"],
|
||||||
|
}
|
||||||
|
labels = labels_dict[task]
|
||||||
|
|
||||||
|
if values.isna().any():
|
||||||
|
raise ValueError("Values contain NaN")
|
||||||
|
|
||||||
|
# Separate none values from others
|
||||||
|
none_mask = values == none_val
|
||||||
|
non_none_values = values[~none_mask]
|
||||||
|
|
||||||
|
assert len(non_none_values) > 5, "Not enough non-none values to create bins."
|
||||||
|
binned_non_none = pd.qcut(non_none_values, q=5, labels=labels[1:]).cat.set_categories(labels, ordered=True)
|
||||||
|
binned = pd.Series(index=values.index, dtype="category")
|
||||||
|
binned = binned.cat.set_categories(labels, ordered=True)
|
||||||
|
binned.update(binned_non_none)
|
||||||
|
binned.loc[none_mask] = labels[0]
|
||||||
|
return binned
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, eq=False)
|
||||||
|
class DatasetLabels:
|
||||||
|
binned: pd.Series
|
||||||
|
train: torch.Tensor
|
||||||
|
test: torch.Tensor
|
||||||
|
raw_values: pd.Series
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def intervals(self) -> list[tuple[float, float] | tuple[int, int]]:
|
||||||
|
# For each category get the min and max values from raw_values
|
||||||
|
intervals = []
|
||||||
|
for category in self.binned.cat.categories:
|
||||||
|
category_mask = self.binned == category
|
||||||
|
if category_mask.sum() == 0:
|
||||||
|
intervals.append((None, None))
|
||||||
|
else:
|
||||||
|
category_raw_values = self.raw_values[category_mask]
|
||||||
|
intervals.append((category_raw_values.min(), category_raw_values.max()))
|
||||||
|
return intervals
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def labels(self) -> list[str]:
|
||||||
|
return list(self.binned.cat.categories)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, eq=False)
|
||||||
|
class DatasetInputs:
|
||||||
|
data: pd.DataFrame
|
||||||
|
train: torch.Tensor
|
||||||
|
test: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CategoricalTrainingDataset:
|
||||||
|
dataset: pd.DataFrame
|
||||||
|
X: DatasetInputs
|
||||||
|
y: DatasetLabels
|
||||||
|
z: pd.Series
|
||||||
|
split: pd.Series
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.z)
|
||||||
|
|
||||||
|
|
||||||
@cyclopts.Parameter("*")
|
@cyclopts.Parameter("*")
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class DatasetEnsemble:
|
class DatasetEnsemble:
|
||||||
grid: Literal["hex", "healpix"]
|
grid: Literal["hex", "healpix"]
|
||||||
level: int
|
level: int
|
||||||
|
|
@ -70,17 +156,35 @@ class DatasetEnsemble:
|
||||||
filter_target: str | Literal[False] = False
|
filter_target: str | Literal[False] = False
|
||||||
add_lonlat: bool = True
|
add_lonlat: bool = True
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return int(self.id(), 16)
|
||||||
|
|
||||||
def id(self):
|
def id(self):
|
||||||
return hashlib.blake2b(
|
return hashlib.blake2b(
|
||||||
json.dumps(asdict(self), sort_keys=True).encode("utf-8"),
|
json.dumps(asdict(self), sort_keys=True).encode("utf-8"),
|
||||||
digest_size=16,
|
digest_size=16,
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def covcol(self) -> str:
|
||||||
|
return "dartsml_has_coverage" if self.target == "darts_mllabels" else "darts_has_coverage"
|
||||||
|
|
||||||
|
def taskcol(self, task: Task) -> str:
|
||||||
|
if task == "binary":
|
||||||
|
return "dartsml_has_rts" if self.target == "darts_mllabels" else "darts_has_rts"
|
||||||
|
elif task == "count":
|
||||||
|
return "dartsml_rts_count" if self.target == "darts_mllabels" else "darts_rts_count"
|
||||||
|
elif task == "density":
|
||||||
|
return "dartsml_rts_density" if self.target == "darts_mllabels" else "darts_rts_density"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid task: {task}")
|
||||||
|
|
||||||
def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
|
def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
|
||||||
if member == "AlphaEarth":
|
if member == "AlphaEarth":
|
||||||
store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level)
|
store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level)
|
||||||
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
|
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
|
||||||
store = entropice.paths.get_era5_stores(member.split("-")[1], grid=self.grid, level=self.level)
|
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
|
||||||
|
store = entropice.paths.get_era5_stores(era5_agg, grid=self.grid, level=self.level)
|
||||||
elif member == "ArcticDEM":
|
elif member == "ArcticDEM":
|
||||||
store = entropice.paths.get_arcticdem_stores(grid=self.grid, level=self.level)
|
store = entropice.paths.get_arcticdem_stores(grid=self.grid, level=self.level)
|
||||||
else:
|
else:
|
||||||
|
|
@ -145,7 +249,7 @@ class DatasetEnsemble:
|
||||||
def _prep_era5(
|
def _prep_era5(
|
||||||
self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]
|
self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
era5 = self._read_member(f"ERA5-{temporal}", targets)
|
era5 = self._read_member("ERA5-" + temporal, targets)
|
||||||
era5_df = era5.to_dataframe()
|
era5_df = era5.to_dataframe()
|
||||||
era5_df["t"] = _get_era5_tempus(era5_df, temporal)
|
era5_df["t"] = _get_era5_tempus(era5_df, temporal)
|
||||||
if "aggregations" not in era5.dims:
|
if "aggregations" not in era5.dims:
|
||||||
|
|
@ -190,9 +294,10 @@ class DatasetEnsemble:
|
||||||
n_cols += n_cols_member
|
n_cols += n_cols_member
|
||||||
print(f"=== Total number of features in dataset: {n_cols}")
|
print(f"=== Total number of features in dataset: {n_cols}")
|
||||||
|
|
||||||
def create(self, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame:
|
@lru_cache(maxsize=1)
|
||||||
|
def create(self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame:
|
||||||
# n: no cache, o: overwrite cache, r: read cache if exists
|
# n: no cache, o: overwrite cache, r: read cache if exists
|
||||||
cache_file = entropice.paths.get_dataset_cache(self.id())
|
cache_file = entropice.paths.get_dataset_cache(self.id(), subset=filter_target_col)
|
||||||
if cache_mode == "r" and cache_file.exists():
|
if cache_mode == "r" and cache_file.exists():
|
||||||
dataset = gpd.read_parquet(cache_file)
|
dataset = gpd.read_parquet(cache_file)
|
||||||
print(
|
print(
|
||||||
|
|
@ -201,11 +306,14 @@ class DatasetEnsemble:
|
||||||
)
|
)
|
||||||
return dataset
|
return dataset
|
||||||
targets = self._read_target()
|
targets = self._read_target()
|
||||||
|
if filter_target_col is not None:
|
||||||
|
targets = targets.loc[targets[filter_target_col]]
|
||||||
|
|
||||||
member_dfs = []
|
member_dfs = []
|
||||||
for member in self.members:
|
for member in self.members:
|
||||||
if member.startswith("ERA5"):
|
if member.startswith("ERA5"):
|
||||||
member_dfs.append(self._prep_era5(targets, member.split("-")[1]))
|
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
|
||||||
|
member_dfs.append(self._prep_era5(targets, era5_agg))
|
||||||
elif member == "AlphaEarth":
|
elif member == "AlphaEarth":
|
||||||
member_dfs.append(self._prep_embeddings(targets))
|
member_dfs.append(self._prep_embeddings(targets))
|
||||||
elif member == "ArcticDEM":
|
elif member == "ArcticDEM":
|
||||||
|
|
@ -220,3 +328,65 @@ class DatasetEnsemble:
|
||||||
dataset.to_parquet(cache_file)
|
dataset.to_parquet(cache_file)
|
||||||
print(f"Saved dataset to cache at {cache_file}.")
|
print(f"Saved dataset to cache at {cache_file}.")
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
def create_cat_training_dataset(self, task: Task) -> CategoricalTrainingDataset:
|
||||||
|
"""Create a categorical dataset for training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (Task): Task type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CategoricalTrainingDataset: The prepared categorical training dataset.
|
||||||
|
|
||||||
|
"""
|
||||||
|
covcol = "dartsml_has_coverage" if self.target == "darts_mllabels" else "darts_has_coverage"
|
||||||
|
dataset = self.create(filter_target_col=covcol)
|
||||||
|
taskcol = self.taskcol(task)
|
||||||
|
|
||||||
|
valid_labels = dataset[taskcol].notna()
|
||||||
|
|
||||||
|
cols_to_drop = {"geometry", taskcol, covcol}
|
||||||
|
cols_to_drop |= {
|
||||||
|
col
|
||||||
|
for col in dataset.columns
|
||||||
|
if col.startswith("dartsml_" if self.target == "darts_mllabels" else "darts_")
|
||||||
|
}
|
||||||
|
|
||||||
|
model_inputs = dataset.drop(columns=cols_to_drop)
|
||||||
|
# Assert that no column in all-nan
|
||||||
|
assert not model_inputs.isna().all("index").any(), "Some input columns are all NaN"
|
||||||
|
# Get valid inputs (rows)
|
||||||
|
valid_inputs = model_inputs.notna().all("columns")
|
||||||
|
|
||||||
|
dataset = dataset.loc[valid_labels & valid_inputs]
|
||||||
|
model_inputs = model_inputs.loc[valid_labels & valid_inputs]
|
||||||
|
model_labels = dataset[taskcol]
|
||||||
|
|
||||||
|
if task == "binary":
|
||||||
|
binned = model_labels.map({False: "No RTS", True: "RTS"}).astype("category")
|
||||||
|
elif task == "count":
|
||||||
|
binned = bin_values(model_labels.astype(int), task=task)
|
||||||
|
elif task == "density":
|
||||||
|
binned = bin_values(model_labels, task=task)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid task.")
|
||||||
|
|
||||||
|
# Create train / test split
|
||||||
|
train_idx, test_idx = train_test_split(dataset.index.to_numpy(), test_size=0.2, random_state=42, shuffle=True)
|
||||||
|
split = pd.Series(index=dataset.index, dtype=object)
|
||||||
|
split.loc[train_idx] = "train"
|
||||||
|
split.loc[test_idx] = "test"
|
||||||
|
split = split.astype("category")
|
||||||
|
|
||||||
|
X_train = torch.asarray(model_inputs.loc[train_idx].to_numpy(dtype="float64"), device=0)
|
||||||
|
X_test = torch.asarray(model_inputs.loc[test_idx].to_numpy(dtype="float64"), device=0)
|
||||||
|
y_train = torch.asarray(binned.loc[train_idx].cat.codes.to_numpy(dtype="int64"), device=0)
|
||||||
|
y_test = torch.asarray(binned.loc[test_idx].cat.codes.to_numpy(dtype="int64"), device=0)
|
||||||
|
|
||||||
|
return CategoricalTrainingDataset(
|
||||||
|
dataset=dataset.to_crs("EPSG:4326"),
|
||||||
|
X=DatasetInputs(data=model_inputs, train=X_train, test=X_test),
|
||||||
|
y=DatasetLabels(binned=binned, train=y_train, test=y_test, raw_values=model_labels),
|
||||||
|
z=model_labels,
|
||||||
|
split=split,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -177,11 +177,11 @@ def download_daily_aggregated():
|
||||||
tchunksize = era5.chunksizes["time"][0]
|
tchunksize = era5.chunksizes["time"][0]
|
||||||
era5_chunk_starts = pd.date_range(era5.time.min().item(), era5.time.max().item(), freq=f"{tchunksize}h")
|
era5_chunk_starts = pd.date_range(era5.time.min().item(), era5.time.max().item(), freq=f"{tchunksize}h")
|
||||||
closest_chunk_start = era5_chunk_starts[
|
closest_chunk_start = era5_chunk_starts[
|
||||||
era5_chunk_starts.get_indexer([pd.to_datetime(min_time)], method="ffill")[0]
|
era5_chunk_starts.get_indexer([pd.to_datetime(min_time)], method="ffill")[0] # ty:ignore[invalid-argument-type]
|
||||||
]
|
]
|
||||||
subset["time"] = slice(str(closest_chunk_start), max_time)
|
subset["time"] = slice(str(closest_chunk_start), max_time)
|
||||||
|
|
||||||
era5 = era5.sel(**subset)
|
era5 = era5.sel(subset)
|
||||||
|
|
||||||
daily_raw = xr.merge(
|
daily_raw = xr.merge(
|
||||||
[
|
[
|
||||||
|
|
@ -680,7 +680,7 @@ def spatial_agg(
|
||||||
invalid_cell_id = [3059646, 3063547]
|
invalid_cell_id = [3059646, 3063547]
|
||||||
grid_gdf = grid_gdf[~grid_gdf.cell_id.isin(invalid_cell_id)]
|
grid_gdf = grid_gdf[~grid_gdf.cell_id.isin(invalid_cell_id)]
|
||||||
|
|
||||||
aggregations = {
|
aggregations_by_gridlevel: dict[str, dict[int, _Aggregations | Literal["interpolate"]]] = {
|
||||||
"hex": {
|
"hex": {
|
||||||
3: _Aggregations.common(),
|
3: _Aggregations.common(),
|
||||||
4: _Aggregations.common(),
|
4: _Aggregations.common(),
|
||||||
|
|
@ -695,9 +695,9 @@ def spatial_agg(
|
||||||
10: "interpolate",
|
10: "interpolate",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
aggregations = aggregations[grid][level]
|
aggregations = aggregations_by_gridlevel[grid][level]
|
||||||
|
|
||||||
for agg in ["yearly", "seasonal", "shoulder"]:
|
for agg in ("yearly", "seasonal", "shoulder"):
|
||||||
unaligned_store = get_era5_stores(agg)
|
unaligned_store = get_era5_stores(agg)
|
||||||
with stopwatch(f"Loading {agg} ERA5 data"):
|
with stopwatch(f"Loading {agg} ERA5 data"):
|
||||||
unaligned = xr.open_zarr(unaligned_store, consolidated=False).set_coords("spatial_ref").load()
|
unaligned = xr.open_zarr(unaligned_store, consolidated=False).set_coords("spatial_ref").load()
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,11 @@
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
|
from cuml.ensemble import RandomForestClassifier
|
||||||
|
from entropy import ESPAClassifier
|
||||||
from rich import pretty, traceback
|
from rich import pretty, traceback
|
||||||
from sklearn import set_config
|
from sklearn import set_config
|
||||||
from sklearn.base import BaseEstimator
|
from xgboost.sklearn import XGBClassifier
|
||||||
|
|
||||||
from entropice.dataset import DatasetEnsemble
|
from entropice.dataset import DatasetEnsemble
|
||||||
|
|
||||||
|
|
@ -16,7 +18,9 @@ pretty.install()
|
||||||
set_config(array_api_dispatch=True)
|
set_config(array_api_dispatch=True)
|
||||||
|
|
||||||
|
|
||||||
def predict_proba(e: DatasetEnsemble, clf: BaseEstimator, classes: list) -> gpd.GeoDataFrame:
|
def predict_proba(
|
||||||
|
e: DatasetEnsemble, clf: RandomForestClassifier | ESPAClassifier | XGBClassifier, classes: list
|
||||||
|
) -> gpd.GeoDataFrame:
|
||||||
"""Get predicted probabilities for each cell.
|
"""Get predicted probabilities for each cell.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,15 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
DATA_DIR = Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None)).resolve() / "entropice"
|
DATA_DIR = (
|
||||||
|
Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice"
|
||||||
|
)
|
||||||
DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcoding for FAST cluster
|
DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcoding for FAST cluster
|
||||||
|
|
||||||
GRIDS_DIR = DATA_DIR / "grids"
|
GRIDS_DIR = DATA_DIR / "grids"
|
||||||
FIGURES_DIR = Path("figures")
|
FIGURES_DIR = Path("figures")
|
||||||
DARTS_DIR = DATA_DIR / "darts"
|
RTS_DIR = DATA_DIR / "darts-rts"
|
||||||
|
RTS_LABELS_DIR = DATA_DIR / "darts-rts-mllabels"
|
||||||
ERA5_DIR = DATA_DIR / "era5"
|
ERA5_DIR = DATA_DIR / "era5"
|
||||||
ARCTICDEM_DIR = DATA_DIR / "arcticdem"
|
ARCTICDEM_DIR = DATA_DIR / "arcticdem"
|
||||||
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
|
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
|
||||||
|
|
@ -22,7 +25,7 @@ RESULTS_DIR = DATA_DIR / "results"
|
||||||
|
|
||||||
GRIDS_DIR.mkdir(parents=True, exist_ok=True)
|
GRIDS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
|
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
DARTS_DIR.mkdir(parents=True, exist_ok=True)
|
RTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
ERA5_DIR.mkdir(parents=True, exist_ok=True)
|
ERA5_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
ARCTICDEM_DIR.mkdir(parents=True, exist_ok=True)
|
ARCTICDEM_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -34,9 +37,9 @@ DATASET_ENSEMBLES_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
|
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
|
||||||
|
|
||||||
dartsl2_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
|
dartsl2_file = RTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
|
||||||
dartsl2_cov_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
|
dartsl2_cov_file = RTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
|
||||||
darts_ml_training_labels_repo = DARTS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
|
darts_ml_training_labels_repo = RTS_LABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
|
||||||
|
|
||||||
|
|
||||||
def _get_gridname(grid: Literal["hex", "healpix"], level: int) -> str:
|
def _get_gridname(grid: Literal["hex", "healpix"], level: int) -> str:
|
||||||
|
|
@ -58,9 +61,9 @@ def get_grid_viz_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||||
def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int, labels: bool = False) -> Path:
|
def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int, labels: bool = False) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
if labels:
|
if labels:
|
||||||
rtsfile = DARTS_DIR / f"{gridname}_darts-mllabels.parquet"
|
rtsfile = RTS_LABELS_DIR / f"{gridname}_darts-mllabels.parquet"
|
||||||
else:
|
else:
|
||||||
rtsfile = DARTS_DIR / f"{gridname}_darts.parquet"
|
rtsfile = RTS_DIR / f"{gridname}_darts.parquet"
|
||||||
return rtsfile
|
return rtsfile
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -107,8 +110,11 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||||
return dataset_file
|
return dataset_file
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_cache(eid: str) -> Path:
|
def get_dataset_cache(eid: str, subset: str | None = None) -> Path:
|
||||||
|
if subset is None:
|
||||||
cache_file = DATASET_ENSEMBLES_DIR / f"{eid}_dataset.parquet"
|
cache_file = DATASET_ENSEMBLES_DIR / f"{eid}_dataset.parquet"
|
||||||
|
else:
|
||||||
|
cache_file = DATASET_ENSEMBLES_DIR / f"{eid}_{subset}_dataset.parquet"
|
||||||
return cache_file
|
return cache_file
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -116,7 +122,7 @@ def get_cv_results_dir(
|
||||||
name: str,
|
name: str,
|
||||||
grid: Literal["hex", "healpix"],
|
grid: Literal["hex", "healpix"],
|
||||||
level: int,
|
level: int,
|
||||||
task: Literal["binary", "multi"],
|
task: Literal["binary", "count", "density"],
|
||||||
) -> Path:
|
) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# ruff: noqa: N806
|
|
||||||
"""Training of classification models training."""
|
"""Training of classification models training."""
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
|
|
@ -8,7 +7,6 @@ from typing import Literal
|
||||||
import cyclopts
|
import cyclopts
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import toml
|
import toml
|
||||||
import torch
|
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from cuml.ensemble import RandomForestClassifier
|
from cuml.ensemble import RandomForestClassifier
|
||||||
from cuml.neighbors import KNeighborsClassifier
|
from cuml.neighbors import KNeighborsClassifier
|
||||||
|
|
@ -17,9 +15,9 @@ from rich import pretty, traceback
|
||||||
from scipy.stats import loguniform, randint
|
from scipy.stats import loguniform, randint
|
||||||
from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen
|
from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen
|
||||||
from sklearn import set_config
|
from sklearn import set_config
|
||||||
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
|
from sklearn.model_selection import KFold, RandomizedSearchCV
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
from xgboost import XGBClassifier
|
from xgboost.sklearn import XGBClassifier
|
||||||
|
|
||||||
from entropice.dataset import DatasetEnsemble
|
from entropice.dataset import DatasetEnsemble
|
||||||
from entropice.inference import predict_proba
|
from entropice.inference import predict_proba
|
||||||
|
|
@ -30,7 +28,7 @@ pretty.install()
|
||||||
|
|
||||||
set_config(array_api_dispatch=True)
|
set_config(array_api_dispatch=True)
|
||||||
|
|
||||||
cli = cyclopts.App("entropice-training", config=cyclopts.config.Toml("training-config.toml"))
|
cli = cyclopts.App("entropice-training", config=cyclopts.config.Toml("training-config.toml")) # ty:ignore[invalid-argument-type]
|
||||||
|
|
||||||
_metrics = {
|
_metrics = {
|
||||||
"binary": ["accuracy", "recall", "precision", "f1", "jaccard"],
|
"binary": ["accuracy", "recall", "precision", "f1", "jaccard"],
|
||||||
|
|
@ -57,51 +55,6 @@ class CVSettings:
|
||||||
model: Literal["espa", "xgboost", "rf", "knn"] = "espa"
|
model: Literal["espa", "xgboost", "rf", "knn"] = "espa"
|
||||||
|
|
||||||
|
|
||||||
def _create_xy_data(e: DatasetEnsemble, task: Literal["binary", "count", "density"] = "binary"):
|
|
||||||
data = e.create()
|
|
||||||
|
|
||||||
covcol = "dartsml_has_coverage" if e.target == "darts_mllabels" else "darts_has_coverage"
|
|
||||||
bincol = "dartsml_has_rts" if e.target == "darts_mllabels" else "darts_has_rts"
|
|
||||||
countcol = "dartsml_rts_count" if e.target == "darts_mllabels" else "darts_rts_count"
|
|
||||||
densitycol = "dartsml_rts_density" if e.target == "darts_mllabels" else "darts_rts_density"
|
|
||||||
|
|
||||||
data = data[data[covcol]].reset_index(drop=True)
|
|
||||||
|
|
||||||
cols_to_drop = ["geometry"]
|
|
||||||
if e.target == "darts_mllabels":
|
|
||||||
cols_to_drop += [col for col in data.columns if col.startswith("dartsml_")]
|
|
||||||
else:
|
|
||||||
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
|
|
||||||
X_data = data.drop(columns=cols_to_drop).dropna()
|
|
||||||
if task == "binary":
|
|
||||||
labels = ["No RTS", "RTS"]
|
|
||||||
y_data = data.loc[X_data.index, bincol]
|
|
||||||
elif task == "count":
|
|
||||||
# Put into n categories (log scaled)
|
|
||||||
y_data = data.loc[X_data.index, countcol]
|
|
||||||
n_categories = 5
|
|
||||||
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
|
|
||||||
# Change the first interval to start at 1 and add a category for 0
|
|
||||||
bins = pd.IntervalIndex.from_tuples(
|
|
||||||
[(-1, 0)] + [(int(interval.left), int(interval.right)) for interval in bins]
|
|
||||||
)
|
|
||||||
print(f"{bins=}")
|
|
||||||
y_data = pd.cut(y_data, bins=bins)
|
|
||||||
labels = [str(v) for v in y_data.sort_values().unique()]
|
|
||||||
y_data = y_data.cat.codes
|
|
||||||
elif task == "density":
|
|
||||||
y_data = data.loc[X_data.index, densitycol]
|
|
||||||
n_categories = 5
|
|
||||||
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
|
|
||||||
print(f"{bins=}")
|
|
||||||
y_data = pd.cut(y_data, bins=bins)
|
|
||||||
labels = [str(v) for v in y_data.sort_values().unique()]
|
|
||||||
y_data = y_data.cat.codes
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown task: {task}")
|
|
||||||
return data, X_data, y_data, labels
|
|
||||||
|
|
||||||
|
|
||||||
def _create_clf(
|
def _create_clf(
|
||||||
settings: CVSettings,
|
settings: CVSettings,
|
||||||
):
|
):
|
||||||
|
|
@ -196,15 +149,7 @@ def random_cv(
|
||||||
|
|
||||||
"""
|
"""
|
||||||
print("Creating training data...")
|
print("Creating training data...")
|
||||||
_, X_data, y_data, labels = _create_xy_data(dataset_ensemble, task=settings.task)
|
training_data = dataset_ensemble.create_cat_training_dataset(task=settings.task)
|
||||||
print(f"Using {settings.task}-class classification with {len(labels)} classes: {labels}")
|
|
||||||
print(f"{y_data.describe()=}")
|
|
||||||
X = X_data.to_numpy(dtype="float64")
|
|
||||||
y = y_data.to_numpy(dtype="int8")
|
|
||||||
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
|
|
||||||
print(f"{X.shape=}, {y.shape=}")
|
|
||||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
|
||||||
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
|
|
||||||
|
|
||||||
clf, param_grid, fit_params = _create_clf(settings)
|
clf, param_grid, fit_params = _create_clf(settings)
|
||||||
print(f"Using model: {settings.model} with parameters: {param_grid}")
|
print(f"Using model: {settings.model} with parameters: {param_grid}")
|
||||||
|
|
@ -224,14 +169,14 @@ def random_cv(
|
||||||
|
|
||||||
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
|
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
|
||||||
with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"):
|
with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"):
|
||||||
search.fit(X_train, y_train, **fit_params)
|
search.fit(training_data.X.train, training_data.y.train, **fit_params)
|
||||||
|
|
||||||
print("Best parameters combination found:")
|
print("Best parameters combination found:")
|
||||||
best_parameters = search.best_estimator_.get_params()
|
best_parameters = search.best_estimator_.get_params()
|
||||||
for param_name in sorted(param_grid.keys()):
|
for param_name in sorted(param_grid.keys()):
|
||||||
print(f"{param_name}: {best_parameters[param_name]}")
|
print(f"{param_name}: {best_parameters[param_name]}")
|
||||||
|
|
||||||
test_accuracy = search.score(X_test, y_test)
|
test_accuracy = search.score(training_data.X.test, training_data.y.test)
|
||||||
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
|
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
|
||||||
print(f"Accuracy on test set: {test_accuracy:.3f}")
|
print(f"Accuracy on test set: {test_accuracy:.3f}")
|
||||||
|
|
||||||
|
|
@ -251,7 +196,7 @@ def random_cv(
|
||||||
"param_grid": param_grid_serializable,
|
"param_grid": param_grid_serializable,
|
||||||
"cv_splits": cv.get_n_splits(),
|
"cv_splits": cv.get_n_splits(),
|
||||||
"metrics": metrics,
|
"metrics": metrics,
|
||||||
"classes": labels,
|
"classes": training_data.y.labels,
|
||||||
}
|
}
|
||||||
settings_file = results_dir / "search_settings.toml"
|
settings_file = results_dir / "search_settings.toml"
|
||||||
print(f"Storing search settings to {settings_file}")
|
print(f"Storing search settings to {settings_file}")
|
||||||
|
|
@ -267,7 +212,7 @@ def random_cv(
|
||||||
# Store the search results
|
# Store the search results
|
||||||
results = pd.DataFrame(search.cv_results_)
|
results = pd.DataFrame(search.cv_results_)
|
||||||
# Parse the params into individual columns
|
# Parse the params into individual columns
|
||||||
params = pd.json_normalize(results["params"])
|
params = pd.json_normalize(results["params"]) # ty:ignore[invalid-argument-type]
|
||||||
# Concatenate the params columns with the original DataFrame
|
# Concatenate the params columns with the original DataFrame
|
||||||
results = pd.concat([results.drop(columns=["params"]), params], axis=1)
|
results = pd.concat([results.drop(columns=["params"]), params], axis=1)
|
||||||
results_file = results_dir / "search_results.parquet"
|
results_file = results_dir / "search_results.parquet"
|
||||||
|
|
@ -278,7 +223,7 @@ def random_cv(
|
||||||
if settings.model == "espa":
|
if settings.model == "espa":
|
||||||
best_estimator = search.best_estimator_
|
best_estimator = search.best_estimator_
|
||||||
# Annotate the state with xarray metadata
|
# Annotate the state with xarray metadata
|
||||||
features = X_data.columns.tolist()
|
features = training_data.X.data.columns.tolist()
|
||||||
boxes = list(range(best_estimator.K_))
|
boxes = list(range(best_estimator.K_))
|
||||||
box_centers = xr.DataArray(
|
box_centers = xr.DataArray(
|
||||||
best_estimator.S_.cpu().numpy(),
|
best_estimator.S_.cpu().numpy(),
|
||||||
|
|
@ -290,7 +235,7 @@ def random_cv(
|
||||||
box_assignments = xr.DataArray(
|
box_assignments = xr.DataArray(
|
||||||
best_estimator.Lambda_.cpu().numpy(),
|
best_estimator.Lambda_.cpu().numpy(),
|
||||||
dims=["class", "box"],
|
dims=["class", "box"],
|
||||||
coords={"class": labels, "box": boxes},
|
coords={"class": training_data.y.labels, "box": boxes},
|
||||||
name="box_assignments",
|
name="box_assignments",
|
||||||
attrs={"description": "Assignments of samples to boxes."},
|
attrs={"description": "Assignments of samples to boxes."},
|
||||||
)
|
)
|
||||||
|
|
@ -317,7 +262,7 @@ def random_cv(
|
||||||
|
|
||||||
# Predict probabilities for all cells
|
# Predict probabilities for all cells
|
||||||
print("Predicting probabilities for all cells...")
|
print("Predicting probabilities for all cells...")
|
||||||
preds = predict_proba(dataset_ensemble, clf=best_estimator, classes=labels)
|
preds = predict_proba(dataset_ensemble, clf=best_estimator, classes=training_data.y.labels)
|
||||||
preds_file = results_dir / "predicted_probabilities.parquet"
|
preds_file = results_dir / "predicted_probabilities.parquet"
|
||||||
print(f"Storing predicted probabilities to {preds_file}")
|
print(f"Storing predicted probabilities to {preds_file}")
|
||||||
preds.to_parquet(preds_file)
|
preds.to_parquet(preds_file)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue