diff --git a/src/entropice/dashboard/overview_page.py b/src/entropice/dashboard/overview_page.py index e373dae..7d087c5 100644 --- a/src/entropice/dashboard/overview_page.py +++ b/src/entropice/dashboard/overview_page.py @@ -3,9 +3,639 @@ from datetime import datetime import pandas as pd +import plotly.express as px import streamlit as st +from entropice.dashboard.plots.colors import get_palette from entropice.dashboard.utils.data import load_all_training_results +from entropice.dataset import DatasetEnsemble + + +def render_sample_count_overview(): + """Render overview of sample counts per task+target+grid+level combination.""" + st.subheader("📊 Sample Counts by Configuration") + + st.markdown( + """ + This visualization shows the number of available samples for each combination of: + - **Task**: binary, count, density + - **Target Dataset**: darts_rts, darts_mllabels + - **Grid System**: hex, healpix + - **Grid Level**: varying by grid type + """ + ) + + # Define all possible grid configurations + grid_configs = [ + ("hex", 3), + ("hex", 4), + ("hex", 5), + ("hex", 6), + ("healpix", 6), + ("healpix", 7), + ("healpix", 8), + ("healpix", 9), + ("healpix", 10), + ] + + target_datasets = ["darts_rts", "darts_mllabels"] + tasks = ["binary", "count", "density"] + + # Collect sample counts + sample_data = [] + + with st.spinner("Computing sample counts for all configurations..."): + for grid, level in grid_configs: + for target in target_datasets: + # Create minimal ensemble just to get target data + ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=[]) # type: ignore[arg-type] + + # Read target data + targets = ensemble._read_target() + + for task in tasks: + # Get task-specific column + taskcol = ensemble.taskcol(task) # type: ignore[arg-type] + covcol = ensemble.covcol + + # Count samples with coverage and valid labels + if covcol in targets.columns and taskcol in targets.columns: + valid_coverage = targets[covcol].sum() + valid_labels = targets[taskcol].notna().sum() + valid_both = (targets[covcol] & targets[taskcol].notna()).sum() + + sample_data.append( + { + "Grid": f"{grid}-{level}", + "Grid Type": grid, + "Level": level, + "Target": target.replace("darts_", ""), + "Task": task.capitalize(), + "Samples (Coverage)": valid_coverage, + "Samples (Labels)": valid_labels, + "Samples (Both)": valid_both, + "Grid_Level_Sort": f"{grid}_{level:02d}", + } + ) + + sample_df = pd.DataFrame(sample_data) + + # Create tabs for different views + tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"]) + + with tab1: + st.markdown("### Sample Counts Heatmap") + st.markdown("Showing counts of samples with both coverage and valid labels") + + # Create heatmap for each target dataset + for target in target_datasets: + target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")] + + # Pivot for heatmap: Grid x Task + pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Both)", aggfunc="mean") + + # Sort index by grid type and level + sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid") + pivot_df = pivot_df.reindex(sort_order.sort_values("Grid_Level_Sort").index) + + # Get color palette for sample counts + sample_colors = get_palette(f"sample_counts_{target}", n_colors=10) + + fig = px.imshow( + pivot_df, + labels={"x": "Task", "y": "Grid Configuration", "color": "Sample Count"}, + x=pivot_df.columns, + y=pivot_df.index, + color_continuous_scale=sample_colors, + aspect="auto", + title=f"Target: {target}", + ) + + # Add text annotations + fig.update_traces(text=pivot_df.values, texttemplate="%{text:,}", textfont_size=10) + + fig.update_layout(height=400) + st.plotly_chart(fig, use_container_width=True) + + with tab2: + st.markdown("### Sample Counts Bar Chart") + st.markdown("Showing counts of samples with both coverage and valid labels") + + # Create a faceted bar chart showing both targets side by side + # Get color palette for tasks + n_tasks = sample_df["Task"].nunique() + task_colors = get_palette("task_types", n_colors=n_tasks) + + fig = px.bar( + sample_df, + x="Grid", + y="Samples (Both)", + color="Task", + facet_col="Target", + barmode="group", + title="Sample Counts by Grid Configuration and Target Dataset", + labels={"Grid": "Grid Configuration", "Samples (Both)": "Number of Samples"}, + color_discrete_sequence=task_colors, + height=500, + ) + + # Update facet labels to be cleaner + fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1])) + fig.update_xaxes(tickangle=-45) + st.plotly_chart(fig, use_container_width=True) + + with tab3: + st.markdown("### Detailed Sample Counts") + + # Display full table with formatting + display_df = sample_df[ + ["Grid", "Target", "Task", "Samples (Coverage)", "Samples (Labels)", "Samples (Both)"] + ].copy() + + # Format numbers with commas + for col in ["Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]: + display_df[col] = display_df[col].apply(lambda x: f"{x:,}") + + st.dataframe(display_df, hide_index=True, use_container_width=True) + + +@st.fragment +def render_feature_count_fragment(): + """Render interactive feature count visualization using fragments.""" + st.subheader("🔢 Feature Counts by Dataset Configuration") + + st.markdown( + """ + This visualization shows the total number of features that would be generated + for different combinations of data sources and grid configurations. + """ + ) + + # First section: Comparison across all grid configurations + st.markdown("### Feature Count Comparison Across Grid Configurations") + st.markdown("Comparing feature counts for all grid configurations with all data sources enabled") + + # Define all possible grid configurations + grid_configs = [ + ("hex", 3), + ("hex", 4), + ("hex", 5), + ("hex", 6), + ("healpix", 6), + ("healpix", 7), + ("healpix", 8), + ("healpix", 9), + ("healpix", 10), + ] + + # Collect feature statistics for all configurations + feature_comparison_data = [] + + with st.spinner("Computing feature counts for all grid configurations..."): + for grid, level in grid_configs: + # Determine which members are available for this configuration + disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6) + + if disable_alphaearth: + members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] + else: + members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] + + # Use darts_rts as default target for comparison + ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type] + stats = ensemble.get_stats() + + # Calculate minimum cells across all data sources + min_cells = min( + member_stats["dimensions"]["cell_ids"] # type: ignore[index] + for member_stats in stats["members"].values() + ) + + feature_comparison_data.append( + { + "Grid": f"{grid}-{level}", + "Grid Type": grid, + "Level": level, + "Total Features": stats["total_features"], + "Data Sources": len(members), + "Inference Cells": min_cells, + "Total Samples": stats["num_target_samples"], + "AlphaEarth": "AlphaEarth" in members, + "Grid_Level_Sort": f"{grid}_{level:02d}", + } + ) + + comparison_df = pd.DataFrame(feature_comparison_data) + + # Create tabs for different comparison views + comp_tab1, comp_tab2, comp_tab3 = st.tabs(["📊 Bar Chart", "📈 Breakdown", "📋 Data Table"]) + + with comp_tab1: + st.markdown("#### Total Features by Grid Configuration") + + # Collect breakdown data for stacked bar chart + stacked_data = [] + + for idx, row in comparison_df.iterrows(): + grid_config = row["Grid"] + grid, level_str = grid_config.split("-") + level = int(level_str) + disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6) + + if disable_alphaearth: + members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] + else: + members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] + + ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type] + stats = ensemble.get_stats() + + # Add data for each member + for member, member_stats in stats["members"].items(): + stacked_data.append( + { + "Grid": grid_config, + "Data Source": member, + "Number of Features": member_stats["num_features"], + "Grid_Level_Sort": row["Grid_Level_Sort"], + } + ) + + # Add lon/lat + if ensemble.add_lonlat: + stacked_data.append( + { + "Grid": grid_config, + "Data Source": "Lon/Lat", + "Number of Features": 2, + "Grid_Level_Sort": row["Grid_Level_Sort"], + } + ) + + stacked_df = pd.DataFrame(stacked_data) + stacked_df = stacked_df.sort_values("Grid_Level_Sort") + + # Get color palette for data sources + unique_sources = stacked_df["Data Source"].unique() + n_sources = len(unique_sources) + source_colors = get_palette("data_sources", n_colors=n_sources) + + # Create stacked bar chart + fig = px.bar( + stacked_df, + x="Grid", + y="Number of Features", + color="Data Source", + barmode="stack", + title="Total Features by Data Source Across Grid Configurations", + labels={"Grid": "Grid Configuration", "Number of Features": "Number of Features"}, + color_discrete_sequence=source_colors, + text_auto=False, + ) + + fig.update_layout(height=500, xaxis_tickangle=-45) + st.plotly_chart(fig, use_container_width=True) + + # Add secondary metrics + col1, col2 = st.columns(2) + with col1: + # Get color palette for grid configs + n_grids = len(comparison_df) + grid_colors = get_palette("grid_configs", n_colors=n_grids) + + fig_cells = px.bar( + comparison_df, + x="Grid", + y="Inference Cells", + color="Grid", + title="Inference Cells by Grid Configuration", + labels={"Grid": "Grid Configuration", "Inference Cells": "Number of Cells"}, + color_discrete_sequence=grid_colors, + text="Inference Cells", + ) + fig_cells.update_traces(texttemplate="%{text:,}", textposition="outside") + fig_cells.update_layout(xaxis_tickangle=-45, showlegend=False) + st.plotly_chart(fig_cells, use_container_width=True) + + with col2: + fig_samples = px.bar( + comparison_df, + x="Grid", + y="Total Samples", + color="Grid", + title="Total Samples by Grid Configuration", + labels={"Grid": "Grid Configuration", "Total Samples": "Number of Samples"}, + color_discrete_sequence=grid_colors, + text="Total Samples", + ) + fig_samples.update_traces(texttemplate="%{text:,}", textposition="outside") + fig_samples.update_layout(xaxis_tickangle=-45, showlegend=False) + st.plotly_chart(fig_samples, use_container_width=True) + + with comp_tab2: + st.markdown("#### Feature Breakdown by Data Source") + st.markdown("Showing percentage contribution of each data source across all grid configurations") + + # Collect breakdown data for all grid configurations + all_breakdown_data = [] + + for idx, row in comparison_df.iterrows(): + grid_config = row["Grid"] + grid, level_str = grid_config.split("-") + level = int(level_str) + disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6) + + if disable_alphaearth: + members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] + else: + members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] + + ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type] + stats = ensemble.get_stats() + + total_features = stats["total_features"] + + # Add data for each member with percentage + for member, member_stats in stats["members"].items(): + percentage = (member_stats["num_features"] / total_features) * 100 # type: ignore[operator] + all_breakdown_data.append( + { + "Grid": grid_config, + "Data Source": member, + "Percentage": percentage, + "Number of Features": member_stats["num_features"], + "Grid_Level_Sort": row["Grid_Level_Sort"], + } + ) + + # Add lon/lat + if ensemble.add_lonlat: + percentage = (2 / total_features) * 100 # type: ignore[operator] + all_breakdown_data.append( + { + "Grid": grid_config, + "Data Source": "Lon/Lat", + "Percentage": percentage, + "Number of Features": 2, + "Grid_Level_Sort": row["Grid_Level_Sort"], + } + ) + + breakdown_all_df = pd.DataFrame(all_breakdown_data) + + # Sort by grid configuration + breakdown_all_df = breakdown_all_df.sort_values("Grid_Level_Sort") + + # Get color palette for data sources + unique_sources = breakdown_all_df["Data Source"].unique() + n_sources = len(unique_sources) + source_colors = get_palette("data_sources", n_colors=n_sources) + + # Create donut charts for each grid configuration + # Organize in a grid layout + num_grids = len(comparison_df) + cols_per_row = 3 + num_rows = (num_grids + cols_per_row - 1) // cols_per_row + + for row_idx in range(num_rows): + cols = st.columns(cols_per_row) + for col_idx in range(cols_per_row): + grid_idx = row_idx * cols_per_row + col_idx + if grid_idx < num_grids: + grid_config = comparison_df.iloc[grid_idx]["Grid"] + grid_data = breakdown_all_df[breakdown_all_df["Grid"] == grid_config] + + with cols[col_idx]: + fig = px.pie( + grid_data, + names="Data Source", + values="Number of Features", + title=grid_config, + hole=0.4, + color_discrete_sequence=source_colors, + ) + fig.update_traces(textposition="inside", textinfo="percent") + fig.update_layout(showlegend=True, height=350) + st.plotly_chart(fig, use_container_width=True) + + with comp_tab3: + st.markdown("#### Detailed Feature Count Comparison") + + # Display full comparison table with formatting + display_df = comparison_df[ + ["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples", "AlphaEarth"] + ].copy() + + # Format numbers with commas + for col in ["Total Features", "Inference Cells", "Total Samples"]: + display_df[col] = display_df[col].apply(lambda x: f"{x:,}") + + # Format boolean as Yes/No + display_df["AlphaEarth"] = display_df["AlphaEarth"].apply(lambda x: "✓" if x else "✗") + + st.dataframe(display_df, hide_index=True, use_container_width=True) + + st.divider() + + # Second section: Detailed configuration with user selection + st.markdown("### Detailed Configuration Explorer") + st.markdown("Select specific grid configuration and data sources for detailed statistics") + + # Grid selection + grid_options = [ + "hex-3", + "hex-4", + "hex-5", + "hex-6", + "healpix-6", + "healpix-7", + "healpix-8", + "healpix-9", + "healpix-10", + ] + + col1, col2 = st.columns(2) + + with col1: + grid_level_combined = st.selectbox( + "Grid Configuration", + options=grid_options, + index=0, + help="Select the grid system and resolution level", + key="feature_grid_select", + ) + + with col2: + target = st.selectbox( + "Target Dataset", + options=["darts_rts", "darts_mllabels"], + index=0, + help="Select the target dataset", + key="feature_target_select", + ) + + # Parse grid type and level + grid, level_str = grid_level_combined.split("-") + level = int(level_str) + + # Members selection + st.markdown("#### Select Data Sources") + + # Check if AlphaEarth should be disabled + disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6) + + all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] + + # Use columns for checkboxes + cols = st.columns(len(all_members)) + selected_members = [] + + for idx, member in enumerate(all_members): + with cols[idx]: + if member == "AlphaEarth" and disable_alphaearth: + st.checkbox( + member, + value=False, + disabled=True, + help=f"Not available for {grid} level {level}", + key=f"feature_member_{member}", + ) + else: + if st.checkbox(member, value=True, key=f"feature_member_{member}"): + selected_members.append(member) + + # Show results if at least one member is selected + if selected_members: + st.markdown("---") + + ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members) + + with st.spinner("Computing dataset statistics..."): + stats = ensemble.get_stats() + + # High-level metrics + col1, col2, col3, col4, col5 = st.columns(5) + with col1: + st.metric("Total Features", f"{stats['total_features']:,}") + with col2: + # Calculate minimum cells across all data sources (for inference capability) + min_cells = min( + member_stats["dimensions"]["cell_ids"] # type: ignore[index] + for member_stats in stats["members"].values() + ) + st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources") + with col3: + st.metric("Data Sources", len(selected_members)) + with col4: + st.metric("Total Samples", f"{stats['num_target_samples']:,}") + with col5: + # Calculate total data points + total_points = stats["total_features"] * stats["num_target_samples"] + st.metric("Total Data Points", f"{total_points:,}") + + # Feature breakdown by source + st.markdown("#### Feature Breakdown by Data Source") + + breakdown_data = [] + for member, member_stats in stats["members"].items(): + breakdown_data.append( + { + "Data Source": member, + "Number of Features": member_stats["num_features"], + "Percentage": f"{member_stats['num_features'] / stats['total_features'] * 100:.1f}%", # type: ignore[operator] + } + ) + + # Add lon/lat + if ensemble.add_lonlat: + breakdown_data.append( + { + "Data Source": "Lon/Lat", + "Number of Features": 2, + "Percentage": f"{2 / stats['total_features'] * 100:.1f}%", + } + ) + + breakdown_df = pd.DataFrame(breakdown_data) + + # Get color palette for data sources + n_sources = len(breakdown_df) + source_colors = get_palette("data_sources", n_colors=n_sources) + + # Create pie chart + fig = px.pie( + breakdown_df, + names="Data Source", + values="Number of Features", + title="Feature Distribution by Data Source", + hole=0.4, + color_discrete_sequence=source_colors, + ) + fig.update_traces(textposition="inside", textinfo="percent+label") + st.plotly_chart(fig, use_container_width=True) + + # Show detailed table + st.dataframe(breakdown_df, hide_index=True, use_container_width=True) + + # Detailed member information + with st.expander("📦 Detailed Source Information", expanded=False): + for member, member_stats in stats["members"].items(): + st.markdown(f"### {member}") + + metric_cols = st.columns(4) + with metric_cols[0]: + st.metric("Features", member_stats["num_features"]) + with metric_cols[1]: + st.metric("Variables", member_stats["num_variables"]) + with metric_cols[2]: + dim_str = " x ".join([str(dim) for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr] + st.metric("Shape", dim_str) + with metric_cols[3]: + total_points = 1 + for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr] + total_points *= dim_size + st.metric("Data Points", f"{total_points:,}") + + # Variables + st.markdown("**Variables:**") + vars_html = " ".join( + [ + f'{v}' + for v in member_stats["variables"] # type: ignore[union-attr] + ] + ) + st.markdown(vars_html, unsafe_allow_html=True) + + # Dimensions + st.markdown("**Dimensions:**") + dim_html = " ".join( + [ + f'' + f"{dim_name}: {dim_size}" + for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr] + ] + ) + st.markdown(dim_html, unsafe_allow_html=True) + + st.markdown("---") + else: + st.info("👆 Select at least one data source to see feature statistics") + + +def render_dataset_analysis(): + """Render the dataset analysis section with sample and feature counts.""" + st.header("📈 Dataset Analysis") + + # Create tabs for the two different analyses + analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"]) + + with analysis_tabs[0]: + render_sample_count_overview() + + with analysis_tabs[1]: + render_feature_count_fragment() def render_overview_page(): @@ -20,8 +650,10 @@ def render_overview_page(): st.write(f"Found **{len(training_results)}** training result(s)") + st.divider() + # Summary statistics at the top - st.subheader("Summary Statistics") + st.header("📊 Training Results Summary") col1, col2, col3, col4 = st.columns(4) with col1: @@ -43,8 +675,14 @@ def render_overview_page(): st.divider() + # Add dataset analysis section + render_dataset_analysis() + + st.divider() + # Detailed results table - st.subheader("Training Results") + st.header("🎯 Experiment Results") + st.subheader("Results Table") # Build a summary dataframe summary_data = [] @@ -93,7 +731,7 @@ def render_overview_page(): st.divider() # Expandable details for each result - st.subheader("Detailed Results") + st.subheader("Individual Experiment Details") for tr in training_results: with st.expander(tr.get_display_name("task_first")):