Extent the dashboard to use altair
This commit is contained in:
parent
d498b1e752
commit
150f14ed52
6 changed files with 886 additions and 265 deletions
61
pixi.lock
generated
61
pixi.lock
generated
|
|
@ -94,6 +94,7 @@ environments:
|
||||||
- pypi: https://files.pythonhosted.org/packages/f7/0d/4764669bdf47bd472899b3d3db91fffbe925c8e3038ec591a2fd2ad6a14d/aiohttp-3.13.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
- pypi: https://files.pythonhosted.org/packages/f7/0d/4764669bdf47bd472899b3d3db91fffbe925c8e3038ec591a2fd2ad6a14d/aiohttp-3.13.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/aa/f3/0b6ced594e51cc95d8c1fc1640d3623770d01e4969d29c0bd09945fafefa/altair-5.5.0-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/aa/f3/0b6ced594e51cc95d8c1fc1640d3623770d01e4969d29c0bd09945fafefa/altair-5.5.0-py3-none-any.whl
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/c8/a7/a597ff7dd1e1603abd94991ce242f93979d5f10b0d45ed23976dfb22bf64/altair_tiles-0.4.0-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/2b/f0/09a30ca0551af20c7cefa7464b7ccb6f5407a550b83c4dcb15c410814849/anywidget-0.9.18-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/2b/f0/09a30ca0551af20c7cefa7464b7ccb6f5407a550b83c4dcb15c410814849/anywidget-0.9.18-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/e0/b1/0542e0cab6f49f151a2d7a42400f84f706fc0b64e85dc1f56708b2e9fd37/array_api_compat-1.12.0-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/e0/b1/0542e0cab6f49f151a2d7a42400f84f706fc0b64e85dc1f56708b2e9fd37/array_api_compat-1.12.0-py3-none-any.whl
|
||||||
|
|
@ -209,6 +210,7 @@ environments:
|
||||||
- pypi: https://files.pythonhosted.org/packages/22/ff/6425bf5c20d79aa5b959d1ce9e65f599632345391381c9a104133fe0b171/matplotlib-3.10.7-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
- pypi: https://files.pythonhosted.org/packages/22/ff/6425bf5c20d79aa5b959d1ce9e65f599632345391381c9a104133fe0b171/matplotlib-3.10.7-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/5d/ba/459f18c16f2b3fc1a1ca871f72f07d70c07bf768ad0a507a698b8052ac58/msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
- pypi: https://files.pythonhosted.org/packages/5d/ba/459f18c16f2b3fc1a1ca871f72f07d70c07bf768ad0a507a698b8052ac58/msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/c6/2d/f0b184fa88d6630aa267680bdb8623fb69cb0d024b8c6f0d23f9a0f406d3/multidict-6.7.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
- pypi: https://files.pythonhosted.org/packages/c6/2d/f0b184fa88d6630aa267680bdb8623fb69cb0d024b8c6f0d23f9a0f406d3/multidict-6.7.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/93/cf/be4e93afbfa0def2cd6fac9302071db0bd6d0617999ecbf53f92b9398de3/multiurl-0.3.7-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/93/cf/be4e93afbfa0def2cd6fac9302071db0bd6d0617999ecbf53f92b9398de3/multiurl-0.3.7-py3-none-any.whl
|
||||||
|
|
@ -300,6 +302,9 @@ environments:
|
||||||
- pypi: https://files.pythonhosted.org/packages/06/af/413f6b172f9d4c4943b980a9fd96bb4d915680ce8f79c07de6f697b45c8b/ultraplot-1.65.1-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/06/af/413f6b172f9d4c4943b980a9fd96bb4d915680ce8f79c07de6f697b45c8b/ultraplot-1.65.1-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/a9/99/3ae339466c9183ea5b8ae87b34c0b897eda475d2aec2307cae60e5cd4f29/uritemplate-4.2.0-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/a9/99/3ae339466c9183ea5b8ae87b34c0b897eda475d2aec2307cae60e5cd4f29/uritemplate-4.2.0-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/e6/9f/ca52771fe972e0dcc5167fedb609940e01516066938ff2ee28b273ae4f29/vega_datasets-0.9.0-py3-none-any.whl
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/04/d5/81d1403788f072e7d0e2b2fe539a0ae4410f27886ff52df094e5348c99ea/vegafusion-2.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/a7/6b/48f6d47a92eaf6f0dd235146307a7eb0d179b78d2faebc53aca3f1e49177/vl_convert_python-1.8.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl
|
- pypi: https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/3f/0e/fa3b193432cfc60c93b42f3be03365f5f909d2b3ea410295cf36df739e31/widgetsnbextension-4.0.15-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/3f/0e/fa3b193432cfc60c93b42f3be03365f5f909d2b3ea410295cf36df739e31/widgetsnbextension-4.0.15-py3-none-any.whl
|
||||||
|
|
@ -417,6 +422,27 @@ packages:
|
||||||
- sphinxext-altair ; extra == 'doc'
|
- sphinxext-altair ; extra == 'doc'
|
||||||
- vl-convert-python>=1.7.0 ; extra == 'save'
|
- vl-convert-python>=1.7.0 ; extra == 'save'
|
||||||
requires_python: '>=3.9'
|
requires_python: '>=3.9'
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/c8/a7/a597ff7dd1e1603abd94991ce242f93979d5f10b0d45ed23976dfb22bf64/altair_tiles-0.4.0-py3-none-any.whl
|
||||||
|
name: altair-tiles
|
||||||
|
version: 0.4.0
|
||||||
|
sha256: eeed1a6d89800f6cf5aafa6a59ee735bc7c243cd133acebccabfdbf69cc7e33c
|
||||||
|
requires_dist:
|
||||||
|
- altair
|
||||||
|
- mercantile
|
||||||
|
- xyzservices
|
||||||
|
- geopandas ; extra == 'dev'
|
||||||
|
- ghp-import ; extra == 'dev'
|
||||||
|
- hatch ; extra == 'dev'
|
||||||
|
- ipykernel ; extra == 'dev'
|
||||||
|
- ipython ; extra == 'dev'
|
||||||
|
- mypy ; extra == 'dev'
|
||||||
|
- pytest ; extra == 'dev'
|
||||||
|
- ruff>=0.1.4 ; extra == 'dev'
|
||||||
|
- vega-datasets ; extra == 'dev'
|
||||||
|
- vl-convert-python ; extra == 'dev'
|
||||||
|
- jupyter-book ; extra == 'doc'
|
||||||
|
- vl-convert-python ; extra == 'doc'
|
||||||
|
requires_python: '>=3.9'
|
||||||
- pypi: https://files.pythonhosted.org/packages/2b/f0/09a30ca0551af20c7cefa7464b7ccb6f5407a550b83c4dcb15c410814849/anywidget-0.9.18-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/2b/f0/09a30ca0551af20c7cefa7464b7ccb6f5407a550b83c4dcb15c410814849/anywidget-0.9.18-py3-none-any.whl
|
||||||
name: anywidget
|
name: anywidget
|
||||||
version: 0.9.18
|
version: 0.9.18
|
||||||
|
|
@ -1274,7 +1300,7 @@ packages:
|
||||||
- pypi: ./
|
- pypi: ./
|
||||||
name: entropice
|
name: entropice
|
||||||
version: 0.1.0
|
version: 0.1.0
|
||||||
sha256: 39f2dabdc6891e121e03650dfde2f69b084370df8561d63478c6b3b518530e54
|
sha256: 4f45dd8bbe428416b7bcb3a904e31376735a9bbbc0d5438e91913e7477e3c0c0
|
||||||
requires_dist:
|
requires_dist:
|
||||||
- aiohttp>=3.12.11
|
- aiohttp>=3.12.11
|
||||||
- bokeh>=3.7.3
|
- bokeh>=3.7.3
|
||||||
|
|
@ -1292,7 +1318,6 @@ packages:
|
||||||
- geemap>=0.36.3
|
- geemap>=0.36.3
|
||||||
- geopandas>=1.1.0
|
- geopandas>=1.1.0
|
||||||
- h3>=4.2.2
|
- h3>=4.2.2
|
||||||
- h5netcdf>=1.6.4
|
|
||||||
- ipycytoscape>=1.3.3
|
- ipycytoscape>=1.3.3
|
||||||
- ipykernel>=6.29.5
|
- ipykernel>=6.29.5
|
||||||
- ipywidgets>=8.1.7
|
- ipywidgets>=8.1.7
|
||||||
|
|
@ -1321,6 +1346,8 @@ packages:
|
||||||
- zarr[remote]>=3.1.3
|
- zarr[remote]>=3.1.3
|
||||||
- geocube>=0.7.1,<0.8
|
- geocube>=0.7.1,<0.8
|
||||||
- streamlit>=1.50.0,<2
|
- streamlit>=1.50.0,<2
|
||||||
|
- altair[all]>=5.5.0,<6
|
||||||
|
- h5netcdf>=1.7.3,<2
|
||||||
editable: true
|
editable: true
|
||||||
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
||||||
name: entropy
|
name: entropy
|
||||||
|
|
@ -3159,6 +3186,15 @@ packages:
|
||||||
version: 0.1.2
|
version: 0.1.2
|
||||||
sha256: 84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8
|
sha256: 84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8
|
||||||
requires_python: '>=3.7'
|
requires_python: '>=3.7'
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl
|
||||||
|
name: mercantile
|
||||||
|
version: 1.2.1
|
||||||
|
sha256: 30f457a73ee88261aab787b7069d85961a5703bb09dc57a170190bc042cd023f
|
||||||
|
requires_dist:
|
||||||
|
- click>=3.0
|
||||||
|
- check-manifest ; extra == 'dev'
|
||||||
|
- hypothesis ; extra == 'test'
|
||||||
|
- pytest ; extra == 'test'
|
||||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/mkl-2024.2.2-ha770c72_17.conda
|
- conda: https://conda.anaconda.org/conda-forge/linux-64/mkl-2024.2.2-ha770c72_17.conda
|
||||||
sha256: 1e59d0dc811f150d39c2ff2da930d69dcb91cb05966b7df5b7d85133006668ed
|
sha256: 1e59d0dc811f150d39c2ff2da930d69dcb91cb05966b7df5b7d85133006668ed
|
||||||
md5: e4ab075598123e783b788b995afbdad0
|
md5: e4ab075598123e783b788b995afbdad0
|
||||||
|
|
@ -4728,6 +4764,27 @@ packages:
|
||||||
- pysocks>=1.5.6,!=1.5.7,<2.0 ; extra == 'socks'
|
- pysocks>=1.5.6,!=1.5.7,<2.0 ; extra == 'socks'
|
||||||
- zstandard>=0.18.0 ; extra == 'zstd'
|
- zstandard>=0.18.0 ; extra == 'zstd'
|
||||||
requires_python: '>=3.9'
|
requires_python: '>=3.9'
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/e6/9f/ca52771fe972e0dcc5167fedb609940e01516066938ff2ee28b273ae4f29/vega_datasets-0.9.0-py3-none-any.whl
|
||||||
|
name: vega-datasets
|
||||||
|
version: 0.9.0
|
||||||
|
sha256: 3d7c63917be6ca9b154b565f4779a31fedce57b01b5b9d99d8a34a7608062a1d
|
||||||
|
requires_dist:
|
||||||
|
- pandas
|
||||||
|
requires_python: '>=3.5'
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/04/d5/81d1403788f072e7d0e2b2fe539a0ae4410f27886ff52df094e5348c99ea/vegafusion-2.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||||
|
name: vegafusion
|
||||||
|
version: 2.0.3
|
||||||
|
sha256: 0b11c19a70f1bfe3d23d0a09aeecaac7bd03fac01a966d69fbd4dd8679dcb7e7
|
||||||
|
requires_dist:
|
||||||
|
- arro3-core
|
||||||
|
- packaging
|
||||||
|
- narwhals>=1.42
|
||||||
|
requires_python: '>=3.9'
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/a7/6b/48f6d47a92eaf6f0dd235146307a7eb0d179b78d2faebc53aca3f1e49177/vl_convert_python-1.8.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||||
|
name: vl-convert-python
|
||||||
|
version: 1.8.0
|
||||||
|
sha256: b51264998e8fcc43dbce801484a950cfe6513cdc4c46b20604ef50989855a617
|
||||||
|
requires_python: '>=3.7'
|
||||||
- pypi: https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl
|
- pypi: https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl
|
||||||
name: watchdog
|
name: watchdog
|
||||||
version: 6.0.0
|
version: 6.0.0
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ dependencies = [
|
||||||
"geemap>=0.36.3",
|
"geemap>=0.36.3",
|
||||||
"geopandas>=1.1.0",
|
"geopandas>=1.1.0",
|
||||||
"h3>=4.2.2",
|
"h3>=4.2.2",
|
||||||
"h5netcdf>=1.6.4",
|
|
||||||
"ipycytoscape>=1.3.3",
|
"ipycytoscape>=1.3.3",
|
||||||
"ipykernel>=6.29.5",
|
"ipykernel>=6.29.5",
|
||||||
"ipywidgets>=8.1.7",
|
"ipywidgets>=8.1.7",
|
||||||
|
|
@ -50,7 +49,7 @@ dependencies = [
|
||||||
"xvec>=0.5.1",
|
"xvec>=0.5.1",
|
||||||
"zarr[remote]>=3.1.3",
|
"zarr[remote]>=3.1.3",
|
||||||
"geocube>=0.7.1,<0.8",
|
"geocube>=0.7.1,<0.8",
|
||||||
"streamlit>=1.50.0,<2",
|
"streamlit>=1.50.0,<2", "altair[all]>=5.5.0,<6", "h5netcdf>=1.7.3,<2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
|
||||||
0
src/entropice/inference.py
Normal file
0
src/entropice/inference.py
Normal file
|
|
@ -87,10 +87,9 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||||
return dataset_file
|
return dataset_file
|
||||||
|
|
||||||
|
|
||||||
def get_cv_results_file(name: str, grid: Literal["hex", "healpix"], level: int) -> Path:
|
def get_cv_results_dir(name: str, grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}"
|
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}"
|
||||||
results_dir.mkdir(parents=True, exist_ok=True)
|
results_dir.mkdir(parents=True, exist_ok=True)
|
||||||
results_file = results_dir / "search_results.parquet"
|
return results_dir
|
||||||
return results_file
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,10 @@
|
||||||
# ruff: noqa: N806
|
# ruff: noqa: N806
|
||||||
"""Training dataset preparation and model training."""
|
"""Training dataset preparation and model training."""
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import cyclopts
|
import cyclopts
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
import matplotlib.colors as mcolors
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -21,7 +17,7 @@ from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
from entropice.paths import (
|
from entropice.paths import (
|
||||||
get_cv_results_file,
|
get_cv_results_dir,
|
||||||
get_darts_rts_file,
|
get_darts_rts_file,
|
||||||
get_embeddings_store,
|
get_embeddings_store,
|
||||||
get_era5_stores,
|
get_era5_stores,
|
||||||
|
|
@ -55,23 +51,41 @@ def prepare_dataset(grid: Literal["hex", "healpix"], level: int):
|
||||||
rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16))
|
rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16))
|
||||||
|
|
||||||
# Get era5 data
|
# Get era5 data
|
||||||
era5_store = get_era5_stores("yearly", grid=grid, level=level)
|
|
||||||
era5 = xr.open_zarr(era5_store, consolidated=False)
|
|
||||||
era5 = era5.sel(cell_ids=rts["cell_id"].values)
|
|
||||||
|
|
||||||
era5_df = []
|
era5_df = []
|
||||||
for var in era5.data_vars:
|
|
||||||
df = era5[var].drop_vars("spatial_ref").to_dataframe()
|
|
||||||
df["year"] = df.index.get_level_values("time").year
|
|
||||||
df = (
|
|
||||||
df.pivot_table(index="cell_ids", columns="year", values=var)
|
|
||||||
.rename(columns=lambda x: f"{var}_{x}")
|
|
||||||
.rename_axis(None, axis=1)
|
|
||||||
)
|
|
||||||
era5_df.append(df)
|
|
||||||
era5_df = pd.concat(era5_df, axis=1)
|
|
||||||
|
|
||||||
# TODO: season and shoulder data
|
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
|
||||||
|
seasons = {
|
||||||
|
10: "winter",
|
||||||
|
4: "summer",
|
||||||
|
}
|
||||||
|
for temporal in ["yearly", "seasonal", "shoulder"]:
|
||||||
|
era5_store = get_era5_stores(temporal, grid=grid, level=level)
|
||||||
|
era5 = xr.open_zarr(era5_store, consolidated=False)
|
||||||
|
era5 = era5.sel(cell_ids=rts["cell_id"].values)
|
||||||
|
|
||||||
|
for var in era5.data_vars:
|
||||||
|
df = era5[var].drop_vars("spatial_ref").to_dataframe()
|
||||||
|
if temporal == "yearly":
|
||||||
|
df["t"] = df.index.get_level_values("time").year
|
||||||
|
elif temporal == "seasonal":
|
||||||
|
df["t"] = (
|
||||||
|
df.index.get_level_values("time")
|
||||||
|
.month.map(lambda x: seasons.get(x))
|
||||||
|
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
||||||
|
)
|
||||||
|
elif temporal == "shoulder":
|
||||||
|
df["t"] = (
|
||||||
|
df.index.get_level_values("time")
|
||||||
|
.month.map(lambda x: shoulder_seasons.get(x))
|
||||||
|
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
||||||
|
)
|
||||||
|
df = (
|
||||||
|
df.pivot_table(index="cell_ids", columns="t", values=var)
|
||||||
|
.rename(columns=lambda x: f"{var}_{x}")
|
||||||
|
.rename_axis(None, axis=1)
|
||||||
|
)
|
||||||
|
era5_df.append(df)
|
||||||
|
era5_df = pd.concat(era5_df, axis=1)
|
||||||
|
|
||||||
# Get embeddings data
|
# Get embeddings data
|
||||||
embs_store = get_embeddings_store(grid=grid, level=level)
|
embs_store = get_embeddings_store(grid=grid, level=level)
|
||||||
|
|
@ -83,6 +97,7 @@ def prepare_dataset(grid: Literal["hex", "healpix"], level: int):
|
||||||
embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
|
embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
|
||||||
|
|
||||||
# Combine datasets by cell id / cell
|
# Combine datasets by cell id / cell
|
||||||
|
# TODO: use prefixes to easy split the features in analysis again
|
||||||
dataset = rts.set_index("cell_id").join(era5_df).join(embeddings_df)
|
dataset = rts.set_index("cell_id").join(era5_df).join(embeddings_df)
|
||||||
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
||||||
|
|
||||||
|
|
@ -91,12 +106,13 @@ def prepare_dataset(grid: Literal["hex", "healpix"], level: int):
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def random_cv(grid: Literal["hex", "healpix"], level: int):
|
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
||||||
"""Perform random cross-validation on the training dataset.
|
"""Perform random cross-validation on the training dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||||
level (int): The grid level to use.
|
level (int): The grid level to use.
|
||||||
|
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
data = get_train_dataset_file(grid=grid, level=level)
|
data = get_train_dataset_file(grid=grid, level=level)
|
||||||
|
|
@ -125,8 +141,8 @@ def random_cv(grid: Literal["hex", "healpix"], level: int):
|
||||||
search = RandomizedSearchCV(
|
search = RandomizedSearchCV(
|
||||||
clf,
|
clf,
|
||||||
param_grid,
|
param_grid,
|
||||||
n_iter=20,
|
n_iter=n_iter,
|
||||||
n_jobs=24,
|
n_jobs=20,
|
||||||
cv=cv,
|
cv=cv,
|
||||||
random_state=42,
|
random_state=42,
|
||||||
verbose=10,
|
verbose=10,
|
||||||
|
|
@ -155,141 +171,57 @@ def random_cv(grid: Literal["hex", "healpix"], level: int):
|
||||||
results = pd.concat([results.drop(columns=["params"]), params], axis=1)
|
results = pd.concat([results.drop(columns=["params"]), params], axis=1)
|
||||||
results["grid"] = grid
|
results["grid"] = grid
|
||||||
results["level"] = level
|
results["level"] = level
|
||||||
results_file = get_cv_results_file("random_search", grid=grid, level=level)
|
results_dir = get_cv_results_dir("random_search", grid=grid, level=level)
|
||||||
|
results_file = results_dir / "search_results.parquet"
|
||||||
print(f"Storing CV results to {results_file}")
|
print(f"Storing CV results to {results_file}")
|
||||||
results.to_parquet(results_file)
|
results.to_parquet(results_file)
|
||||||
|
|
||||||
|
# Get the inner state of the best estimator
|
||||||
|
best_estimator = search.best_estimator_
|
||||||
|
# Annotate the state with xarray metadata
|
||||||
|
features = X_data.columns.tolist()
|
||||||
|
labels = y_data.unique().tolist()
|
||||||
|
boxes = list(range(best_estimator.K_))
|
||||||
|
box_centers = xr.DataArray(
|
||||||
|
best_estimator.S_.cpu().numpy(),
|
||||||
|
dims=["feature", "box"],
|
||||||
|
coords={"feature": features, "box": boxes},
|
||||||
|
name="box_centers",
|
||||||
|
attrs={"description": "Centers of the boxes in feature space."},
|
||||||
|
)
|
||||||
|
box_assignments = xr.DataArray(
|
||||||
|
best_estimator.Lambda_.cpu().numpy(),
|
||||||
|
dims=["class", "box"],
|
||||||
|
coords={"class": labels, "box": boxes},
|
||||||
|
name="box_assignments",
|
||||||
|
attrs={"description": "Assignments of samples to boxes."},
|
||||||
|
)
|
||||||
|
feature_weights = xr.DataArray(
|
||||||
|
best_estimator.W_.cpu().numpy(),
|
||||||
|
dims=["feature"],
|
||||||
|
coords={"feature": features},
|
||||||
|
name="feature_weights",
|
||||||
|
attrs={"description": "Feature weights for each box."},
|
||||||
|
)
|
||||||
|
state = xr.Dataset(
|
||||||
|
{
|
||||||
|
"box_centers": box_centers,
|
||||||
|
"box_assignments": box_assignments,
|
||||||
|
"feature_weights": feature_weights,
|
||||||
|
},
|
||||||
|
attrs={
|
||||||
|
"description": "Inner state of the best ESPAClassifier from RandomizedSearchCV.",
|
||||||
|
"grid": grid,
|
||||||
|
"level": level,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
state_file = results_dir / "best_estimator_state.nc"
|
||||||
|
print(f"Storing best estimator state to {state_file}")
|
||||||
|
state.to_netcdf(state_file, engine="h5netcdf")
|
||||||
|
|
||||||
stopwatch.summary()
|
stopwatch.summary()
|
||||||
print("Done.")
|
print("Done.")
|
||||||
|
|
||||||
plot_random_cv_results(results_file)
|
|
||||||
|
|
||||||
|
|
||||||
def _plot_k_binned(
|
|
||||||
results: pd.DataFrame, target: str, *, vmin_percentile: float | None = None, vmax_percentile: float | None = None
|
|
||||||
):
|
|
||||||
assert vmin_percentile is None or vmax_percentile is None, (
|
|
||||||
"Only one of vmin_percentile or vmax_percentile can be set."
|
|
||||||
)
|
|
||||||
assert "initial_K_binned" in results.columns, "initial_K_binned column not found in results."
|
|
||||||
assert target in results.columns, f"{target} column not found in results."
|
|
||||||
assert "eps_e" in results.columns, "eps_e column not found in results."
|
|
||||||
assert "eps_cl" in results.columns, "eps_cl column not found in results."
|
|
||||||
|
|
||||||
# add a colorbar instead of the sampled legend
|
|
||||||
cmap = sns.color_palette("ch:", as_cmap=True)
|
|
||||||
# sufisticated normalization
|
|
||||||
if vmin_percentile is not None:
|
|
||||||
vmin = np.percentile(results[target], vmin_percentile)
|
|
||||||
norm = mcolors.Normalize(vmin=vmin)
|
|
||||||
elif vmax_percentile is not None:
|
|
||||||
vmax = np.percentile(results[target], vmax_percentile)
|
|
||||||
norm = mcolors.Normalize(vmax=vmax)
|
|
||||||
else:
|
|
||||||
norm = mcolors.Normalize()
|
|
||||||
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
||||||
|
|
||||||
# nice col-wrap based on columns
|
|
||||||
n_cols = results["initial_K_binned"].unique().size
|
|
||||||
col_wrap = 5 if n_cols % 5 == 0 else (4 if n_cols % 4 == 0 else 3)
|
|
||||||
|
|
||||||
scatter = sns.relplot(
|
|
||||||
data=results,
|
|
||||||
x="eps_e",
|
|
||||||
y="eps_cl",
|
|
||||||
hue=target,
|
|
||||||
hue_norm=sm.norm,
|
|
||||||
palette=cmap,
|
|
||||||
legend=False,
|
|
||||||
col="initial_K_binned",
|
|
||||||
col_wrap=col_wrap,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply log scale to all axes
|
|
||||||
for ax in scatter.axes.flat:
|
|
||||||
ax.set_xscale("log")
|
|
||||||
ax.set_yscale("log")
|
|
||||||
|
|
||||||
# Tight layout
|
|
||||||
scatter.figure.tight_layout()
|
|
||||||
|
|
||||||
# Add a shared colorbar at the bottom
|
|
||||||
scatter.figure.subplots_adjust(bottom=0.15) # Make room for the colorbar
|
|
||||||
cbar_ax = scatter.figure.add_axes([0.15, 0.05, 0.7, 0.02]) # [left, bottom, width, height]
|
|
||||||
cbar = scatter.figure.colorbar(sm, cax=cbar_ax, orientation="horizontal")
|
|
||||||
cbar.set_label(target)
|
|
||||||
|
|
||||||
return scatter
|
|
||||||
|
|
||||||
|
|
||||||
def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
|
|
||||||
assert "initial_K" in results.columns, "initial_K column not found in results."
|
|
||||||
assert metric in results.columns, f"{metric} not found in results."
|
|
||||||
|
|
||||||
if target == "eps_cl":
|
|
||||||
hue = "eps_cl"
|
|
||||||
col = "eps_e_binned"
|
|
||||||
elif target == "eps_e":
|
|
||||||
hue = "eps_e"
|
|
||||||
col = "eps_cl_binned"
|
|
||||||
assert hue in results.columns, f"{hue} column not found in results."
|
|
||||||
assert col in results.columns, f"{col} column not found in results."
|
|
||||||
|
|
||||||
return sns.relplot(results, x="initial_K", y=metric, hue=hue, col=col, col_wrap=5, hue_norm=mcolors.LogNorm())
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
|
||||||
def plot_random_cv_results(file: Path):
|
|
||||||
"""Plot analysis of the results from the RandomCVSearch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (Path): The file of the results.
|
|
||||||
|
|
||||||
"""
|
|
||||||
print(f"Plotting random CV results from {file}...")
|
|
||||||
results = pd.read_parquet(file)
|
|
||||||
# Bin the initial_K into 40er bins
|
|
||||||
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=range(20, 401, 40), right=False)
|
|
||||||
# Bin the eps_cl and eps_e into logarithmic bins
|
|
||||||
eps_cl_bins = np.logspace(-3, 7, num=10)
|
|
||||||
eps_e_bins = np.logspace(-3, 7, num=10)
|
|
||||||
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
|
|
||||||
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
|
|
||||||
|
|
||||||
figdir = file.parent
|
|
||||||
|
|
||||||
# K-Plots
|
|
||||||
metrics = ["f1"]
|
|
||||||
for metric in metrics:
|
|
||||||
_plot_k_binned(
|
|
||||||
results,
|
|
||||||
f"mean_test_{metric}",
|
|
||||||
vmin_percentile=50,
|
|
||||||
).figure.savefig(figdir / f"params3d-mean_{metric}.pdf")
|
|
||||||
_plot_k_binned(
|
|
||||||
results,
|
|
||||||
f"std_test_{metric}",
|
|
||||||
vmax_percentile=50,
|
|
||||||
).figure.savefig(figdir / f"params3d-std_{metric}.pdf")
|
|
||||||
_plot_k_binned(results, f"mean_test_{metric}").figure.savefig(figdir / f"params3d-mean_{metric}-noperc.pdf")
|
|
||||||
_plot_k_binned(results, f"std_test_{metric}").figure.savefig(figdir / f"params3d-std_{metric}-noperc.pdf")
|
|
||||||
|
|
||||||
# eps-Plots
|
|
||||||
_plot_eps_binned(
|
|
||||||
results,
|
|
||||||
"eps_cl",
|
|
||||||
f"mean_test_{metric}",
|
|
||||||
).figure.savefig(figdir / f"k-eps_cl-mean_{metric}.pdf")
|
|
||||||
_plot_eps_binned(
|
|
||||||
results,
|
|
||||||
"eps_e",
|
|
||||||
f"mean_test_{metric}",
|
|
||||||
).figure.savefig(figdir / f"k-eps_e-mean_{metric}.pdf")
|
|
||||||
|
|
||||||
# Close all figures
|
|
||||||
plt.close("all")
|
|
||||||
print("Done.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,16 @@
|
||||||
"""Streamlit dashboard for training analysis results visualization."""
|
"""Streamlit dashboard for training analysis results visualization."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import matplotlib.colors as mcolors
|
import altair as alt
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import seaborn as sns
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
import xarray as xr
|
||||||
|
|
||||||
from entropice.paths import RESULTS_DIR
|
from entropice.paths import RESULTS_DIR
|
||||||
|
|
||||||
sns.set_theme("talk", "whitegrid")
|
|
||||||
|
|
||||||
|
|
||||||
def _plot_k_binned(
|
def _plot_k_binned(
|
||||||
results: pd.DataFrame,
|
results: pd.DataFrame,
|
||||||
|
|
@ -30,50 +28,47 @@ def _plot_k_binned(
|
||||||
assert "eps_e" in results.columns, "eps_e column not found in results."
|
assert "eps_e" in results.columns, "eps_e column not found in results."
|
||||||
assert "eps_cl" in results.columns, "eps_cl column not found in results."
|
assert "eps_cl" in results.columns, "eps_cl column not found in results."
|
||||||
|
|
||||||
# add a colorbar instead of the sampled legend
|
# Prepare data
|
||||||
cmap = sns.color_palette("ch:", as_cmap=True)
|
plot_data = results[["eps_e", "eps_cl", "initial_K_binned", target]].copy()
|
||||||
# sophisticated normalization
|
|
||||||
|
# Sort bins by their left value and convert to string with sorted categories
|
||||||
|
plot_data = plot_data.sort_values("initial_K_binned")
|
||||||
|
plot_data["initial_K_binned"] = plot_data["initial_K_binned"].astype(str)
|
||||||
|
bin_order = plot_data["initial_K_binned"].unique().tolist()
|
||||||
|
|
||||||
|
# Determine color scale domain
|
||||||
if vmin_percentile is not None:
|
if vmin_percentile is not None:
|
||||||
vmin = np.percentile(results[target], vmin_percentile)
|
vmin = np.percentile(results[target], vmin_percentile)
|
||||||
norm = mcolors.Normalize(vmin=vmin)
|
color_scale = alt.Scale(scheme="viridis", domain=[vmin, plot_data[target].max()])
|
||||||
elif vmax_percentile is not None:
|
elif vmax_percentile is not None:
|
||||||
vmax = np.percentile(results[target], vmax_percentile)
|
vmax = np.percentile(results[target], vmax_percentile)
|
||||||
norm = mcolors.Normalize(vmax=vmax)
|
color_scale = alt.Scale(scheme="viridis", domain=[plot_data[target].min(), vmax])
|
||||||
else:
|
else:
|
||||||
norm = mcolors.Normalize()
|
color_scale = alt.Scale(scheme="viridis")
|
||||||
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
||||||
|
|
||||||
# nice col-wrap based on columns
|
# Create the chart
|
||||||
n_cols = results["initial_K_binned"].unique().size
|
chart = (
|
||||||
col_wrap = 5 if n_cols % 5 == 0 else (4 if n_cols % 4 == 0 else 3)
|
alt.Chart(plot_data)
|
||||||
|
.mark_circle(size=60, opacity=0.7)
|
||||||
scatter = sns.relplot(
|
.encode(
|
||||||
data=results,
|
x=alt.X(
|
||||||
x="eps_e",
|
"eps_e:Q",
|
||||||
y="eps_cl",
|
scale=alt.Scale(type="log"),
|
||||||
hue=target,
|
axis=alt.Axis(title="eps_e", grid=True, gridOpacity=0.5),
|
||||||
hue_norm=sm.norm,
|
),
|
||||||
palette=cmap,
|
y=alt.Y(
|
||||||
legend=False,
|
"eps_cl:Q",
|
||||||
col="initial_K_binned",
|
scale=alt.Scale(type="log"),
|
||||||
col_wrap=col_wrap,
|
axis=alt.Axis(title="eps_cl", grid=True, gridOpacity=0.5),
|
||||||
|
),
|
||||||
|
color=alt.Color(f"{target}:Q", scale=color_scale, title=target),
|
||||||
|
tooltip=["eps_e:Q", "eps_cl:Q", alt.Tooltip(f"{target}:Q", format=".4f"), "initial_K_binned:N"],
|
||||||
|
)
|
||||||
|
.properties(width=200, height=200)
|
||||||
|
.facet(facet=alt.Facet("initial_K_binned:N", title="Initial K", sort=bin_order), columns=5)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply log scale to all axes
|
return chart
|
||||||
for ax in scatter.axes.flat:
|
|
||||||
ax.set_xscale("log")
|
|
||||||
ax.set_yscale("log")
|
|
||||||
|
|
||||||
# Tight layout
|
|
||||||
scatter.figure.tight_layout()
|
|
||||||
|
|
||||||
# Add a shared colorbar at the bottom
|
|
||||||
scatter.figure.subplots_adjust(bottom=0.15) # Make room for the colorbar
|
|
||||||
cbar_ax = scatter.figure.add_axes([0.15, 0.05, 0.7, 0.02]) # [left, bottom, width, height]
|
|
||||||
cbar = scatter.figure.colorbar(sm, cax=cbar_ax, orientation="horizontal")
|
|
||||||
cbar.set_label(target)
|
|
||||||
|
|
||||||
return scatter
|
|
||||||
|
|
||||||
|
|
||||||
def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
|
def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
|
||||||
|
|
@ -93,25 +88,183 @@ def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
|
||||||
assert hue in results.columns, f"{hue} column not found in results."
|
assert hue in results.columns, f"{hue} column not found in results."
|
||||||
assert col in results.columns, f"{col} column not found in results."
|
assert col in results.columns, f"{col} column not found in results."
|
||||||
|
|
||||||
return sns.relplot(results, x="initial_K", y=metric, hue=hue, col=col, col_wrap=5, hue_norm=mcolors.LogNorm())
|
# Prepare data
|
||||||
|
plot_data = results[["initial_K", metric, hue, col]].copy()
|
||||||
|
|
||||||
|
# Sort bins by their left value and convert to string with sorted categories
|
||||||
|
plot_data = plot_data.sort_values(col)
|
||||||
|
plot_data[col] = plot_data[col].astype(str)
|
||||||
|
bin_order = plot_data[col].unique().tolist()
|
||||||
|
|
||||||
|
# Create the chart
|
||||||
|
chart = (
|
||||||
|
alt.Chart(plot_data)
|
||||||
|
.mark_circle(size=60, opacity=0.7)
|
||||||
|
.encode(
|
||||||
|
x=alt.X("initial_K:Q", title="Initial K"),
|
||||||
|
y=alt.Y(f"{metric}:Q", title=metric),
|
||||||
|
color=alt.Color(f"{hue}:Q", scale=alt.Scale(type="log", scheme="viridis"), title=hue),
|
||||||
|
tooltip=[
|
||||||
|
"initial_K:Q",
|
||||||
|
alt.Tooltip(f"{metric}:Q", format=".4f"),
|
||||||
|
alt.Tooltip(f"{hue}:Q", format=".2e"),
|
||||||
|
f"{col}:N",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(width=200, height=200)
|
||||||
|
.facet(facet=alt.Facet(f"{col}:N", title=col.replace("_binned", ""), sort=bin_order), columns=5)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chart
|
||||||
|
|
||||||
|
|
||||||
def load_and_prepare_results(file_path: Path) -> pd.DataFrame:
|
def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame:
|
||||||
"""Load results file and prepare binned columns."""
|
"""Load results file and prepare binned columns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the results parquet file.
|
||||||
|
k_bin_width: Width of bins for initial_K parameter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with added binned columns.
|
||||||
|
|
||||||
|
"""
|
||||||
results = pd.read_parquet(file_path)
|
results = pd.read_parquet(file_path)
|
||||||
|
|
||||||
# Bin the initial_K into 40er bins
|
# Automatically determine bin width for initial_K based on data range
|
||||||
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=range(20, 401, 40), right=False)
|
k_min = results["initial_K"].min()
|
||||||
|
k_max = results["initial_K"].max()
|
||||||
|
# Use configurable bin width, adapted to actual data range
|
||||||
|
k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width)
|
||||||
|
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False)
|
||||||
|
|
||||||
|
# Automatically create logarithmic bins for epsilon parameters based on data range
|
||||||
|
# Use 10 bins spanning the actual data range
|
||||||
|
eps_cl_min = np.log10(results["eps_cl"].min())
|
||||||
|
eps_cl_max = np.log10(results["eps_cl"].max())
|
||||||
|
eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10)
|
||||||
|
|
||||||
|
eps_e_min = np.log10(results["eps_e"].min())
|
||||||
|
eps_e_max = np.log10(results["eps_e"].max())
|
||||||
|
eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10)
|
||||||
|
|
||||||
# Bin the eps_cl and eps_e into logarithmic bins
|
|
||||||
eps_cl_bins = np.logspace(-3, 7, num=10)
|
|
||||||
eps_e_bins = np.logspace(-3, 7, num=10)
|
|
||||||
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
|
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
|
||||||
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
|
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_prepare_model_state(file_path: Path) -> xr.Dataset:
|
||||||
|
"""Load a model state from a NetCDF file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (Path): The path to the NetCDF file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
xr.Dataset: The model state as an xarray Dataset.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return xr.open_dataset(file_path, engine="h5netcdf")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_embedding_features(model_state: xr.Dataset) -> xr.DataArray | None:
|
||||||
|
"""Extract embedding features from the model state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_state: The xarray Dataset containing the model state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
xr.DataArray: The extracted embedding features. This DataArray has dimensions
|
||||||
|
('agg', 'band', 'year') corresponding to the different components of the embedding features.
|
||||||
|
Returns None if no embedding features are found.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _is_embedding_feature(feature: str) -> bool:
|
||||||
|
parts = feature.split("_")
|
||||||
|
if len(parts) != 3:
|
||||||
|
return False
|
||||||
|
_, band, _ = parts
|
||||||
|
if not band.startswith("A"):
|
||||||
|
return False
|
||||||
|
if not band[1:].isdigit():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)]
|
||||||
|
if len(embedding_features) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Split the single feature dimension of embedding features into separate dimensions (agg, band, year)
|
||||||
|
embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"]
|
||||||
|
embedding_feature_array = embedding_feature_array.assign_coords(
|
||||||
|
agg=("feature", [f.split("_")[0] for f in embedding_features]),
|
||||||
|
band=("feature", [f.split("_")[1] for f in embedding_features]),
|
||||||
|
year=("feature", [f.split("_")[2] for f in embedding_features]),
|
||||||
|
)
|
||||||
|
embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010
|
||||||
|
return embedding_feature_array
|
||||||
|
|
||||||
|
|
||||||
|
def extract_era5_features(model_state: xr.Dataset) -> xr.DataArray | None:
|
||||||
|
"""Extract ERA5 features from the model state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_state: The xarray Dataset containing the model state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
xr.DataArray: The extracted ERA5 features. This DataArray has dimensions
|
||||||
|
('variable', 'time') corresponding to the different components of the ERA5 features.
|
||||||
|
Returns None if no ERA5 features are found.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _is_era5_feature(feature: str) -> bool:
|
||||||
|
# Instant fit if winter or summer in the name
|
||||||
|
if "winter" in feature or "spring" in feature:
|
||||||
|
return True
|
||||||
|
# Instant fit if OND, JFM, AMJ or JAS in the name
|
||||||
|
if any(season in feature for season in ["OND", "JFM", "AMJ", "JAS"]):
|
||||||
|
return True
|
||||||
|
parts = feature.split("_")
|
||||||
|
if len(parts) == 3:
|
||||||
|
_, band, year = parts
|
||||||
|
if band.startswith("A"):
|
||||||
|
return False
|
||||||
|
if year.isdigit():
|
||||||
|
return True
|
||||||
|
elif parts[-1].isdigit():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _extract_var_name(feature: str) -> str:
|
||||||
|
if any(season in feature for season in ["spring", "winter", "OND", "JFM", "AMJ", "JAS"]):
|
||||||
|
return feature.rsplit("_", 2)[0]
|
||||||
|
else:
|
||||||
|
return feature.rsplit("_", 1)[0]
|
||||||
|
|
||||||
|
def _extract_time_name(feature: str) -> str:
|
||||||
|
if any(season in feature for season in ["spring", "winter", "OND", "JFM", "AMJ", "JAS"]):
|
||||||
|
return "_".join(feature.rsplit("_", 2)[-2:])
|
||||||
|
else:
|
||||||
|
return feature.rsplit("_", 1)[-1]
|
||||||
|
|
||||||
|
era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)]
|
||||||
|
if len(era5_features) == 0:
|
||||||
|
return None
|
||||||
|
# Split the single feature dimension of era5 features into separate dimensions (variable, time)
|
||||||
|
era5_features_array = model_state.sel(feature=era5_features)["feature_weights"]
|
||||||
|
era5_features_array = era5_features_array.assign_coords(
|
||||||
|
variable=("feature", [_extract_var_name(f) for f in era5_features]),
|
||||||
|
time=("feature", [_extract_time_name(f) for f in era5_features]),
|
||||||
|
)
|
||||||
|
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010
|
||||||
|
return era5_features_array
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Extract common features, e.g. area or water content
|
||||||
|
|
||||||
|
|
||||||
def get_available_result_files() -> list[Path]:
|
def get_available_result_files() -> list[Path]:
|
||||||
"""Get all available result files from RESULTS_DIR."""
|
"""Get all available result files from RESULTS_DIR."""
|
||||||
if not RESULTS_DIR.exists():
|
if not RESULTS_DIR.exists():
|
||||||
|
|
@ -119,14 +272,308 @@ def get_available_result_files() -> list[Path]:
|
||||||
|
|
||||||
result_files = []
|
result_files = []
|
||||||
for search_dir in RESULTS_DIR.iterdir():
|
for search_dir in RESULTS_DIR.iterdir():
|
||||||
if search_dir.is_dir():
|
if not search_dir.is_dir():
|
||||||
result_file = search_dir / "search_results.parquet"
|
continue
|
||||||
if result_file.exists():
|
|
||||||
result_files.append(result_file)
|
result_file = search_dir / "search_results.parquet"
|
||||||
|
state_file = search_dir / "best_estimator_state.nc"
|
||||||
|
if result_file.exists() and state_file.exists():
|
||||||
|
result_files.append(search_dir)
|
||||||
|
|
||||||
return sorted(result_files, reverse=True) # Most recent first
|
return sorted(result_files, reverse=True) # Most recent first
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_results_dir_name(results_dir: Path) -> str:
|
||||||
|
gridname, date = results_dir.name.split("_random_search_cv")
|
||||||
|
gridname = gridname.lstrip("permafrost_")
|
||||||
|
date = datetime.strptime(date, "%Y%m%d-%H%M%S")
|
||||||
|
date = date.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
return f"{gridname} ({date})"
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_top_features(model_state: xr.Dataset, top_n: int = 10):
|
||||||
|
"""Plot the top N most important features based on feature weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_state: The xarray Dataset containing the model state.
|
||||||
|
top_n: Number of top features to display.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Altair chart showing the top features by importance.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Extract feature weights
|
||||||
|
feature_weights = model_state["feature_weights"].to_pandas()
|
||||||
|
|
||||||
|
# Sort by absolute weight and take top N
|
||||||
|
top_features = feature_weights.abs().nlargest(top_n).sort_values(ascending=True)
|
||||||
|
|
||||||
|
# Create DataFrame for plotting with original (signed) weights
|
||||||
|
plot_data = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"feature": top_features.index,
|
||||||
|
"weight": feature_weights.loc[top_features.index].to_numpy(),
|
||||||
|
"abs_weight": top_features.to_numpy(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create horizontal bar chart
|
||||||
|
chart = (
|
||||||
|
alt.Chart(plot_data)
|
||||||
|
.mark_bar()
|
||||||
|
.encode(
|
||||||
|
y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)),
|
||||||
|
x=alt.X("weight:Q", title="Feature Weight (scaled by number of features)"),
|
||||||
|
color=alt.condition(
|
||||||
|
alt.datum.weight > 0,
|
||||||
|
alt.value("steelblue"), # Positive weights
|
||||||
|
alt.value("coral"), # Negative weights
|
||||||
|
),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip("feature:N", title="Feature"),
|
||||||
|
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
|
||||||
|
alt.Tooltip("abs_weight:Q", format=".4f", title="Absolute Weight"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(
|
||||||
|
width=600,
|
||||||
|
height=400,
|
||||||
|
title=f"Top {top_n} Most Important Features",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chart
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_embedding_heatmap(embedding_array: xr.DataArray):
|
||||||
|
"""Create a heatmap showing embedding feature weights across bands and years.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_array: DataArray with dimensions (agg, band, year) containing feature weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Altair chart showing the heatmap.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Convert to DataFrame for plotting
|
||||||
|
df = embedding_array.to_dataframe(name="weight").reset_index()
|
||||||
|
|
||||||
|
# Create faceted heatmap
|
||||||
|
chart = (
|
||||||
|
alt.Chart(df)
|
||||||
|
.mark_rect()
|
||||||
|
.encode(
|
||||||
|
x=alt.X("year:O", title="Year"),
|
||||||
|
y=alt.Y("band:O", title="Band", sort=alt.SortField(field="band", order="ascending")),
|
||||||
|
color=alt.Color(
|
||||||
|
"weight:Q",
|
||||||
|
scale=alt.Scale(scheme="redblue", domainMid=0),
|
||||||
|
title="Weight",
|
||||||
|
),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip("agg:N", title="Aggregation"),
|
||||||
|
alt.Tooltip("band:N", title="Band"),
|
||||||
|
alt.Tooltip("year:O", title="Year"),
|
||||||
|
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(width=200, height=200)
|
||||||
|
.facet(facet=alt.Facet("agg:N", title="Aggregation"), columns=11)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chart
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_embedding_aggregation_summary(embedding_array: xr.DataArray):
|
||||||
|
"""Create bar charts summarizing embedding weights by aggregation, band, and year.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_array: DataArray with dimensions (agg, band, year) containing feature weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of three Altair charts (by_agg, by_band, by_year).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Aggregate by different dimensions
|
||||||
|
by_agg = embedding_array.mean(dim=["band", "year"]).to_pandas().abs()
|
||||||
|
by_band = embedding_array.mean(dim=["agg", "year"]).to_pandas().abs()
|
||||||
|
by_year = embedding_array.mean(dim=["agg", "band"]).to_pandas().abs()
|
||||||
|
|
||||||
|
# Create DataFrames
|
||||||
|
df_agg = pd.DataFrame({"dimension": by_agg.index, "mean_abs_weight": by_agg.to_numpy()})
|
||||||
|
df_band = pd.DataFrame({"dimension": by_band.index, "mean_abs_weight": by_band.to_numpy()})
|
||||||
|
df_year = pd.DataFrame({"dimension": by_year.index, "mean_abs_weight": by_year.to_numpy()})
|
||||||
|
|
||||||
|
# Sort by weight
|
||||||
|
df_agg = df_agg.sort_values("mean_abs_weight", ascending=True)
|
||||||
|
df_band = df_band.sort_values("mean_abs_weight", ascending=True)
|
||||||
|
df_year = df_year.sort_values("mean_abs_weight", ascending=True)
|
||||||
|
|
||||||
|
# Create charts with different colors
|
||||||
|
chart_agg = (
|
||||||
|
alt.Chart(df_agg)
|
||||||
|
.mark_bar()
|
||||||
|
.encode(
|
||||||
|
y=alt.Y("dimension:N", title="Aggregation", sort="-x"),
|
||||||
|
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||||
|
color=alt.Color(
|
||||||
|
"mean_abs_weight:Q",
|
||||||
|
scale=alt.Scale(scheme="blues"),
|
||||||
|
legend=None,
|
||||||
|
),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip("dimension:N", title="Aggregation"),
|
||||||
|
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(width=250, height=200, title="By Aggregation")
|
||||||
|
)
|
||||||
|
|
||||||
|
chart_band = (
|
||||||
|
alt.Chart(df_band)
|
||||||
|
.mark_bar()
|
||||||
|
.encode(
|
||||||
|
y=alt.Y("dimension:N", title="Band", sort="-x"),
|
||||||
|
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||||
|
color=alt.Color(
|
||||||
|
"mean_abs_weight:Q",
|
||||||
|
scale=alt.Scale(scheme="greens"),
|
||||||
|
legend=None,
|
||||||
|
),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip("dimension:N", title="Band"),
|
||||||
|
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(width=250, height=200, title="By Band")
|
||||||
|
)
|
||||||
|
|
||||||
|
chart_year = (
|
||||||
|
alt.Chart(df_year)
|
||||||
|
.mark_bar()
|
||||||
|
.encode(
|
||||||
|
y=alt.Y("dimension:O", title="Year", sort="-x"),
|
||||||
|
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||||
|
color=alt.Color(
|
||||||
|
"mean_abs_weight:Q",
|
||||||
|
scale=alt.Scale(scheme="oranges"),
|
||||||
|
legend=None,
|
||||||
|
),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip("dimension:O", title="Year"),
|
||||||
|
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(width=250, height=200, title="By Year")
|
||||||
|
)
|
||||||
|
|
||||||
|
return chart_agg, chart_band, chart_year
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_era5_heatmap(era5_array: xr.DataArray):
|
||||||
|
"""Create a heatmap showing ERA5 feature weights across variables and time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
era5_array: DataArray with dimensions (variable, time) containing feature weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Altair chart showing the heatmap.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Convert to DataFrame for plotting
|
||||||
|
df = era5_array.to_dataframe(name="weight").reset_index()
|
||||||
|
|
||||||
|
# Create heatmap
|
||||||
|
chart = (
|
||||||
|
alt.Chart(df)
|
||||||
|
.mark_rect()
|
||||||
|
.encode(
|
||||||
|
x=alt.X("time:N", title="Time", sort=None),
|
||||||
|
y=alt.Y("variable:N", title="Variable", sort="-color"),
|
||||||
|
color=alt.Color(
|
||||||
|
"weight:Q",
|
||||||
|
scale=alt.Scale(scheme="redblue", domainMid=0),
|
||||||
|
title="Weight",
|
||||||
|
),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip("variable:N", title="Variable"),
|
||||||
|
alt.Tooltip("time:N", title="Time"),
|
||||||
|
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(
|
||||||
|
height=400,
|
||||||
|
title="ERA5 Feature Weights Heatmap",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chart
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_era5_summary(era5_array: xr.DataArray):
|
||||||
|
"""Create bar charts summarizing ERA5 weights by variable and time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
era5_array: DataArray with dimensions (variable, time) containing feature weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of two Altair charts (by_variable, by_time).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Aggregate by different dimensions
|
||||||
|
by_variable = era5_array.mean(dim="time").to_pandas().abs()
|
||||||
|
by_time = era5_array.mean(dim="variable").to_pandas().abs()
|
||||||
|
|
||||||
|
# Create DataFrames
|
||||||
|
df_variable = pd.DataFrame({"dimension": by_variable.index, "mean_abs_weight": by_variable.to_numpy()})
|
||||||
|
df_time = pd.DataFrame({"dimension": by_time.index, "mean_abs_weight": by_time.to_numpy()})
|
||||||
|
|
||||||
|
# Sort by weight
|
||||||
|
df_variable = df_variable.sort_values("mean_abs_weight", ascending=True)
|
||||||
|
df_time = df_time.sort_values("mean_abs_weight", ascending=True)
|
||||||
|
|
||||||
|
# Create charts with different colors
|
||||||
|
chart_variable = (
|
||||||
|
alt.Chart(df_variable)
|
||||||
|
.mark_bar()
|
||||||
|
.encode(
|
||||||
|
y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)),
|
||||||
|
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||||
|
color=alt.Color(
|
||||||
|
"mean_abs_weight:Q",
|
||||||
|
scale=alt.Scale(scheme="purples"),
|
||||||
|
legend=None,
|
||||||
|
),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip("dimension:N", title="Variable"),
|
||||||
|
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(width=400, height=300, title="By Variable")
|
||||||
|
)
|
||||||
|
|
||||||
|
chart_time = (
|
||||||
|
alt.Chart(df_time)
|
||||||
|
.mark_bar()
|
||||||
|
.encode(
|
||||||
|
y=alt.Y("dimension:N", title="Time", sort="-x", axis=alt.Axis(labelLimit=200)),
|
||||||
|
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||||
|
color=alt.Color(
|
||||||
|
"mean_abs_weight:Q",
|
||||||
|
scale=alt.Scale(scheme="teals"),
|
||||||
|
legend=None,
|
||||||
|
),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip("dimension:N", title="Time"),
|
||||||
|
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(width=400, height=300, title="By Time")
|
||||||
|
)
|
||||||
|
|
||||||
|
return chart_variable, chart_time
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Run Streamlit dashboard application."""
|
"""Run Streamlit dashboard application."""
|
||||||
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
|
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
|
||||||
|
|
@ -138,23 +585,30 @@ def main():
|
||||||
st.sidebar.header("Configuration")
|
st.sidebar.header("Configuration")
|
||||||
|
|
||||||
# Get available result files
|
# Get available result files
|
||||||
result_files = get_available_result_files()
|
result_dirs = get_available_result_files()
|
||||||
|
|
||||||
if not result_files:
|
if not result_dirs:
|
||||||
st.error(f"No result files found in {RESULTS_DIR}")
|
st.error(f"No result files found in {RESULTS_DIR}")
|
||||||
st.info("Please run a random CV search first to generate results.")
|
st.info("Please run a random CV search first to generate results.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# File selection
|
# Directory selection
|
||||||
file_options = {str(f.parent.name): f for f in result_files}
|
dir_options = {_parse_results_dir_name(f): f for f in result_dirs}
|
||||||
selected_file_name = st.sidebar.selectbox(
|
selected_dir_name = st.sidebar.selectbox(
|
||||||
"Select Result File", options=list(file_options.keys()), help="Choose a search result file to visualize"
|
"Select Result Directory",
|
||||||
|
options=list(dir_options.keys()),
|
||||||
|
help="Choose a search result directory to visualize",
|
||||||
)
|
)
|
||||||
selected_file = file_options[selected_file_name]
|
results_dir = dir_options[selected_dir_name]
|
||||||
|
|
||||||
# Load and prepare data
|
# Load and prepare data with default bin width (will be reloaded with custom width later)
|
||||||
with st.spinner("Loading results..."):
|
with st.spinner("Loading results..."):
|
||||||
results = load_and_prepare_results(selected_file)
|
results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=40)
|
||||||
|
model_state = load_and_prepare_model_state(results_dir / "best_estimator_state.nc")
|
||||||
|
n_features = model_state.sizes["feature"]
|
||||||
|
model_state["feature_weights"] *= n_features
|
||||||
|
embedding_feature_array = extract_embedding_features(model_state)
|
||||||
|
era5_feature_array = extract_era5_features(model_state)
|
||||||
|
|
||||||
st.sidebar.success(f"Loaded {len(results)} results")
|
st.sidebar.success(f"Loaded {len(results)} results")
|
||||||
|
|
||||||
|
|
@ -164,11 +618,6 @@ def main():
|
||||||
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
|
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Percentile normalization option
|
|
||||||
use_percentile = st.sidebar.checkbox(
|
|
||||||
"Use Percentile Normalization", value=True, help="Apply percentile-based color normalization to plots"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Display some basic statistics
|
# Display some basic statistics
|
||||||
st.header("Dataset Overview")
|
st.header("Dataset Overview")
|
||||||
col1, col2, col3 = st.columns(3)
|
col1, col2, col3 = st.columns(3)
|
||||||
|
|
@ -186,50 +635,235 @@ def main():
|
||||||
with st.expander("Best Parameters"):
|
with st.expander("Best Parameters"):
|
||||||
best_idx = results[f"mean_test_{selected_metric}"].idxmax()
|
best_idx = results[f"mean_test_{selected_metric}"].idxmax()
|
||||||
best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]]
|
best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]]
|
||||||
st.dataframe(best_params.to_frame().T, use_container_width=True)
|
st.dataframe(best_params.to_frame().T, width="content")
|
||||||
|
|
||||||
# Main plots
|
# Create tabs for different visualizations
|
||||||
st.header(f"Visualization for {selected_metric.capitalize()}")
|
tab1, tab2 = st.tabs(["Search Results", "Model State"])
|
||||||
|
|
||||||
# K-binned plots
|
with tab1:
|
||||||
st.subheader("K-Binned Parameter Space (Mean)")
|
# Main plots
|
||||||
with st.spinner("Generating mean plot..."):
|
st.header(f"Visualization for {selected_metric.capitalize()}")
|
||||||
if use_percentile:
|
|
||||||
fig1 = _plot_k_binned(results, f"mean_test_{selected_metric}", vmin_percentile=50)
|
# K-binned plot configuration
|
||||||
|
col_toggle, col_slider = st.columns([1, 1])
|
||||||
|
|
||||||
|
with col_toggle:
|
||||||
|
# Percentile normalization toggle for K-binned plots
|
||||||
|
use_percentile = st.toggle(
|
||||||
|
"Use Percentile Normalization",
|
||||||
|
value=True,
|
||||||
|
help="Apply percentile-based color normalization to K-binned parameter space plots",
|
||||||
|
)
|
||||||
|
|
||||||
|
with col_slider:
|
||||||
|
# Bin width slider for K-binned plots
|
||||||
|
k_min = int(results["initial_K"].min())
|
||||||
|
k_max = int(results["initial_K"].max())
|
||||||
|
k_range = k_max - k_min
|
||||||
|
|
||||||
|
k_bin_width = st.slider(
|
||||||
|
"Initial K Bin Width",
|
||||||
|
min_value=10,
|
||||||
|
max_value=max(100, k_range // 2),
|
||||||
|
value=40,
|
||||||
|
step=10,
|
||||||
|
help=f"Width of bins for initial_K facets (range: {k_min}-{k_max})",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show estimated number of bins
|
||||||
|
estimated_bins = int(np.ceil(k_range / k_bin_width))
|
||||||
|
st.caption(f"Creating approximately {estimated_bins} bins for initial_K")
|
||||||
|
|
||||||
|
# Reload data if bin width changed from default
|
||||||
|
if k_bin_width != 40:
|
||||||
|
with st.spinner("Re-binning data..."):
|
||||||
|
results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=k_bin_width)
|
||||||
|
|
||||||
|
# K-binned plots
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.subheader("K-Binned Parameter Space (Mean)")
|
||||||
|
with st.spinner("Generating mean plot..."):
|
||||||
|
if use_percentile:
|
||||||
|
chart1 = _plot_k_binned(results, f"mean_test_{selected_metric}", vmin_percentile=50)
|
||||||
|
else:
|
||||||
|
chart1 = _plot_k_binned(results, f"mean_test_{selected_metric}")
|
||||||
|
st.altair_chart(chart1, use_container_width=True)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.subheader("K-Binned Parameter Space (Std)")
|
||||||
|
with st.spinner("Generating std plot..."):
|
||||||
|
if use_percentile:
|
||||||
|
chart2 = _plot_k_binned(results, f"std_test_{selected_metric}", vmax_percentile=50)
|
||||||
|
else:
|
||||||
|
chart2 = _plot_k_binned(results, f"std_test_{selected_metric}")
|
||||||
|
st.altair_chart(chart2, use_container_width=True)
|
||||||
|
|
||||||
|
# Epsilon-binned plots
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.subheader("K vs eps_cl")
|
||||||
|
with st.spinner("Generating eps_cl plot..."):
|
||||||
|
chart3 = _plot_eps_binned(results, "eps_cl", f"mean_test_{selected_metric}")
|
||||||
|
st.altair_chart(chart3, use_container_width=True)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.subheader("K vs eps_e")
|
||||||
|
with st.spinner("Generating eps_e plot..."):
|
||||||
|
chart4 = _plot_eps_binned(results, "eps_e", f"mean_test_{selected_metric}")
|
||||||
|
st.altair_chart(chart4, use_container_width=True)
|
||||||
|
|
||||||
|
# Optional: Raw data table
|
||||||
|
with st.expander("View Raw Results Data"):
|
||||||
|
st.dataframe(results, width="stretch")
|
||||||
|
|
||||||
|
with tab2:
|
||||||
|
# Model state visualization
|
||||||
|
st.header("Best Estimator Model State")
|
||||||
|
|
||||||
|
# Show basic model state info
|
||||||
|
with st.expander("Model State Information"):
|
||||||
|
st.write(f"**Variables:** {list(model_state.data_vars)}")
|
||||||
|
st.write(f"**Dimensions:** {dict(model_state.sizes)}")
|
||||||
|
st.write(f"**Coordinates:** {list(model_state.coords)}")
|
||||||
|
|
||||||
|
# Show statistics
|
||||||
|
st.write("**Feature Weight Statistics:**")
|
||||||
|
feature_weights = model_state["feature_weights"].to_pandas()
|
||||||
|
col1, col2, col3 = st.columns(3)
|
||||||
|
with col1:
|
||||||
|
st.metric("Mean Weight", f"{feature_weights.mean():.4f}")
|
||||||
|
with col2:
|
||||||
|
st.metric("Max Weight", f"{feature_weights.max():.4f}")
|
||||||
|
with col3:
|
||||||
|
st.metric("Total Features", len(feature_weights))
|
||||||
|
|
||||||
|
# Feature importance plot
|
||||||
|
st.subheader("Feature Importance")
|
||||||
|
st.markdown("The most important features based on learned feature weights from the best estimator.")
|
||||||
|
|
||||||
|
# Slider to control number of features to display
|
||||||
|
top_n = st.slider(
|
||||||
|
"Number of top features to display",
|
||||||
|
min_value=5,
|
||||||
|
max_value=50,
|
||||||
|
value=10,
|
||||||
|
step=5,
|
||||||
|
help="Select how many of the most important features to visualize",
|
||||||
|
)
|
||||||
|
|
||||||
|
with st.spinner("Generating feature importance plot..."):
|
||||||
|
feature_chart = _plot_top_features(model_state, top_n=top_n)
|
||||||
|
st.altair_chart(feature_chart, use_container_width=True)
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
**Interpretation:**
|
||||||
|
- **Magnitude**: Larger absolute values indicate more important features
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embedding features analysis (if present)
|
||||||
|
if embedding_feature_array is not None:
|
||||||
|
st.subheader("Embedding Feature Analysis")
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
Analysis of embedding features showing which aggregations, bands, and years
|
||||||
|
are most important for the model predictions.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Summary bar charts
|
||||||
|
st.markdown("### Importance by Dimension")
|
||||||
|
with st.spinner("Generating dimension summaries..."):
|
||||||
|
chart_agg, chart_band, chart_year = _plot_embedding_aggregation_summary(embedding_feature_array)
|
||||||
|
col1, col2, col3 = st.columns(3)
|
||||||
|
with col1:
|
||||||
|
st.altair_chart(chart_agg, use_container_width=True)
|
||||||
|
with col2:
|
||||||
|
st.altair_chart(chart_band, use_container_width=True)
|
||||||
|
with col3:
|
||||||
|
st.altair_chart(chart_year, use_container_width=True)
|
||||||
|
|
||||||
|
# Detailed heatmap
|
||||||
|
st.markdown("### Detailed Heatmap by Aggregation")
|
||||||
|
st.markdown("Shows the weight of each band-year combination for each aggregation type.")
|
||||||
|
with st.spinner("Generating heatmap..."):
|
||||||
|
heatmap_chart = _plot_embedding_heatmap(embedding_feature_array)
|
||||||
|
st.altair_chart(heatmap_chart, use_container_width=True)
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
with st.expander("Embedding Feature Statistics"):
|
||||||
|
st.write("**Overall Statistics:**")
|
||||||
|
n_emb_features = embedding_feature_array.size
|
||||||
|
mean_weight = float(embedding_feature_array.mean().values)
|
||||||
|
max_weight = float(embedding_feature_array.max().values)
|
||||||
|
col1, col2, col3 = st.columns(3)
|
||||||
|
with col1:
|
||||||
|
st.metric("Total Embedding Features", n_emb_features)
|
||||||
|
with col2:
|
||||||
|
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||||
|
with col3:
|
||||||
|
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||||
|
|
||||||
|
# Show top embedding features
|
||||||
|
st.write("**Top 10 Embedding Features:**")
|
||||||
|
emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index()
|
||||||
|
top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]]
|
||||||
|
st.dataframe(top_emb, width="stretch")
|
||||||
else:
|
else:
|
||||||
fig1 = _plot_k_binned(results, f"mean_test_{selected_metric}")
|
st.info("No embedding features found in this model.")
|
||||||
st.pyplot(fig1.figure)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
st.subheader("K-Binned Parameter Space (Std)")
|
# ERA5 features analysis (if present)
|
||||||
with st.spinner("Generating std plot..."):
|
if era5_feature_array is not None:
|
||||||
if use_percentile:
|
st.subheader("ERA5 Feature Analysis")
|
||||||
fig2 = _plot_k_binned(results, f"std_test_{selected_metric}", vmax_percentile=50)
|
st.markdown(
|
||||||
|
"""
|
||||||
|
Analysis of ERA5 climate features showing which variables and time periods
|
||||||
|
are most important for the model predictions.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Summary bar charts
|
||||||
|
st.markdown("### Importance by Dimension")
|
||||||
|
with st.spinner("Generating ERA5 dimension summaries..."):
|
||||||
|
chart_variable, chart_time = _plot_era5_summary(era5_feature_array)
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
with col1:
|
||||||
|
st.altair_chart(chart_variable, use_container_width=True)
|
||||||
|
with col2:
|
||||||
|
st.altair_chart(chart_time, use_container_width=True)
|
||||||
|
|
||||||
|
# Detailed heatmap
|
||||||
|
st.markdown("### Detailed Heatmap")
|
||||||
|
st.markdown("Shows the weight of each variable-time combination.")
|
||||||
|
with st.spinner("Generating ERA5 heatmap..."):
|
||||||
|
era5_heatmap_chart = _plot_era5_heatmap(era5_feature_array)
|
||||||
|
st.altair_chart(era5_heatmap_chart, use_container_width=True)
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
with st.expander("ERA5 Feature Statistics"):
|
||||||
|
st.write("**Overall Statistics:**")
|
||||||
|
n_era5_features = era5_feature_array.size
|
||||||
|
mean_weight = float(era5_feature_array.mean().values)
|
||||||
|
max_weight = float(era5_feature_array.max().values)
|
||||||
|
col1, col2, col3 = st.columns(3)
|
||||||
|
with col1:
|
||||||
|
st.metric("Total ERA5 Features", n_era5_features)
|
||||||
|
with col2:
|
||||||
|
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||||
|
with col3:
|
||||||
|
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||||
|
|
||||||
|
# Show top ERA5 features
|
||||||
|
st.write("**Top 10 ERA5 Features:**")
|
||||||
|
era5_df = era5_feature_array.to_dataframe(name="weight").reset_index()
|
||||||
|
top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]]
|
||||||
|
st.dataframe(top_era5, width="stretch")
|
||||||
else:
|
else:
|
||||||
fig2 = _plot_k_binned(results, f"std_test_{selected_metric}")
|
st.info("No ERA5 features found in this model.")
|
||||||
st.pyplot(fig2.figure)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
# Epsilon-binned plots
|
|
||||||
col1, col2 = st.columns(2)
|
|
||||||
|
|
||||||
with col1:
|
|
||||||
st.subheader("K vs eps_cl")
|
|
||||||
with st.spinner("Generating eps_cl plot..."):
|
|
||||||
fig3 = _plot_eps_binned(results, "eps_cl", f"mean_test_{selected_metric}")
|
|
||||||
st.pyplot(fig3.figure)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
with col2:
|
|
||||||
st.subheader("K vs eps_e")
|
|
||||||
with st.spinner("Generating eps_e plot..."):
|
|
||||||
fig4 = _plot_eps_binned(results, "eps_e", f"mean_test_{selected_metric}")
|
|
||||||
st.pyplot(fig4.figure)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
# Optional: Raw data table
|
|
||||||
with st.expander("View Raw Results Data"):
|
|
||||||
st.dataframe(results, use_container_width=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue