Run main experiments
This commit is contained in:
parent
3ce6b6e867
commit
f9df8e9fe6
43 changed files with 4112 additions and 4022 deletions
197
pixi.lock
generated
197
pixi.lock
generated
|
|
@ -243,6 +243,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/aa/f3/0b6ced594e51cc95d8c1fc1640d3623770d01e4969d29c0bd09945fafefa/altair-5.5.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c8/a7/a597ff7dd1e1603abd94991ce242f93979d5f10b0d45ed23976dfb22bf64/altair_tiles-0.4.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/69/ce/68d6e31f0a75a5cccc03535e47434c0ca4be37fe950e93117e455cbc362c/antimeridian-0.4.5-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz
|
||||
|
|
@ -257,14 +258,14 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/31/4a/72dc383d1a0d14f1d453e334e3461e229762edb1bf3f75b3ab977e9386ed/arro3_core-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/1b/df/2a5a1306dc1699b51b02c1c38c55f3564a8c4f84087c23c61e7e7ae37dfa/arro3_io-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c3/1c/f06ad85180e7dd9855aa5ede901bfc2be858d7bee17d4e978a14c0ecec14/astropy-7.2.0-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/74/51/59effa402d4ce8813e42eb62416059d42dd07826b0e7aa2db057c336972d/astropy_iers_data-0.2026.2.2.0.48.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/4e/67/6af8b422c04dec79c908cf60fdcd4725c3c112b2a058087c4ff58284a142/astropy_iers_data-0.2026.2.9.0.50.33-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/2d/1b/37d8a28965907d23eeba8bce56272932ee01176d192cefdf19a4a0b53c00/autogluon_common-1.5.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/96/de/4bffa0f6f3257e73a22402019d19fbe34dfedc2865896f97ad57935cf7dd/autogluon_core-1.5.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/f3/c8/46eb69e371da89337419d3c754140f3ddae3c85a81b061ba3f275f442475/autogluon_features-1.5.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/48/7c/50547d2940e98c8a15b8c92cd4953814385b95f5fc1dec806fa240389417/autogluon_tabular-1.5.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/fc/d8/b8fcba9464f02b121f39de2db2bf57f0b216fe11d014513d666e8634380d/azure_core-1.38.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/db/88/aaea2ad269ce70b446660371286272c1f6ba66541a7f6f635baf8b0db726/azure_core-1.38.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/d8/3a/6ef2047a072e54e1142718d433d50e9514c999a58f51abfff7902f3a72f8/azure_storage_blob-12.28.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/df/73/b6e24bd22e6720ca8ee9a85a0c4a2971af8497d8f3193fa05390cbd46e09/backoff-2.2.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl
|
||||
|
|
@ -298,10 +299,10 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/0c/00/3106b1854b45bd0474ced037dfe6b73b90fe68a68968cef47c23de3d43d2/confection-0.1.5-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/4b/32/e0f13a1c5b0f8572d0ec6ae2f6c677b7991fafd95da523159c19eff0696a/contourpy-1.3.3-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/2b/08/f83e2e0814248b844265802d081f2fac2f1cbe6cd258e72ba14ff006823a/cryptography-46.0.4-cp311-abi3-manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/6b/e7/237155ae19a9023de7e30ec64e5d99a9431a567407ac21170a046d22a5a3/cryptography-46.0.5-cp311-abi3-manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/fa/25/0be9314cd72fe2ee2ef89ceb1f438bc156428a12177d684040456eee4a56/cupy_xarray-0.1.4-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/1c/7c/996760c30f1302704af57c66ff2d723f7d656d0d0b93563b5528a51484bb/cyclopts-4.5.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/2b/03/f906829bcfcbb945f19d6a64240ffb66a31d69ca5533e95882f0efc9c13c/cyclopts-4.5.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/36/36/bc980b9a14409f3356309c45a8d88d58797d02002a9d794dd6c84e809d3a/cymem-2.0.13-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e5/23/d39ccc4ed76222db31530b0a7d38876fdb7673e23f838e8d8f0ed4651a4f/dask-2026.1.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/28/0e/b11ad5fd77e3dd0baad9cac3184315be7654ae401e3b0b0c324503f23d96/datashader-0.18.2-py3-none-any.whl
|
||||
|
|
@ -314,7 +315,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/02/10/5da547df7a391dcde17f59520a231527b8571e6f46fc8efb02ccb370ab12/docutils-0.22.4-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/53/32/256df3dbaa198c58539ad94f9a41e98c2c8ff23f126b8f5f52c7dcd0a738/duckdb-1.4.4-cp313-cp313-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a7/39/ce46ee84779ef19d88fd028fc786a6dcc68b73ace33c31997aeda0dfecdc/earthengine_api-1.7.12-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c2/0f/875b6df73f884062f3bd7d62a2fb9bfc1d07d1c93a611e999401c5b10ca0/earthengine_api-1.7.13-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/04/40/2ccf4c87a5f9c8198fe71600d5f307f5dada201c091af8774a9c1e360865/ecmwf_datastores_client-0.4.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/65/54/5e3b0e41799e17e5eff1547fda4aab53878c0adb4243de6b95f8ddef899e/ee_extra-2025.7.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/06/98/3e22f4386f6c1957f5994c9aa9cedd8a442bb75766bd0b2e2c1c92854af9/eemont-2025.7.1-py3-none-any.whl
|
||||
|
|
@ -326,12 +327,14 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/7e/31/d229f6cdb9cbe03020499d69c4b431b705aa19a55aa0fe698c98022b2fef/faiss_cpu-1.12.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c7/7d/74dd43d58f37584b32f0d781c8dbea9a286ee73e90393394e70569d4f254/fastai-2.8.6-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/23/03/2fe18e3d718b5a36d6c548df3e7662a4c433efea4d28662063d259248a1d/fastcore-1.12.11-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/59/f3/f71552b94a39509b62e72c4a26b6e4440bb9ce6decacf90af2916829e69e/fastcluster-1.3.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/ea/d6/bb13c44b5863c0be7a27ef02982eca88f50d717549df1979e85942292239/fastcore-1.12.12-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/47/60/ed35253a05a70b63e4f52df1daa39a6a464a3e22b0bd060b77f63e2e2b6a/fastdownload-0.0.7-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/cb/a8/20d0723294217e47de6d9e2e40fd4a9d2f7c4b6ef974babd482a59743694/fastjsonschema-2.21.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/fe/a7/af33584fa6d17b911cfaba460efd3409cb5dd47083c181a4fdfec4bef840/fastlite-0.2.4-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/79/45/4aa502bbda9b63c792463c3466a2c5ef3c0830935f81906043f66b2b6c74/fastprogress-1.1.3-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/47/3d/4b85b47a7e70d5c7cc0cf7d7b2883646c9c0bd3ef54a33f23d5873aa910c/fasttransform-0.0.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/fa/97/3702c3be0e5ad3f46a75ccb9f30b6d20bd9432d9940a0c62dfa4869b4758/flox-0.11.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/92/54/dc5aec836660a37f11a8c66300bc2c18be254ef3a78ff08869ed1960c0fb/flox-0.11.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/b5/a8/5f764f333204db0390362a4356d03a43626997f26818a0e9396f1b3bd8c9/folium-0.20.0-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a3/4b/d67eedaed19def5967fade3297fed8161b25ba94699efc124b14fb68cdbc/fonttools-4.61.1-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/38/74/f94141b38a51a553efef7f510fc213894161ae49b88bffd037f8d2a7cb2f/frozendict-2.4.7-py3-none-any.whl
|
||||
|
|
@ -347,7 +350,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/77/b6/85c4d21067220b9a78cfb81f516f9725ea6befc1544ec9bd2c1acd97c324/google_api_core-2.29.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/04/44/3677ff27998214f2fa7957359da48da378a0ffff1bd0bdaba42e752bc13e/google_api_python_client-2.189.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/07/ad/223d5f4b0b987669ffeb3eadd7e9f85ece633aa7fd3246f1e2f6238e1e05/google_api_python_client-2.190.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/83/1d/d6466de3a5249d35e832a52834115ca9d1d0de6abc22065f049707516d47/google_auth-2.48.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/99/d5/3c97526c8796d3caf5f4b3bed2b05e8a7102326f00a334e7a438237f3b22/google_auth_httplib2-0.3.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/89/20/bfa472e327c8edee00f04beecc80baeddd2ab33ee0e86fd7654da49d45e9/google_cloud_core-2.5.0-py3-none-any.whl
|
||||
|
|
@ -400,7 +403,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/de/73/3d757cb3fc16f0f9794dd289bcd0c4a031d9cf54d8137d6b984b2d02edf3/lightning_utilities-0.15.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/db/bc/83e112abc66cd466c6b83f99118035867cecd41802f8d044638aa78a106e/locket-1.0.0-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/0c/29/0348de65b8cc732daa3e33e67806420b2ae89bdce2b04af740289c5c6c8c/loguru-0.7.3-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c4/bd/ba44a47578ea48ee28b54543c1de8c529eedad8317516a2a753e6d9c77c5/lonboard-0.13.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/87/c4/c15eb88220cc6211eb3756c858a76f6ac26b99e2433831d2d7022ad0ff72/lonboard-0.14.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a7/5e/19fb53bd69379498c47bc234ca4d2851cfbca333d6d6929b10251916da25/mapclassify-2.10.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/75/97/a471f1c3eb1fd6f6c24a31a5858f443891d5127e63a7788678d14e249aea/matplotlib-3.10.8-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl
|
||||
|
|
@ -414,6 +417,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/93/cf/be4e93afbfa0def2cd6fac9302071db0bd6d0617999ecbf53f92b9398de3/multiurl-0.3.7-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e2/63/58e2de2b5232cd294c64092688c422196e74f9fa8b3958bdf02d33df24b9/murmurhash-1.0.15-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/03/cc/7cb74758e6df95e0c4e1253f203b6dd7f348bf2f29cf89e9210a2416d535/narwhals-2.16.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a9/82/0340caa499416c78e5d8f5f05947ae4bc3cba53c9f038ab6e9ed964e22f1/nbformat-5.10.4-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/7b/7a/a8d32501bb95ecff342004a674720164f95ad616f269450b3bc13dc88ae3/netcdf4-1.7.4-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/ae/d3/ff8f1b9968aa4dcd1da1880322ed492314cc920998182e549b586c895a17/numbagg-0.9.4-py3-none-any.whl
|
||||
|
|
@ -436,13 +440,13 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/d1/c6/df1fe324248424f77b89371116dab5243db7f052c32cc9fe7442ad9c5f75/pandas_stubs-2.3.3.260113-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/11/b6/f8c7e1f5f716e16070cf35f90c24f95f397376bb810e65000b6bc55950cc/param-2.3.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/b6/61/fae042894f4296ec49e3f193aff5d7c18440da9e48102c3315e1bc4519a7/parso-0.8.6-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c0/db/61efa0d08a99f897ef98256b03e563092d36cc38dc4ebe4a85020fe40b31/pbr-7.0.3-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/01/9a/632e58ec89a32738cabfd9ec418f0e9898a2b4719afc581f07c04a05e3c9/pillow-12.1.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/71/24/538bff45bde96535d7d998c6fed1a751c75ac7c53c37c90dc2601b243893/pillow-12.1.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/de/f0/c81e05b613866b76d2d1066490adf1a3dbc4ee9d9c839961c3fc8a6997af/pip-26.0.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/da/10/1b0dcf51427326f70e50d98df21b18c228117a743a1fc515a42f8dc7d342/platformdirs-4.6.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/8a/67/f95b5460f127840310d2187f916cf0023b5875c0717fdf893f71e1325e87/plotly-6.5.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/42/88/71fa06eb487ed9d4fab0ad173300b7a58706385f98fb66b1ccdc3ec3d4dd/plum_dispatch-2.6.1-py3-none-any.whl
|
||||
|
|
@ -505,6 +509,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/ca/63/2c6daf59d86b1c30600bff679d039f57fd1932af82c43c0bde1cbc55e8d4/sentry_sdk-2.52.0-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/6b/6a/c006de5df0e0f4850aa94019df1f79bf6a5342fa851ca85e4728691fd0c4/shap-0.50.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/f2/a2/83fc37e2a58090e3d2ff79175a95493c664bcd0b653dd75cb9134645a4e5/shapely-2.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/63/81/9ef641ff4e12cbcca30e54e72fb0951a2ba195d0cda0ba4100e532d929db/slicer-0.0.8-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/ab/6c/1d4db72c5dbbb9ea2fbc323a40986917cca84ca098f6fcf80624370979e7/smart_geocubes-0.1.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/ad/95/bc978be7ea0babf2fb48a414b6afaad414c6a9e8b1eafc5b8a53c030381a/smart_open-7.5.0-py3-none-any.whl
|
||||
|
|
@ -544,7 +549,8 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/8d/c0/fdf9d3ee103ce66a55f0532835ad5e154226c5222423c6636ba049dc42fc/traittypes-0.2.3-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/03/b8/e484ef633af3887baeeb4b6ad12743363af7cce68ae51e938e00aaa0529d/transformers-4.57.6-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/74/18/8dd4fe6df1fd66f3e83b4798eddb1d8482d9d9b105f25099b76703402ebb/ty-0.0.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c8/0a/4aca634faf693e33004796b6cee0ae2e1dba375a800c16ab8d3eff4bb800/typer_slim-0.21.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/7a/ed/d6fca788b51d0d4640c4bc82d0e85bad4b49809bca36bf4af01b4dcb66a7/typer-0.23.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/07/3e/ba3a222c80ee070d9497ece3e1fe77253c142925dd4c90f04278aac0a9eb/typer_slim-0.23.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e7/c1/56ef16bf5dcd255155cc736d276efa6ae0a5c26fd685e28f0412a4013c01/types_pytz-2025.2.0.20251108-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/dc/9b/47798a6c91d8bdb567fe2698fe81e0c6b7cb7ef4d13da4114b41d239f65d/typing_inspection-0.4.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/d6/32/48209716f9715d77f1bce084ad74c5d3cfcf41fd78d0c7e7dbe4829cfa3a/ultraplot-1.72.0-py3-none-any.whl
|
||||
|
|
@ -732,6 +738,11 @@ packages:
|
|||
- jupyter-book ; extra == 'doc'
|
||||
- vl-convert-python ; extra == 'doc'
|
||||
requires_python: '>=3.9'
|
||||
- pypi: https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl
|
||||
name: annotated-doc
|
||||
version: 0.0.4
|
||||
sha256: 571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320
|
||||
requires_python: '>=3.8'
|
||||
- pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl
|
||||
name: annotated-types
|
||||
version: 0.7.0
|
||||
|
|
@ -920,10 +931,10 @@ packages:
|
|||
- astropy[dev] ; extra == 'dev-all'
|
||||
- astropy[test-all] ; extra == 'dev-all'
|
||||
requires_python: '>=3.11'
|
||||
- pypi: https://files.pythonhosted.org/packages/74/51/59effa402d4ce8813e42eb62416059d42dd07826b0e7aa2db057c336972d/astropy_iers_data-0.2026.2.2.0.48.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/4e/67/6af8b422c04dec79c908cf60fdcd4725c3c112b2a058087c4ff58284a142/astropy_iers_data-0.2026.2.9.0.50.33-py3-none-any.whl
|
||||
name: astropy-iers-data
|
||||
version: 0.2026.2.2.0.48.1
|
||||
sha256: 62aecb2faea740e0d714808b85512ebe4f29adbfe1e8d5e5481cfd66494d164f
|
||||
version: 0.2026.2.9.0.50.33
|
||||
sha256: ac01dede0240499b23c2b89fdc93093500336197c5c794e6a01173cfd78a7620
|
||||
requires_dist:
|
||||
- pytest ; extra == 'docs'
|
||||
- hypothesis ; extra == 'test'
|
||||
|
|
@ -1295,10 +1306,10 @@ packages:
|
|||
purls: []
|
||||
size: 3472674
|
||||
timestamp: 1765257107074
|
||||
- pypi: https://files.pythonhosted.org/packages/fc/d8/b8fcba9464f02b121f39de2db2bf57f0b216fe11d014513d666e8634380d/azure_core-1.38.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/db/88/aaea2ad269ce70b446660371286272c1f6ba66541a7f6f635baf8b0db726/azure_core-1.38.1-py3-none-any.whl
|
||||
name: azure-core
|
||||
version: 1.38.0
|
||||
sha256: ab0c9b2cd71fecb1842d52c965c95285d3cfb38902f6766e4a471f1cd8905335
|
||||
version: 1.38.1
|
||||
sha256: 69f08ee3d55136071b7100de5b198994fc1c5f89d2b91f2f43156d20fcf200a4
|
||||
requires_dist:
|
||||
- requests>=2.21.0
|
||||
- typing-extensions>=4.6.0
|
||||
|
|
@ -1871,17 +1882,17 @@ packages:
|
|||
purls: []
|
||||
size: 48369
|
||||
timestamp: 1765019689213
|
||||
- pypi: https://files.pythonhosted.org/packages/2b/08/f83e2e0814248b844265802d081f2fac2f1cbe6cd258e72ba14ff006823a/cryptography-46.0.4-cp311-abi3-manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/6b/e7/237155ae19a9023de7e30ec64e5d99a9431a567407ac21170a046d22a5a3/cryptography-46.0.5-cp311-abi3-manylinux_2_28_x86_64.whl
|
||||
name: cryptography
|
||||
version: 46.0.4
|
||||
sha256: 0a9ad24359fee86f131836a9ac3bffc9329e956624a2d379b613f8f8abaf5255
|
||||
version: 46.0.5
|
||||
sha256: 3ee190460e2fbe447175cda91b88b84ae8322a104fc27766ad09428754a618ed
|
||||
requires_dist:
|
||||
- cffi>=1.14 ; python_full_version == '3.8.*' and platform_python_implementation != 'PyPy'
|
||||
- cffi>=2.0.0 ; python_full_version >= '3.9' and platform_python_implementation != 'PyPy'
|
||||
- typing-extensions>=4.13.2 ; python_full_version < '3.11'
|
||||
- bcrypt>=3.1.5 ; extra == 'ssh'
|
||||
- nox[uv]>=2024.4.15 ; extra == 'nox'
|
||||
- cryptography-vectors==46.0.4 ; extra == 'test'
|
||||
- cryptography-vectors==46.0.5 ; extra == 'test'
|
||||
- pytest>=7.4.0 ; extra == 'test'
|
||||
- pytest-benchmark>=4.0 ; extra == 'test'
|
||||
- pytest-cov>=2.10.1 ; extra == 'test'
|
||||
|
|
@ -2390,10 +2401,10 @@ packages:
|
|||
- pytest-cov ; extra == 'tests'
|
||||
- pytest-xdist ; extra == 'tests'
|
||||
requires_python: '>=3.8'
|
||||
- pypi: https://files.pythonhosted.org/packages/1c/7c/996760c30f1302704af57c66ff2d723f7d656d0d0b93563b5528a51484bb/cyclopts-4.5.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/2b/03/f906829bcfcbb945f19d6a64240ffb66a31d69ca5533e95882f0efc9c13c/cyclopts-4.5.2-py3-none-any.whl
|
||||
name: cyclopts
|
||||
version: 4.5.1
|
||||
sha256: 0642c93601e554ca6b7b9abd81093847ea4448b2616280f2a0952416574e8c7a
|
||||
version: 4.5.2
|
||||
sha256: ee56ee23c2c81abc34b66b5aa8fd2698ca699740054e84e534449ec3eb7f944d
|
||||
requires_dist:
|
||||
- attrs>=23.1.0
|
||||
- docstring-parser>=0.15,<4.0
|
||||
|
|
@ -2576,10 +2587,10 @@ packages:
|
|||
- pyarrow ; extra == 'all'
|
||||
- adbc-driver-manager ; extra == 'all'
|
||||
requires_python: '>=3.9.0'
|
||||
- pypi: https://files.pythonhosted.org/packages/a7/39/ce46ee84779ef19d88fd028fc786a6dcc68b73ace33c31997aeda0dfecdc/earthengine_api-1.7.12-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c2/0f/875b6df73f884062f3bd7d62a2fb9bfc1d07d1c93a611e999401c5b10ca0/earthengine_api-1.7.13-py3-none-any.whl
|
||||
name: earthengine-api
|
||||
version: 1.7.12
|
||||
sha256: 39c24f65b97e88bfed325e55d7f9fa5c8e8a9f92280c5ed24e3ab36560f2b543
|
||||
version: 1.7.13
|
||||
sha256: 32a24b6003287f71afb24e2cee7718296f8d82778488b88c1e760279c7e47840
|
||||
requires_dist:
|
||||
- google-cloud-storage
|
||||
- google-api-python-client>=1.12.1
|
||||
|
|
@ -2648,7 +2659,7 @@ packages:
|
|||
- pypi: ./
|
||||
name: entropice
|
||||
version: 0.1.0
|
||||
sha256: 07232c2b09b1b8b691cc8ca7d25b3c0041f2324236a11491f9f07b7e6827973a
|
||||
sha256: d9313dad098d69cd67a908e0cb26d4506c8bf723d6a42d0b213d84a6bdb03e9e
|
||||
requires_dist:
|
||||
- aiohttp>=3.12.11
|
||||
- bokeh>=3.7.3
|
||||
|
|
@ -2715,6 +2726,8 @@ packages:
|
|||
- shap>=0.50.0,<0.51
|
||||
- h5py>=3.15.1,<4
|
||||
- pydantic>=2.12.5,<3
|
||||
- nbformat>=5.10.4,<6
|
||||
- fastcluster>=1.3.0,<2
|
||||
requires_python: '>=3.13,<3.14'
|
||||
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9152653278559faff830ff984a66d30b8ae5657c
|
||||
name: entropy
|
||||
|
|
@ -2801,10 +2814,17 @@ packages:
|
|||
- accelerate>=0.21 ; extra == 'dev'
|
||||
- ipykernel ; extra == 'dev'
|
||||
requires_python: '>=3.10'
|
||||
- pypi: https://files.pythonhosted.org/packages/23/03/2fe18e3d718b5a36d6c548df3e7662a4c433efea4d28662063d259248a1d/fastcore-1.12.11-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/59/f3/f71552b94a39509b62e72c4a26b6e4440bb9ce6decacf90af2916829e69e/fastcluster-1.3.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
name: fastcluster
|
||||
version: 1.3.0
|
||||
sha256: 2dce31ace6f8e08c5400d6e19492fe09aba2b050f78a7aa6943ba2ae50dcd1b0
|
||||
requires_dist:
|
||||
- numpy>=2
|
||||
requires_python: '>=3'
|
||||
- pypi: https://files.pythonhosted.org/packages/ea/d6/bb13c44b5863c0be7a27ef02982eca88f50d717549df1979e85942292239/fastcore-1.12.12-py3-none-any.whl
|
||||
name: fastcore
|
||||
version: 1.12.11
|
||||
sha256: b6a0ce9f48509405109251d00ac0576cfe5cba0a2b1b495a4126283969efbad5
|
||||
version: 1.12.12
|
||||
sha256: bb1a3a3accd62a72bad56af974e0617af078316be5bb5dcc8763b8244c197fa8
|
||||
requires_dist:
|
||||
- numpy ; extra == 'dev'
|
||||
- nbdev>=0.2.39 ; extra == 'dev'
|
||||
|
|
@ -2826,6 +2846,19 @@ packages:
|
|||
- fastprogress
|
||||
- fastcore>=1.3.26
|
||||
requires_python: '>=3.6'
|
||||
- pypi: https://files.pythonhosted.org/packages/cb/a8/20d0723294217e47de6d9e2e40fd4a9d2f7c4b6ef974babd482a59743694/fastjsonschema-2.21.2-py3-none-any.whl
|
||||
name: fastjsonschema
|
||||
version: 2.21.2
|
||||
sha256: 1c797122d0a86c5cace2e54bf4e819c36223b552017172f32c5c024a6b77e463
|
||||
requires_dist:
|
||||
- colorama ; extra == 'devel'
|
||||
- jsonschema ; extra == 'devel'
|
||||
- json-spec ; extra == 'devel'
|
||||
- pylint ; extra == 'devel'
|
||||
- pytest ; extra == 'devel'
|
||||
- pytest-benchmark ; extra == 'devel'
|
||||
- pytest-cache ; extra == 'devel'
|
||||
- validictory ; extra == 'devel'
|
||||
- pypi: https://files.pythonhosted.org/packages/fe/a7/af33584fa6d17b911cfaba460efd3409cb5dd47083c181a4fdfec4bef840/fastlite-0.2.4-py3-none-any.whl
|
||||
name: fastlite
|
||||
version: 0.2.4
|
||||
|
|
@ -2879,10 +2912,10 @@ packages:
|
|||
- pkg:pypi/filelock?source=hash-mapping
|
||||
size: 18609
|
||||
timestamp: 1765846639623
|
||||
- pypi: https://files.pythonhosted.org/packages/fa/97/3702c3be0e5ad3f46a75ccb9f30b6d20bd9432d9940a0c62dfa4869b4758/flox-0.11.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/92/54/dc5aec836660a37f11a8c66300bc2c18be254ef3a78ff08869ed1960c0fb/flox-0.11.1-py3-none-any.whl
|
||||
name: flox
|
||||
version: 0.11.0
|
||||
sha256: 61620abc0eec12a3d6f93fd08f17326435b17d256678a5380598d10b25012751
|
||||
version: 0.11.1
|
||||
sha256: 2c5da10771d139118eee7ca453b5a60c34f051cf1c06f2e5446728bc09fce2ec
|
||||
requires_dist:
|
||||
- pandas>=2.1
|
||||
- packaging>=21.3
|
||||
|
|
@ -3343,10 +3376,10 @@ packages:
|
|||
- grpcio-gcp>=0.2.2,<1.0.0 ; extra == 'grpcgcp'
|
||||
- grpcio-gcp>=0.2.2,<1.0.0 ; extra == 'grpcio-gcp'
|
||||
requires_python: '>=3.7'
|
||||
- pypi: https://files.pythonhosted.org/packages/04/44/3677ff27998214f2fa7957359da48da378a0ffff1bd0bdaba42e752bc13e/google_api_python_client-2.189.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/07/ad/223d5f4b0b987669ffeb3eadd7e9f85ece633aa7fd3246f1e2f6238e1e05/google_api_python_client-2.190.0-py3-none-any.whl
|
||||
name: google-api-python-client
|
||||
version: 2.189.0
|
||||
sha256: a258c09660a49c6159173f8bbece171278e917e104a11f0640b34751b79c8a1a
|
||||
version: 2.190.0
|
||||
sha256: d9b5266758f96c39b8c21d9bbfeb4e58c14dbfba3c931f7c5a8d7fdcd292dd57
|
||||
requires_dist:
|
||||
- httplib2>=0.19.0,<1.0.0
|
||||
- google-auth>=1.32.0,!=2.24.0,!=2.25.0,<3.0.0
|
||||
|
|
@ -5787,10 +5820,10 @@ packages:
|
|||
- build==1.2.2 ; python_full_version >= '3.11' and extra == 'dev'
|
||||
- twine==6.0.1 ; python_full_version >= '3.11' and extra == 'dev'
|
||||
requires_python: '>=3.5,<4.0'
|
||||
- pypi: https://files.pythonhosted.org/packages/c4/bd/ba44a47578ea48ee28b54543c1de8c529eedad8317516a2a753e6d9c77c5/lonboard-0.13.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/87/c4/c15eb88220cc6211eb3756c858a76f6ac26b99e2433831d2d7022ad0ff72/lonboard-0.14.0-py3-none-any.whl
|
||||
name: lonboard
|
||||
version: 0.13.0
|
||||
sha256: 8acb17fdcbb34bd147a68aebd4b887996171e0eb9df7f4fc06e467cdfa32fb07
|
||||
version: 0.14.0
|
||||
sha256: 35f3218490e7bf07562575a872b45d44cf02026c3db9d0af486ad716dca02551
|
||||
requires_dist:
|
||||
- anywidget~=0.9.0
|
||||
- arro3-compute>=0.4.1
|
||||
|
|
@ -5805,12 +5838,14 @@ packages:
|
|||
- click>=8.1.7 ; extra == 'cli'
|
||||
- pyogrio>=0.8 ; extra == 'cli'
|
||||
- shapely>=2 ; extra == 'cli'
|
||||
- async-geotiff>=0.1.0 ; python_full_version >= '3.11' and extra == 'cog'
|
||||
- morecantile>=7.0 ; python_full_version >= '3.11' and extra == 'cog'
|
||||
- geopandas>=0.13 ; extra == 'geopandas'
|
||||
- pandas>=2 ; extra == 'geopandas'
|
||||
- pyarrow>=16 ; extra == 'geopandas'
|
||||
- shapely>=2 ; extra == 'geopandas'
|
||||
- movingpandas>=0.17 ; extra == 'movingpandas'
|
||||
requires_python: '>=3.10'
|
||||
requires_python: '>=3.11'
|
||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/lz4-c-1.10.0-h5888daf_1.conda
|
||||
sha256: 47326f811392a5fd3055f0f773036c392d26fdb32e4d8e7a8197eed951489346
|
||||
md5: 9de5350a85c4a20c685259b889aa6393
|
||||
|
|
@ -6127,6 +6162,25 @@ packages:
|
|||
- sqlparse ; extra == 'sql'
|
||||
- sqlframe>=3.22.0,!=3.39.3 ; extra == 'sqlframe'
|
||||
requires_python: '>=3.9'
|
||||
- pypi: https://files.pythonhosted.org/packages/a9/82/0340caa499416c78e5d8f5f05947ae4bc3cba53c9f038ab6e9ed964e22f1/nbformat-5.10.4-py3-none-any.whl
|
||||
name: nbformat
|
||||
version: 5.10.4
|
||||
sha256: 3b48d6c8fbca4b299bf3982ea7db1af21580e4fec269ad087b9e81588891200b
|
||||
requires_dist:
|
||||
- fastjsonschema>=2.15
|
||||
- jsonschema>=2.6
|
||||
- jupyter-core>=4.12,!=5.0.*
|
||||
- traitlets>=5.1
|
||||
- myst-parser ; extra == 'docs'
|
||||
- pydata-sphinx-theme ; extra == 'docs'
|
||||
- sphinx ; extra == 'docs'
|
||||
- sphinxcontrib-github-alt ; extra == 'docs'
|
||||
- sphinxcontrib-spelling ; extra == 'docs'
|
||||
- pep440 ; extra == 'test'
|
||||
- pre-commit ; extra == 'test'
|
||||
- pytest ; extra == 'test'
|
||||
- testpath ; extra == 'test'
|
||||
requires_python: '>=3.8'
|
||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/nccl-2.28.9.1-h4d09622_1.conda
|
||||
sha256: a132df4a0b4c36932cfd5e931b4c88e83991ad77de9adf13c206caefdaf3b8b0
|
||||
md5: af3e8d72000a10bd8159d7e28daf4bfc
|
||||
|
|
@ -6710,15 +6764,15 @@ packages:
|
|||
- nest-asyncio ; extra == 'tests-pypy'
|
||||
- numpy ; extra == 'tests-pypy'
|
||||
requires_python: '>=3.10'
|
||||
- pypi: https://files.pythonhosted.org/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/b6/61/fae042894f4296ec49e3f193aff5d7c18440da9e48102c3315e1bc4519a7/parso-0.8.6-py2.py3-none-any.whl
|
||||
name: parso
|
||||
version: 0.8.5
|
||||
sha256: 646204b5ee239c396d040b90f9e272e9a8017c630092bf59980beb62fd033887
|
||||
version: 0.8.6
|
||||
sha256: 2c549f800b70a5c4952197248825584cb00f033b29c692671d3bf08bf380baff
|
||||
requires_dist:
|
||||
- pytest ; extra == 'testing'
|
||||
- docopt ; extra == 'testing'
|
||||
- flake8==5.0.4 ; extra == 'qa'
|
||||
- mypy==0.971 ; extra == 'qa'
|
||||
- zuban==0.5.1 ; extra == 'qa'
|
||||
- types-setuptools==67.2.0.1 ; extra == 'qa'
|
||||
requires_python: '>=3.6'
|
||||
- pypi: https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl
|
||||
|
|
@ -6746,10 +6800,10 @@ packages:
|
|||
sha256: 7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523
|
||||
requires_dist:
|
||||
- ptyprocess>=0.5
|
||||
- pypi: https://files.pythonhosted.org/packages/01/9a/632e58ec89a32738cabfd9ec418f0e9898a2b4719afc581f07c04a05e3c9/pillow-12.1.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/71/24/538bff45bde96535d7d998c6fed1a751c75ac7c53c37c90dc2601b243893/pillow-12.1.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
||||
name: pillow
|
||||
version: 12.1.0
|
||||
sha256: 6741e6f3074a35e47c77b23a4e4f2d90db3ed905cb1c5e6e0d49bff2045632bc
|
||||
version: 12.1.1
|
||||
sha256: 47b94983da0c642de92ced1702c5b6c292a84bd3a8e1d1702ff923f183594717
|
||||
requires_dist:
|
||||
- furo ; extra == 'docs'
|
||||
- olefile ; extra == 'docs'
|
||||
|
|
@ -6783,21 +6837,10 @@ packages:
|
|||
version: 26.0.1
|
||||
sha256: bdb1b08f4274833d62c1aa29e20907365a2ceb950410df15fc9521bad440122b
|
||||
requires_python: '>=3.9'
|
||||
- pypi: https://files.pythonhosted.org/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/da/10/1b0dcf51427326f70e50d98df21b18c228117a743a1fc515a42f8dc7d342/platformdirs-4.6.0-py3-none-any.whl
|
||||
name: platformdirs
|
||||
version: 4.5.1
|
||||
sha256: d03afa3963c806a9bed9d5125c8f4cb2fdaf74a55ab60e5d59b3fde758104d31
|
||||
requires_dist:
|
||||
- furo>=2025.9.25 ; extra == 'docs'
|
||||
- proselint>=0.14 ; extra == 'docs'
|
||||
- sphinx-autodoc-typehints>=3.2 ; extra == 'docs'
|
||||
- sphinx>=8.2.3 ; extra == 'docs'
|
||||
- appdirs==1.4.4 ; extra == 'test'
|
||||
- covdefaults>=2.3 ; extra == 'test'
|
||||
- pytest-cov>=7 ; extra == 'test'
|
||||
- pytest-mock>=3.15.1 ; extra == 'test'
|
||||
- pytest>=8.4.2 ; extra == 'test'
|
||||
- mypy>=1.18.2 ; extra == 'type'
|
||||
version: 4.6.0
|
||||
sha256: dd7f808d828e1764a22ebff09e60f175ee3c41876606a6132a688d809c7c9c73
|
||||
requires_python: '>=3.10'
|
||||
- pypi: https://files.pythonhosted.org/packages/8a/67/f95b5460f127840310d2187f916cf0023b5875c0717fdf893f71e1325e87/plotly-6.5.2-py3-none-any.whl
|
||||
name: plotly
|
||||
|
|
@ -8464,6 +8507,11 @@ packages:
|
|||
- sphinx-book-theme ; extra == 'docs'
|
||||
- sphinx-remove-toctrees ; extra == 'docs'
|
||||
requires_python: '>=3.10'
|
||||
- pypi: https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl
|
||||
name: shellingham
|
||||
version: 1.5.4
|
||||
sha256: 7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686
|
||||
requires_python: '>=3.7'
|
||||
- conda: https://conda.anaconda.org/conda-forge/noarch/six-1.17.0-pyhe01879c_1.conda
|
||||
sha256: 458227f759d5e3fcec5d9b7acce54e10c9e1f4f4b7ec978f3bfd54ce4ee9853d
|
||||
md5: 3339e3b65d58accf4ca4fb8748ab16b3
|
||||
|
|
@ -9775,15 +9823,22 @@ packages:
|
|||
version: 0.0.11
|
||||
sha256: 25f88e8789072830348cb59b761d5ced70642ed5600673b4bf6a849af71eca8b
|
||||
requires_python: '>=3.8'
|
||||
- pypi: https://files.pythonhosted.org/packages/c8/0a/4aca634faf693e33004796b6cee0ae2e1dba375a800c16ab8d3eff4bb800/typer_slim-0.21.1-py3-none-any.whl
|
||||
name: typer-slim
|
||||
version: 0.21.1
|
||||
sha256: 6e6c31047f171ac93cc5a973c9e617dbc5ab2bddc4d0a3135dc161b4e2020e0d
|
||||
- pypi: https://files.pythonhosted.org/packages/7a/ed/d6fca788b51d0d4640c4bc82d0e85bad4b49809bca36bf4af01b4dcb66a7/typer-0.23.0-py3-none-any.whl
|
||||
name: typer
|
||||
version: 0.23.0
|
||||
sha256: 79f4bc262b6c37872091072a3cb7cb6d7d79ee98c0c658b4364bdcde3c42c913
|
||||
requires_dist:
|
||||
- click>=8.0.0
|
||||
- typing-extensions>=3.7.4.3
|
||||
- shellingham>=1.3.0 ; extra == 'standard'
|
||||
- rich>=10.11.0 ; extra == 'standard'
|
||||
- shellingham>=1.3.0
|
||||
- rich>=10.11.0
|
||||
- annotated-doc>=0.0.2
|
||||
requires_python: '>=3.9'
|
||||
- pypi: https://files.pythonhosted.org/packages/07/3e/ba3a222c80ee070d9497ece3e1fe77253c142925dd4c90f04278aac0a9eb/typer_slim-0.23.0-py3-none-any.whl
|
||||
name: typer-slim
|
||||
version: 0.23.0
|
||||
sha256: 1d693daf22d998a7b1edab8413cdcb8af07254154ce3956c1664dc11b01e2f8b
|
||||
requires_dist:
|
||||
- typer>=0.23.0
|
||||
requires_python: '>=3.9'
|
||||
- pypi: https://files.pythonhosted.org/packages/e7/c1/56ef16bf5dcd255155cc736d276efa6ae0a5c26fd685e28f0412a4013c01/types_pytz-2025.2.0.20251108-py3-none-any.whl
|
||||
name: types-pytz
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ dependencies = [
|
|||
"autogluon-tabular[all,mitra,realmlp,interpret,fastai,tabm,tabpfn,tabdpt,tabpfnmix,tabicl,skew,imodels]>=1.5.0",
|
||||
"shap>=0.50.0,<0.51",
|
||||
"h5py>=3.15.1,<4",
|
||||
"pydantic>=2.12.5,<3",
|
||||
"pydantic>=2.12.5,<3", "nbformat>=5.10.4,<6", "fastcluster>=1.3.0,<2",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,33 @@
|
|||
#! /bin/bash
|
||||
|
||||
# Check if running inside the pixi environment
|
||||
which darts >/dev/null 2>&1
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "This script must be run inside the pixi environment."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# pixi shell
|
||||
darts extract-darts-v2 --grid hex --level 3
|
||||
darts extract-darts-v2 --grid hex --level 4
|
||||
darts extract-darts-v2 --grid hex --level 5
|
||||
darts extract-darts-v2 --grid hex --level 6
|
||||
darts extract-darts-v2 --grid healpix --level 6
|
||||
darts extract-darts-v2 --grid healpix --level 7
|
||||
darts extract-darts-v2 --grid healpix --level 8
|
||||
darts extract-darts-v2 --grid healpix --level 9
|
||||
darts extract-darts-v2 --grid healpix --level 10
|
||||
|
||||
darts extract-darts-v2-aggregated --grid hex --level 3
|
||||
darts extract-darts-v2-aggregated --grid hex --level 4
|
||||
darts extract-darts-v2-aggregated --grid hex --level 5
|
||||
darts extract-darts-v2-aggregated --grid hex --level 6
|
||||
darts extract-darts-v2-aggregated --grid healpix --level 6
|
||||
darts extract-darts-v2-aggregated --grid healpix --level 7
|
||||
darts extract-darts-v2-aggregated --grid healpix --level 8
|
||||
darts extract-darts-v2-aggregated --grid healpix --level 9
|
||||
darts extract-darts-v2-aggregated --grid healpix --level 10
|
||||
exit 0
|
||||
darts extract-darts-v1 --grid hex --level 3
|
||||
darts extract-darts-v1 --grid hex --level 4
|
||||
darts extract-darts-v1 --grid hex --level 5
|
||||
|
|
@ -22,7 +49,6 @@ darts extract-darts-v1-aggregated --grid healpix --level 8
|
|||
darts extract-darts-v1-aggregated --grid healpix --level 9
|
||||
darts extract-darts-v1-aggregated --grid healpix --level 10
|
||||
|
||||
|
||||
darts extract-darts-mllabels --grid hex --level 3
|
||||
darts extract-darts-mllabels --grid hex --level 4
|
||||
darts extract-darts-mllabels --grid hex --level 5
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import streamlit as st
|
|||
from entropice.dashboard.views.dataset_page import render_dataset_page
|
||||
from entropice.dashboard.views.experiment_analysis_page import render_experiment_analysis_page
|
||||
from entropice.dashboard.views.inference_page import render_inference_page
|
||||
from entropice.dashboard.views.model_state_page import render_model_state_page
|
||||
from entropice.dashboard.views.overview_page import render_overview_page
|
||||
from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
|
||||
|
||||
|
|
@ -30,14 +29,13 @@ def main():
|
|||
data_page = st.Page(render_dataset_page, title="Dataset", icon="📊")
|
||||
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
||||
experiment_analysis_page = st.Page(render_experiment_analysis_page, title="Experiment Analysis", icon="🔬")
|
||||
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
|
||||
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
|
||||
|
||||
pg = st.navigation(
|
||||
{
|
||||
"Overview": [overview_page],
|
||||
"Data": [data_page],
|
||||
"Experiments": [training_analysis_page, experiment_analysis_page, model_state_page],
|
||||
"Experiments": [training_analysis_page, experiment_analysis_page],
|
||||
"Inference": [inference_page],
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ def create_grid_level_comparison_plot(
|
|||
Args:
|
||||
results_df: DataFrame with experiment results including grid, level, model, and metrics
|
||||
metric: Metric to compare (e.g., 'f1', 'accuracy', 'r2')
|
||||
split: Data split to show ('train', 'test', or 'combined')
|
||||
split: Data split to show ('train', 'test', or 'complete')
|
||||
|
||||
Returns:
|
||||
Plotly figure showing performance across grid levels
|
||||
|
|
@ -82,7 +82,7 @@ def create_grid_level_comparison_plot(
|
|||
"autogluon": "star",
|
||||
}
|
||||
|
||||
# Add a combined column for hover information
|
||||
# Add a complete column for hover information
|
||||
results_df["model_display"] = results_df["model"].str.upper()
|
||||
|
||||
# Create box plot without individual points first
|
||||
|
|
|
|||
|
|
@ -1,493 +0,0 @@
|
|||
"""Plotting functions for inference result visualizations."""
|
||||
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
import pydeck as pdk
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.utils.colors import get_palette
|
||||
from entropice.dashboard.utils.geometry import fix_hex_geometry
|
||||
from entropice.dashboard.utils.loaders import TrainingResult
|
||||
from entropice.utils.types import Task
|
||||
|
||||
# Canonical orderings imported from the ML pipeline
|
||||
# Binary labels are defined inline in dataset.py: {False: "No RTS", True: "RTS"}
|
||||
# Count/Density labels are defined in the bin_values function
|
||||
BINARY_LABELS = ["No RTS", "RTS"]
|
||||
COUNT_LABELS = ["None", "Very Few", "Few", "Several", "Many", "Very Many"]
|
||||
DENSITY_LABELS = ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"]
|
||||
|
||||
CLASS_ORDERINGS: dict[Task | str, list[str]] = {
|
||||
"binary": BINARY_LABELS,
|
||||
"count_regimes": COUNT_LABELS,
|
||||
"density_regimes": DENSITY_LABELS,
|
||||
# Legacy aliases (deprecated)
|
||||
"count": COUNT_LABELS,
|
||||
"density": DENSITY_LABELS,
|
||||
}
|
||||
|
||||
|
||||
def get_ordered_classes(task: Task | str, available_classes: list[str] | None = None) -> list[str]:
|
||||
"""Get properly ordered class labels for a given task.
|
||||
|
||||
This uses the same canonical ordering as defined in the ML dataset module,
|
||||
ensuring consistency between training and inference visualizations.
|
||||
|
||||
Args:
|
||||
task: Task type ('binary', 'count_regimes', 'density_regimes', 'count', 'density').
|
||||
available_classes: Optional list of available classes to filter and order.
|
||||
If None, returns all canonical classes for the task.
|
||||
|
||||
Returns:
|
||||
List of class labels in proper order.
|
||||
|
||||
Examples:
|
||||
>>> get_ordered_classes("binary")
|
||||
['No RTS', 'RTS']
|
||||
>>> get_ordered_classes("count_regimes", ["None", "Few", "Several"])
|
||||
['None', 'Few', 'Several']
|
||||
|
||||
"""
|
||||
canonical_order = CLASS_ORDERINGS[task]
|
||||
|
||||
if available_classes is None:
|
||||
return canonical_order
|
||||
|
||||
# Filter canonical order to only include available classes, preserving order
|
||||
return [cls for cls in canonical_order if cls in available_classes]
|
||||
|
||||
|
||||
def sort_class_series(series: pd.Series, task: Task | str) -> pd.Series:
|
||||
"""Sort a pandas Series with class labels according to canonical ordering.
|
||||
|
||||
Args:
|
||||
series: Pandas Series with class labels as index.
|
||||
task: Task type ('binary', 'count_regimes', 'density_regimes', 'count', 'density').
|
||||
|
||||
Returns:
|
||||
Sorted Series with classes in canonical order.
|
||||
|
||||
"""
|
||||
available_classes = series.index.tolist()
|
||||
ordered_classes = get_ordered_classes(task, available_classes)
|
||||
|
||||
# Reindex to get proper order
|
||||
return series.reindex(ordered_classes)
|
||||
|
||||
|
||||
def render_inference_statistics(predictions_gdf: gpd.GeoDataFrame, task: str):
|
||||
"""Render summary statistics about inference results.
|
||||
|
||||
Args:
|
||||
predictions_gdf: GeoDataFrame with predictions.
|
||||
task: Task type ('binary', 'count', 'density').
|
||||
|
||||
"""
|
||||
st.subheader("📊 Inference Summary")
|
||||
|
||||
# Get class distribution
|
||||
class_counts = predictions_gdf["predicted_class"].value_counts()
|
||||
|
||||
# Create metrics layout
|
||||
if task == "binary":
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
st.metric("Total Predictions", f"{len(predictions_gdf):,}")
|
||||
|
||||
with col2:
|
||||
rts_count = class_counts.get("RTS", 0)
|
||||
rts_pct = rts_count / len(predictions_gdf) * 100 if len(predictions_gdf) > 0 else 0
|
||||
st.metric("RTS Predictions", f"{rts_count:,} ({rts_pct:.1f}%)")
|
||||
|
||||
with col3:
|
||||
no_rts_count = class_counts.get("No-RTS", 0)
|
||||
no_rts_pct = no_rts_count / len(predictions_gdf) * 100 if len(predictions_gdf) > 0 else 0
|
||||
st.metric("No-RTS Predictions", f"{no_rts_count:,} ({no_rts_pct:.1f}%)")
|
||||
else:
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
st.metric("Total Predictions", f"{len(predictions_gdf):,}")
|
||||
|
||||
with col2:
|
||||
st.metric("Unique Classes", len(class_counts))
|
||||
|
||||
with col3:
|
||||
most_common = class_counts.index[0] if len(class_counts) > 0 else "N/A"
|
||||
st.metric("Most Common Class", most_common)
|
||||
|
||||
|
||||
def render_class_distribution_histogram(predictions_gdf: gpd.GeoDataFrame, task: str):
|
||||
"""Render histogram of predicted class distribution.
|
||||
|
||||
Args:
|
||||
predictions_gdf: GeoDataFrame with predictions.
|
||||
task: Task type ('binary', 'count_regimes', 'density_regimes', 'count', 'density').
|
||||
|
||||
"""
|
||||
st.subheader("📊 Predicted Class Distribution")
|
||||
|
||||
# Get class counts and order them properly
|
||||
class_counts = predictions_gdf["predicted_class"].value_counts()
|
||||
class_counts = sort_class_series(class_counts, task)
|
||||
|
||||
# Get colors based on task
|
||||
categories = class_counts.index.tolist()
|
||||
colors = get_palette(task, len(categories))
|
||||
|
||||
# Create bar chart
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=categories,
|
||||
y=class_counts.values,
|
||||
marker_color=colors,
|
||||
opacity=0.9,
|
||||
text=class_counts.to_numpy(),
|
||||
textposition="outside",
|
||||
textfont={"size": 12},
|
||||
hovertemplate="<b>%{x}</b><br>Count: %{y:,}<br>Percentage: %{customdata:.1f}%<extra></extra>",
|
||||
customdata=class_counts.to_numpy() / len(predictions_gdf) * 100,
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
height=400,
|
||||
margin={"l": 20, "r": 20, "t": 40, "b": 20},
|
||||
showlegend=False,
|
||||
xaxis_title="Predicted Class",
|
||||
yaxis_title="Count",
|
||||
xaxis={"tickangle": -45 if len(categories) > 3 else 0},
|
||||
)
|
||||
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
# Show percentages in a table
|
||||
with st.expander("📋 Detailed Class Distribution", expanded=False):
|
||||
distribution_df = pd.DataFrame(
|
||||
{
|
||||
"Class": categories,
|
||||
"Count": class_counts.to_numpy(),
|
||||
"Percentage": (class_counts.to_numpy() / len(predictions_gdf) * 100).round(2),
|
||||
}
|
||||
)
|
||||
st.dataframe(distribution_df, hide_index=True, width="stretch")
|
||||
|
||||
|
||||
def render_spatial_distribution_stats(predictions_gdf: gpd.GeoDataFrame):
|
||||
"""Render spatial statistics about predictions.
|
||||
|
||||
Args:
|
||||
predictions_gdf: GeoDataFrame with predictions.
|
||||
|
||||
"""
|
||||
st.subheader("🌍 Spatial Coverage")
|
||||
|
||||
# Calculate spatial extent
|
||||
bounds = predictions_gdf.to_crs("EPSG:4326").total_bounds
|
||||
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
st.metric("Min Latitude", f"{bounds[1]:.2f}°")
|
||||
|
||||
with col2:
|
||||
st.metric("Max Latitude", f"{bounds[3]:.2f}°")
|
||||
|
||||
with col3:
|
||||
st.metric("Min Longitude", f"{bounds[0]:.2f}°")
|
||||
|
||||
with col4:
|
||||
st.metric("Max Longitude", f"{bounds[2]:.2f}°")
|
||||
|
||||
# Calculate total area if cell_area is available
|
||||
if "cell_area" in predictions_gdf.columns:
|
||||
total_area = predictions_gdf["cell_area"].sum()
|
||||
st.info(f"📏 **Total Area Covered:** {total_area:,.0f} km²")
|
||||
|
||||
|
||||
def _prepare_geojson_features(display_gdf_wgs84: gpd.GeoDataFrame) -> list:
|
||||
"""Convert GeoDataFrame to GeoJSON features for pydeck.
|
||||
|
||||
Args:
|
||||
display_gdf_wgs84: GeoDataFrame in WGS84 projection with required columns.
|
||||
|
||||
Returns:
|
||||
List of GeoJSON feature dictionaries.
|
||||
|
||||
"""
|
||||
geojson_data = []
|
||||
for _, row in display_gdf_wgs84.iterrows():
|
||||
feature = {
|
||||
"type": "Feature",
|
||||
"geometry": row["geometry"].__geo_interface__,
|
||||
"properties": {
|
||||
"cell_id": str(row["cell_id"]),
|
||||
"predicted_class": str(row["predicted_class"]),
|
||||
"fill_color": row["fill_color"],
|
||||
"elevation": float(row["elevation"]),
|
||||
},
|
||||
}
|
||||
geojson_data.append(feature)
|
||||
return geojson_data
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_inference_map(result: TrainingResult):
|
||||
"""Render 3D pydeck map showing inference results with interactive controls.
|
||||
|
||||
This is a Streamlit fragment that reruns independently when users interact with the
|
||||
visualization controls (color mode and opacity), without re-running the entire page.
|
||||
|
||||
Args:
|
||||
result: TrainingResult object containing prediction data.
|
||||
|
||||
"""
|
||||
st.subheader("🗺️ Inference Results Map")
|
||||
|
||||
# Load predictions
|
||||
preds_gdf = gpd.read_parquet(result.path / "predicted_probabilities.parquet")
|
||||
|
||||
# Get settings
|
||||
task = result.settings.task
|
||||
grid = result.settings.grid
|
||||
|
||||
# Create controls in columns
|
||||
col1, col2, col3 = st.columns([2, 2, 1])
|
||||
|
||||
with col1:
|
||||
# Get unique classes for filtering (properly ordered)
|
||||
all_classes = get_ordered_classes(task, preds_gdf["predicted_class"].unique().tolist())
|
||||
filter_options = ["All Classes", *all_classes]
|
||||
|
||||
selected_filter = st.selectbox(
|
||||
"Filter by Predicted Class",
|
||||
options=filter_options,
|
||||
key="inference_map_filter",
|
||||
)
|
||||
|
||||
with col2:
|
||||
use_elevation = st.checkbox(
|
||||
"Enable 3D Elevation",
|
||||
value=True,
|
||||
help="Show predictions with elevation (requires count/density for meaningful height)",
|
||||
key="inference_map_elevation",
|
||||
)
|
||||
|
||||
with col3:
|
||||
opacity = st.slider(
|
||||
"Opacity",
|
||||
min_value=0.1,
|
||||
max_value=1.0,
|
||||
value=0.7,
|
||||
step=0.1,
|
||||
key="inference_map_opacity",
|
||||
)
|
||||
|
||||
# Filter data if needed
|
||||
if selected_filter != "All Classes":
|
||||
display_gdf = preds_gdf[preds_gdf["predicted_class"] == selected_filter].copy()
|
||||
else:
|
||||
display_gdf = preds_gdf.copy()
|
||||
|
||||
if len(display_gdf) == 0:
|
||||
st.warning(f"No predictions found for filter: {selected_filter}")
|
||||
return
|
||||
|
||||
st.info(f"Displaying {len(display_gdf):,} out of {len(preds_gdf):,} total predictions")
|
||||
|
||||
# Convert to WGS84 for pydeck
|
||||
display_gdf_wgs84 = display_gdf.to_crs("EPSG:4326")
|
||||
|
||||
# Fix antimeridian issues for hex grids
|
||||
if grid == "hex":
|
||||
display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(fix_hex_geometry)
|
||||
|
||||
# Assign colors based on predicted class (using canonical ordering)
|
||||
# Get all possible classes for this task to ensure consistent colors
|
||||
canonical_classes = get_ordered_classes(task)
|
||||
colors_palette = get_palette(task, len(canonical_classes))
|
||||
|
||||
# Create color mapping for canonical classes
|
||||
color_map = {cls: colors_palette[i] for i, cls in enumerate(canonical_classes)}
|
||||
|
||||
# Convert hex colors to RGB
|
||||
def hex_to_rgb(hex_color):
|
||||
hex_color = hex_color.lstrip("#")
|
||||
return [int(hex_color[i : i + 2], 16) for i in (0, 2, 4)]
|
||||
|
||||
display_gdf_wgs84["fill_color"] = display_gdf_wgs84["predicted_class"].map(
|
||||
{cls: hex_to_rgb(color) for cls, color in color_map.items()}
|
||||
)
|
||||
|
||||
# Add elevation based on class encoding (for ordered classes)
|
||||
if use_elevation and len(all_classes) > 1:
|
||||
# Create a normalized elevation based on class order
|
||||
class_to_elevation = {cls: i / (len(all_classes) - 1) for i, cls in enumerate(all_classes)}
|
||||
display_gdf_wgs84["elevation"] = display_gdf_wgs84["predicted_class"].map(class_to_elevation)
|
||||
else:
|
||||
display_gdf_wgs84["elevation"] = 0.0
|
||||
|
||||
# Convert to GeoJSON format
|
||||
geojson_data = _prepare_geojson_features(display_gdf_wgs84)
|
||||
|
||||
# Create pydeck layer
|
||||
layer = pdk.Layer(
|
||||
"GeoJsonLayer",
|
||||
geojson_data,
|
||||
opacity=opacity,
|
||||
stroked=True,
|
||||
filled=True,
|
||||
extruded=use_elevation,
|
||||
wireframe=False,
|
||||
get_fill_color="properties.fill_color",
|
||||
get_line_color=[80, 80, 80],
|
||||
line_width_min_pixels=0.5,
|
||||
get_elevation="properties.elevation" if use_elevation else 0,
|
||||
elevation_scale=500000, # Scale to 500km height
|
||||
pickable=True,
|
||||
)
|
||||
|
||||
# Set initial view state (centered on the Arctic)
|
||||
view_state = pdk.ViewState(
|
||||
latitude=70,
|
||||
longitude=0,
|
||||
zoom=2 if not use_elevation else 1.5,
|
||||
pitch=0 if not use_elevation else 45,
|
||||
)
|
||||
|
||||
# Create deck
|
||||
deck = pdk.Deck(
|
||||
layers=[layer],
|
||||
initial_view_state=view_state,
|
||||
tooltip={
|
||||
"html": "<b>Cell ID:</b> {cell_id}<br/><b>Predicted Class:</b> {predicted_class}",
|
||||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||||
},
|
||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||
)
|
||||
|
||||
# Render the map
|
||||
st.pydeck_chart(deck)
|
||||
|
||||
# Show info about 3D visualization
|
||||
if use_elevation:
|
||||
st.info("💡 3D elevation represents class order. Rotate the map by holding Ctrl/Cmd and dragging.")
|
||||
|
||||
# Add legend
|
||||
with st.expander("Legend", expanded=True):
|
||||
st.markdown("**Predicted Classes:**")
|
||||
|
||||
for cls in all_classes:
|
||||
color_hex = color_map[cls]
|
||||
count = len(display_gdf[display_gdf["predicted_class"] == cls])
|
||||
total_count = len(preds_gdf[preds_gdf["predicted_class"] == cls])
|
||||
percentage = total_count / len(preds_gdf) * 100 if len(preds_gdf) > 0 else 0
|
||||
|
||||
# Show if currently displayed or total count
|
||||
if selected_filter == "All Classes":
|
||||
count_str = f"{count:,} ({percentage:.1f}%)"
|
||||
else:
|
||||
count_str = f"{count:,} displayed / {total_count:,} total ({percentage:.1f}%)"
|
||||
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
|
||||
f'<div style="width: 20px; height: 20px; background-color: {color_hex}; '
|
||||
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
|
||||
f"<span>{cls}: {count_str}</span></div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
if use_elevation and len(all_classes) > 1:
|
||||
st.markdown("---")
|
||||
st.markdown("**Elevation (3D):**")
|
||||
st.markdown(f"Height represents class order: {all_classes[0]} (low) → {all_classes[-1]} (high)")
|
||||
|
||||
|
||||
def render_class_comparison(predictions_gdf: gpd.GeoDataFrame, task: str):
|
||||
"""Render comparison plots between different predicted classes.
|
||||
|
||||
Args:
|
||||
predictions_gdf: GeoDataFrame with predictions.
|
||||
task: Task type ('binary', 'count_regimes', 'density_regimes', 'count', 'density').
|
||||
|
||||
"""
|
||||
st.subheader("🔍 Class Comparison")
|
||||
|
||||
# Get class distribution and order properly
|
||||
class_counts = predictions_gdf["predicted_class"].value_counts()
|
||||
class_counts = sort_class_series(class_counts, task)
|
||||
|
||||
if len(class_counts) < 2:
|
||||
st.info("Need at least 2 classes for comparison.")
|
||||
return
|
||||
|
||||
# Create pie chart
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.markdown("**Class Proportions")
|
||||
|
||||
# Get colors matching canonical order
|
||||
canonical_classes = get_ordered_classes(task)
|
||||
all_colors = get_palette(task, len(canonical_classes))
|
||||
color_map = {cls: all_colors[i] for i, cls in enumerate(canonical_classes)}
|
||||
|
||||
# Extract colors for available classes in proper order
|
||||
colors = [color_map[cls] for cls in class_counts.index]
|
||||
|
||||
fig = go.Figure(
|
||||
data=[
|
||||
go.Pie(
|
||||
labels=class_counts.index,
|
||||
values=class_counts.values,
|
||||
marker_colors=colors,
|
||||
textinfo="label+percent",
|
||||
textposition="auto",
|
||||
hovertemplate="<b>%{label}</b><br>Count: %{value:,}<br>Percentage: %{percent}<extra></extra>",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
height=400,
|
||||
margin={"l": 20, "r": 20, "t": 20, "b": 20},
|
||||
showlegend=True,
|
||||
)
|
||||
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
with col2:
|
||||
st.markdown("**Cumulative Distribution")
|
||||
|
||||
# Create cumulative distribution
|
||||
sorted_counts = class_counts.sort_values(ascending=False)
|
||||
cumulative = sorted_counts.cumsum()
|
||||
cumulative_pct = cumulative / cumulative.iloc[-1] * 100
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=list(range(len(cumulative))),
|
||||
y=cumulative_pct.to_numpy(),
|
||||
mode="lines+markers",
|
||||
line={"color": colors[0], "width": 3},
|
||||
marker={"size": 8},
|
||||
customdata=sorted_counts.index,
|
||||
hovertemplate="<b>%{customdata}</b><br>Cumulative: %{y:.1f}%<extra></extra>",
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
height=400,
|
||||
margin={"l": 20, "r": 20, "t": 20, "b": 20},
|
||||
xaxis_title="Class Rank",
|
||||
yaxis_title="Cumulative Percentage",
|
||||
yaxis={"range": [0, 105]},
|
||||
)
|
||||
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -220,7 +220,7 @@ def _add_regression_subplot(
|
|||
train_values = z_values[dataset.split == "train"]
|
||||
test_values = z_values[dataset.split == "test"]
|
||||
|
||||
# Determine bin edges based on combined data
|
||||
# Determine bin edges based on complete data
|
||||
all_values = pd.concat([train_values, test_values])
|
||||
# Use a reasonable number of bins
|
||||
n_bins = min(30, int(np.sqrt(len(all_values))))
|
||||
|
|
|
|||
|
|
@ -19,25 +19,25 @@ def render_run_information(selected_result: TrainingResult, refit_metric):
|
|||
"""
|
||||
st.header("📋 Run Information")
|
||||
|
||||
grid_config = GridConfig.from_grid_level(f"{selected_result.settings.grid}{selected_result.settings.level}") # ty:ignore[invalid-argument-type]
|
||||
grid_config = GridConfig.from_grid_level((selected_result.run.dataset.grid, selected_result.run.dataset.level))
|
||||
|
||||
col1, col2, col3, col4, col5 = st.columns(5)
|
||||
with col1:
|
||||
st.metric("Task", selected_result.settings.task.capitalize())
|
||||
st.metric("Task", selected_result.run.task.capitalize())
|
||||
with col2:
|
||||
st.metric("Target", selected_result.settings.target.capitalize())
|
||||
st.metric("Target", selected_result.run.target.capitalize())
|
||||
with col3:
|
||||
st.metric("Grid", grid_config.display_name)
|
||||
with col4:
|
||||
st.metric("Model", selected_result.settings.model.upper())
|
||||
st.metric("Model", selected_result.run.model_type.upper())
|
||||
with col5:
|
||||
st.metric("Trials", len(selected_result.results))
|
||||
st.metric("Trials", selected_result.run.n_trials or "N/A")
|
||||
|
||||
st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}")
|
||||
|
||||
|
||||
def _render_metrics(metrics: dict[str, float]):
|
||||
"""Render a set of metrics in a two-column layout.
|
||||
"""Render a set of metrics in a max-five-column layout.
|
||||
|
||||
Args:
|
||||
metrics: Dictionary of metric names and their values.
|
||||
|
|
@ -57,20 +57,25 @@ def render_metrics_section(selected_result: TrainingResult):
|
|||
selected_result: The selected TrainingResult object.
|
||||
|
||||
"""
|
||||
# Extract metrics for each split
|
||||
test_metrics = selected_result.run.get_metrics_from_split("test")
|
||||
train_metrics = selected_result.run.get_metrics_from_split("train")
|
||||
complete_metrics = selected_result.run.get_metrics_from_split("complete")
|
||||
|
||||
# Test
|
||||
st.header("🎯 Test Set Performance")
|
||||
st.caption("Performance metrics on the held-out test set (best model from hyperparameter search)")
|
||||
_render_metrics(selected_result.test_metrics)
|
||||
_render_metrics(test_metrics)
|
||||
|
||||
# Train
|
||||
st.header("🏋️♂️ Training Set Performance")
|
||||
st.caption("Performance metrics on the training set (best model from hyperparameter search)")
|
||||
_render_metrics(selected_result.train_metrics)
|
||||
_render_metrics(train_metrics)
|
||||
|
||||
# Combined / All
|
||||
# Complete / All
|
||||
st.header("🧮 Overall Performance")
|
||||
st.caption("Overall performance metrics combining training and test sets")
|
||||
_render_metrics(selected_result.combined_metrics)
|
||||
_render_metrics(complete_metrics)
|
||||
|
||||
|
||||
@st.fragment
|
||||
|
|
@ -84,7 +89,7 @@ def render_confusion_matrices(selected_result: TrainingResult):
|
|||
st.header("🎭 Confusion Matrices")
|
||||
|
||||
# Check if this is a classification task
|
||||
if selected_result.settings.task not in ["binary", "count_regimes", "density_regimes"]:
|
||||
if selected_result.run.task not in ["binary", "count_regimes", "density_regimes"]:
|
||||
st.info(
|
||||
"📊 Confusion matrices are only available for classification tasks "
|
||||
"(binary, count_regimes, density_regimes)."
|
||||
|
|
@ -93,11 +98,11 @@ def render_confusion_matrices(selected_result: TrainingResult):
|
|||
return
|
||||
|
||||
# Check if confusion matrix data is available
|
||||
if selected_result.confusion_matrix is None:
|
||||
if selected_result.run.confusion_matrix is None:
|
||||
st.warning("⚠️ No confusion matrix data found for this training result.")
|
||||
return
|
||||
|
||||
cm = selected_result.confusion_matrix
|
||||
cm = selected_result.run.confusion_matrix
|
||||
|
||||
# Add normalization selection
|
||||
st.subheader("Display Options")
|
||||
|
|
@ -131,11 +136,11 @@ def render_confusion_matrices(selected_result: TrainingResult):
|
|||
fig_train = plot_confusion_matrix(cm["train"], title="Training Set", normalize=normalize_mode)
|
||||
st.plotly_chart(fig_train, width="stretch")
|
||||
with cols[2]:
|
||||
# Combined Confusion Matrix
|
||||
st.subheader("Combined")
|
||||
# Complete Confusion Matrix
|
||||
st.subheader("Complete")
|
||||
st.caption("Train + Test sets")
|
||||
fig_combined = plot_confusion_matrix(cm["combined"], title="Combined", normalize=normalize_mode)
|
||||
st.plotly_chart(fig_combined, width="stretch")
|
||||
fig_complete = plot_confusion_matrix(cm["complete"], title="Complete", normalize=normalize_mode)
|
||||
st.plotly_chart(fig_complete, width="stretch")
|
||||
|
||||
|
||||
def render_cv_statistics_section(cv_stats: CVMetricStatistics, test_score: float):
|
||||
|
|
|
|||
|
|
@ -413,13 +413,13 @@ def _render_aggregation_selection(
|
|||
|
||||
col_btn1, col_btn2, col_btn3, _ = st.columns([1, 1, 1, 3])
|
||||
with col_btn1:
|
||||
if st.button("✅ Select All", use_container_width=True):
|
||||
if st.button("✅ Select All", width="content"):
|
||||
_set_all_aggregations(member_datasets, members_with_aggs, member_aggregations, selected=True)
|
||||
with col_btn2:
|
||||
if st.button("📊 Median Only", use_container_width=True):
|
||||
if st.button("📊 Median Only", width="content"):
|
||||
_set_median_only_aggregations(member_datasets, members_with_aggs, member_aggregations)
|
||||
with col_btn3:
|
||||
if st.button("❌ Deselect All", use_container_width=True):
|
||||
if st.button("❌ Deselect All", width="content"):
|
||||
_set_all_aggregations(member_datasets, members_with_aggs, member_aggregations, selected=False)
|
||||
|
||||
# Render the form with checkboxes
|
||||
|
|
|
|||
|
|
@ -8,10 +8,7 @@ from entropice.dashboard.plots.experiment_comparison import (
|
|||
create_feature_consistency_plot,
|
||||
create_feature_importance_by_grid_level,
|
||||
)
|
||||
from entropice.dashboard.utils.loaders import (
|
||||
AutogluonTrainingResult,
|
||||
TrainingResult,
|
||||
)
|
||||
from entropice.dashboard.utils.loaders import TrainingResult
|
||||
|
||||
|
||||
def _extract_feature_importance_from_results(
|
||||
|
|
@ -23,107 +20,25 @@ def _extract_feature_importance_from_results(
|
|||
training_results: List of TrainingResult objects
|
||||
|
||||
Returns:
|
||||
DataFrame with columns: feature, importance, model, grid, level, task, target
|
||||
DataFrame with columns: feature, importance, stddev, model, grid, level, task, target, data_source, grid_level
|
||||
|
||||
"""
|
||||
records = []
|
||||
|
||||
fis = []
|
||||
for tr in training_results:
|
||||
# Load model state if available
|
||||
model_state = tr.load_model_state()
|
||||
if model_state is None:
|
||||
continue
|
||||
fi = tr.run.feature_importance.reset_index().rename(columns={"index": "feature"})
|
||||
fi["model"] = tr.run.model_type
|
||||
fi["grid"] = tr.run.dataset.grid
|
||||
fi["level"] = tr.run.dataset.level
|
||||
fi["task"] = tr.run.task
|
||||
fi["target"] = tr.run.target
|
||||
fis.append(fi)
|
||||
|
||||
info = tr.display_info
|
||||
fi = pd.concat(fis, ignore_index=True)
|
||||
# Add data source categorization
|
||||
fi["data_source"] = fi["feature"].apply(_categorize_feature)
|
||||
fi["grid_level"] = fi["grid"] + "_" + fi["level"].astype(str)
|
||||
|
||||
# Extract feature importance based on available data
|
||||
if "feature_importance" in model_state.data_vars:
|
||||
# eSPA or similar models with direct feature importance
|
||||
importance_data = model_state["feature_importance"]
|
||||
for feature_idx, feature_name in enumerate(importance_data.coords["feature"].values):
|
||||
importance_value = float(importance_data.isel(feature=feature_idx).values)
|
||||
records.append(
|
||||
{
|
||||
"feature": str(feature_name),
|
||||
"importance": importance_value,
|
||||
"model": info.model,
|
||||
"grid": info.grid,
|
||||
"level": info.level,
|
||||
"task": info.task,
|
||||
"target": info.target,
|
||||
}
|
||||
)
|
||||
elif "gain" in model_state.data_vars:
|
||||
# XGBoost-style feature importance
|
||||
gain_data = model_state["gain"]
|
||||
for feature_idx, feature_name in enumerate(gain_data.coords["feature"].values):
|
||||
importance_value = float(gain_data.isel(feature=feature_idx).values)
|
||||
records.append(
|
||||
{
|
||||
"feature": str(feature_name),
|
||||
"importance": importance_value,
|
||||
"model": info.model,
|
||||
"grid": info.grid,
|
||||
"level": info.level,
|
||||
"task": info.task,
|
||||
"target": info.target,
|
||||
}
|
||||
)
|
||||
elif "feature_importances_" in model_state.data_vars:
|
||||
# Random Forest style
|
||||
importance_data = model_state["feature_importances_"]
|
||||
for feature_idx, feature_name in enumerate(importance_data.coords["feature"].values):
|
||||
importance_value = float(importance_data.isel(feature=feature_idx).values)
|
||||
records.append(
|
||||
{
|
||||
"feature": str(feature_name),
|
||||
"importance": importance_value,
|
||||
"model": info.model,
|
||||
"grid": info.grid,
|
||||
"level": info.level,
|
||||
"task": info.task,
|
||||
"target": info.target,
|
||||
}
|
||||
)
|
||||
|
||||
return pd.DataFrame(records)
|
||||
|
||||
|
||||
def _extract_feature_importance_from_autogluon(
|
||||
autogluon_results: list[AutogluonTrainingResult],
|
||||
) -> pd.DataFrame:
|
||||
"""Extract feature importance from AutoGluon results.
|
||||
|
||||
Args:
|
||||
autogluon_results: List of AutogluonTrainingResult objects
|
||||
|
||||
Returns:
|
||||
DataFrame with columns: feature, importance, model, grid, level, task, target
|
||||
|
||||
"""
|
||||
records = []
|
||||
|
||||
for ag in autogluon_results:
|
||||
if ag.feature_importance is None:
|
||||
continue
|
||||
|
||||
info = ag.display_info
|
||||
|
||||
# AutoGluon feature importance is already a DataFrame with features as index
|
||||
for feature_name, importance_value in ag.feature_importance["importance"].items():
|
||||
records.append(
|
||||
{
|
||||
"feature": str(feature_name),
|
||||
"importance": float(importance_value),
|
||||
"model": "autogluon",
|
||||
"grid": info.grid,
|
||||
"level": info.level,
|
||||
"task": info.task,
|
||||
"target": info.target,
|
||||
}
|
||||
)
|
||||
|
||||
return pd.DataFrame(records)
|
||||
return fi
|
||||
|
||||
|
||||
def _categorize_feature(feature_name: str) -> str:
|
||||
|
|
@ -138,46 +53,12 @@ def _categorize_feature(feature_name: str) -> str:
|
|||
return "General"
|
||||
|
||||
|
||||
def _prepare_feature_importance_data(
|
||||
training_results: list[TrainingResult],
|
||||
autogluon_results: list[AutogluonTrainingResult],
|
||||
) -> pd.DataFrame | None:
|
||||
"""Extract and prepare feature importance data.
|
||||
|
||||
Args:
|
||||
training_results: List of RandomSearchCV training results
|
||||
autogluon_results: List of AutoGluon training results
|
||||
|
||||
Returns:
|
||||
DataFrame with feature importance data or None if no data available
|
||||
|
||||
"""
|
||||
fi_df_cv = _extract_feature_importance_from_results(training_results)
|
||||
fi_df_ag = _extract_feature_importance_from_autogluon(autogluon_results)
|
||||
|
||||
if fi_df_cv.empty and fi_df_ag.empty:
|
||||
return None
|
||||
|
||||
# Combine both
|
||||
fi_df = pd.concat([fi_df_cv, fi_df_ag], ignore_index=True)
|
||||
|
||||
# Add data source categorization
|
||||
fi_df["data_source"] = fi_df["feature"].apply(_categorize_feature)
|
||||
fi_df["grid_level"] = fi_df["grid"] + "_" + fi_df["level"].astype(str)
|
||||
|
||||
return fi_df
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_feature_importance_analysis(
|
||||
training_results: list[TrainingResult],
|
||||
autogluon_results: list[AutogluonTrainingResult],
|
||||
):
|
||||
def render_feature_importance_analysis(training_results: list[TrainingResult]):
|
||||
"""Render feature importance analysis section.
|
||||
|
||||
Args:
|
||||
training_results: List of RandomSearchCV training results
|
||||
autogluon_results: List of AutoGluon training results
|
||||
|
||||
"""
|
||||
st.header("🔍 Feature Importance Analysis")
|
||||
|
|
@ -191,13 +72,13 @@ def render_feature_importance_analysis(
|
|||
|
||||
# Extract feature importance
|
||||
with st.spinner("Extracting feature importance from training results..."):
|
||||
fi_df = _prepare_feature_importance_data(training_results, autogluon_results)
|
||||
fi = _extract_feature_importance_from_results(training_results)
|
||||
|
||||
if fi_df is None:
|
||||
if fi is None:
|
||||
st.warning("No feature importance data available. Model state files may be missing.")
|
||||
return
|
||||
|
||||
st.success(f"Extracted feature importance from {len(fi_df)} feature-model combinations")
|
||||
st.success(f"Extracted feature importance from {len(fi)} feature-model combinations")
|
||||
|
||||
# Filters
|
||||
st.subheader("Filters")
|
||||
|
|
@ -205,12 +86,12 @@ def render_feature_importance_analysis(
|
|||
|
||||
with col1:
|
||||
# Task filter
|
||||
available_tasks = ["All", *sorted(fi_df["task"].unique().tolist())]
|
||||
available_tasks = ["All", *sorted(fi["task"].unique().tolist())]
|
||||
selected_task = st.selectbox("Task", options=available_tasks, index=0, key="fi_task_filter")
|
||||
|
||||
with col2:
|
||||
# Target filter
|
||||
available_targets = ["All", *sorted(fi_df["target"].unique().tolist())]
|
||||
available_targets = ["All", *sorted(fi["target"].unique().tolist())]
|
||||
selected_target = st.selectbox("Target Dataset", options=available_targets, index=0, key="fi_target_filter")
|
||||
|
||||
with col3:
|
||||
|
|
@ -218,13 +99,12 @@ def render_feature_importance_analysis(
|
|||
top_n_features = st.number_input("Top N Features", min_value=5, max_value=50, value=15, key="top_n_features")
|
||||
|
||||
# Apply filters
|
||||
filtered_fi_df = fi_df.copy()
|
||||
filtered_fi = fi.copy()
|
||||
if selected_task != "All":
|
||||
filtered_fi_df = filtered_fi_df.loc[filtered_fi_df["task"] == selected_task]
|
||||
filtered_fi = filtered_fi.loc[filtered_fi["task"] == selected_task]
|
||||
if selected_target != "All":
|
||||
filtered_fi_df = filtered_fi_df.loc[filtered_fi_df["target"] == selected_target]
|
||||
|
||||
if len(filtered_fi_df) == 0:
|
||||
filtered_fi = filtered_fi.loc[filtered_fi["target"] == selected_target]
|
||||
if len(filtered_fi) == 0:
|
||||
st.warning("No feature importance data available for the selected filters.")
|
||||
return
|
||||
|
||||
|
|
@ -232,17 +112,17 @@ def render_feature_importance_analysis(
|
|||
st.subheader("Top Features by Grid Level")
|
||||
|
||||
try:
|
||||
fig = create_feature_importance_by_grid_level(filtered_fi_df, top_n=top_n_features)
|
||||
fig = create_feature_importance_by_grid_level(filtered_fi, top_n=top_n_features)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
except Exception as e:
|
||||
st.error(f"Could not create feature importance by grid level plot: {e}")
|
||||
|
||||
# Show detailed breakdown in expander
|
||||
grid_levels = sorted(filtered_fi_df["grid_level"].unique())
|
||||
grid_levels = sorted(filtered_fi["grid_level"].unique())
|
||||
|
||||
with st.expander("Show Detailed Breakdown by Grid Level", expanded=False):
|
||||
for grid_level in grid_levels:
|
||||
grid_data = filtered_fi_df[filtered_fi_df["grid_level"] == grid_level]
|
||||
grid_data = filtered_fi[filtered_fi["grid_level"] == grid_level]
|
||||
|
||||
# Get top features for this grid level
|
||||
top_features_grid = (
|
||||
|
|
@ -271,7 +151,7 @@ def render_feature_importance_analysis(
|
|||
)
|
||||
|
||||
try:
|
||||
fig = create_feature_consistency_plot(filtered_fi_df, top_n=top_n_features)
|
||||
fig = create_feature_consistency_plot(filtered_fi, top_n=top_n_features)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
except Exception as e:
|
||||
st.error(f"Could not create feature consistency plot: {e}")
|
||||
|
|
@ -280,7 +160,7 @@ def render_feature_importance_analysis(
|
|||
with st.expander("Show Detailed Statistics", expanded=False):
|
||||
# Get top features overall
|
||||
overall_top_features = (
|
||||
filtered_fi_df.groupby("feature")["importance"]
|
||||
filtered_fi.groupby("feature")["importance"]
|
||||
.mean()
|
||||
.reset_index()
|
||||
.nlargest(top_n_features, "importance")["feature"]
|
||||
|
|
@ -289,7 +169,7 @@ def render_feature_importance_analysis(
|
|||
|
||||
# Calculate variance in importance across models for each feature
|
||||
feature_variance = (
|
||||
filtered_fi_df[filtered_fi_df["feature"].isin(overall_top_features)]
|
||||
filtered_fi[filtered_fi["feature"].isin(overall_top_features)]
|
||||
.groupby("feature")["importance"]
|
||||
.agg(["mean", "std", "min", "max"])
|
||||
.reset_index()
|
||||
|
|
@ -299,7 +179,7 @@ def render_feature_importance_analysis(
|
|||
|
||||
# Add data source
|
||||
feature_variance = feature_variance.merge(
|
||||
filtered_fi_df[["feature", "data_source"]].drop_duplicates(), on="feature", how="left"
|
||||
filtered_fi[["feature", "data_source"]].drop_duplicates(), on="feature", how="left"
|
||||
)
|
||||
|
||||
feature_variance.columns = ["Feature", "Mean", "Std Dev", "Min", "Max", "CV", "Data Source"]
|
||||
|
|
@ -314,7 +194,7 @@ def render_feature_importance_analysis(
|
|||
st.subheader("Feature Importance by Data Source")
|
||||
|
||||
try:
|
||||
fig = create_data_source_importance_bars(filtered_fi_df)
|
||||
fig = create_data_source_importance_bars(filtered_fi)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
except Exception as e:
|
||||
st.error(f"Could not create data source importance chart: {e}")
|
||||
|
|
@ -322,9 +202,7 @@ def render_feature_importance_analysis(
|
|||
# Show detailed table in expander
|
||||
with st.expander("Show Data Source Statistics", expanded=False):
|
||||
# Aggregate importance by data source
|
||||
source_importance = (
|
||||
filtered_fi_df.groupby("data_source")["importance"].agg(["sum", "mean", "count"]).reset_index()
|
||||
)
|
||||
source_importance = filtered_fi.groupby("data_source")["importance"].agg(["sum", "mean", "count"]).reset_index()
|
||||
source_importance.columns = ["Data Source", "Total Importance", "Mean Importance", "Feature Count"]
|
||||
source_importance = source_importance.sort_values("Total Importance", ascending=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def render_grid_level_analysis(summary_df: pd.DataFrame, available_metrics: list
|
|||
unique_tasks = summary_df["task"].unique()
|
||||
|
||||
# Split selection
|
||||
split = st.selectbox("Data Split", options=["test", "train", "combined"], index=0, key="grid_split")
|
||||
split = st.selectbox("Data Split", options=["test", "train", "complete"], index=0, key="grid_split")
|
||||
|
||||
# Create plots for each task
|
||||
for task in sorted(unique_tasks):
|
||||
|
|
|
|||
|
|
@ -34,12 +34,12 @@ def render_inference_maps_section(
|
|||
|
||||
# Extract unique grid configurations from training results
|
||||
available_grid_configs = sorted(
|
||||
{GridConfig.from_grid_level((tr.settings.grid, tr.settings.level)) for tr in training_results},
|
||||
{GridConfig.from_grid_level((tr.run.dataset.grid, tr.run.dataset.level)) for tr in training_results},
|
||||
key=lambda gc: gc.sort_key,
|
||||
)
|
||||
available_tasks = sorted({tr.settings.task for tr in training_results})
|
||||
available_targets = sorted({tr.settings.target for tr in training_results})
|
||||
available_models = sorted({tr.settings.model for tr in training_results})
|
||||
available_tasks = sorted({tr.run.task for tr in training_results})
|
||||
available_targets = sorted({tr.run.target for tr in training_results})
|
||||
available_models = sorted({tr.run.model_type for tr in training_results})
|
||||
|
||||
# Create form for selecting parameters
|
||||
with st.form("inference_map_form"):
|
||||
|
|
@ -92,11 +92,11 @@ def render_inference_maps_section(
|
|||
filtered_results = [
|
||||
tr
|
||||
for tr in training_results
|
||||
if tr.settings.grid == selected_grid
|
||||
and tr.settings.level == selected_level
|
||||
and tr.settings.task == selected_task
|
||||
and tr.settings.target in selected_targets
|
||||
and tr.settings.model in selected_models
|
||||
if tr.run.dataset.grid == selected_grid
|
||||
and tr.run.dataset.level == selected_level
|
||||
and tr.run.task == selected_task
|
||||
and tr.run.target in selected_targets
|
||||
and tr.run.model_type in selected_models
|
||||
]
|
||||
|
||||
if not filtered_results:
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def render_model_comparison(summary_df: pd.DataFrame, available_metrics: list[st
|
|||
}
|
||||
|
||||
# Split selection
|
||||
split = st.selectbox("Data Split", options=["test", "train", "combined"], index=0, key="model_split")
|
||||
split = st.selectbox("Data Split", options=["test", "train", "complete"], index=0, key="model_split")
|
||||
|
||||
# Get unique tasks
|
||||
unique_tasks = summary_df["task"].unique()
|
||||
|
|
|
|||
|
|
@ -3,11 +3,7 @@
|
|||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.utils.loaders import (
|
||||
AutogluonTrainingResult,
|
||||
TrainingResult,
|
||||
get_available_experiments,
|
||||
)
|
||||
from entropice.dashboard.utils.loaders import TrainingResult, get_available_experiments
|
||||
|
||||
|
||||
def render_experiment_sidebar() -> str | None:
|
||||
|
|
@ -38,7 +34,6 @@ def render_experiment_sidebar() -> str | None:
|
|||
def render_experiment_overview(
|
||||
experiment_name: str,
|
||||
training_results: list[TrainingResult],
|
||||
autogluon_results: list[AutogluonTrainingResult],
|
||||
summary_df: pd.DataFrame,
|
||||
):
|
||||
"""Render experiment overview section.
|
||||
|
|
@ -46,7 +41,6 @@ def render_experiment_overview(
|
|||
Args:
|
||||
experiment_name: Name of the experiment
|
||||
training_results: List of RandomSearchCV training results
|
||||
autogluon_results: List of AutoGluon training results
|
||||
summary_df: Summary DataFrame with all results
|
||||
|
||||
"""
|
||||
|
|
@ -56,13 +50,15 @@ def render_experiment_overview(
|
|||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
st.metric("Total Training Runs", len(training_results) + len(autogluon_results))
|
||||
st.metric("Total Training Runs", len(training_results))
|
||||
|
||||
with col2:
|
||||
st.metric("RandomSearchCV Runs", len(training_results))
|
||||
hpsearch_runs = [tr for tr in training_results if tr.run.method_type == "HPOCV"]
|
||||
st.metric("RandomSearchCV Runs", len(hpsearch_runs))
|
||||
|
||||
with col3:
|
||||
st.metric("AutoGluon Runs", len(autogluon_results))
|
||||
autogluon_runs = [tr for tr in training_results if tr.run.method_type == "AutoML"]
|
||||
st.metric("AutoGluon Runs", len(autogluon_runs))
|
||||
|
||||
with col4:
|
||||
unique_configs = summary_df[["grid", "level", "task", "target"]].drop_duplicates()
|
||||
|
|
|
|||
|
|
@ -2,16 +2,15 @@
|
|||
|
||||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.utils.loaders import AutogluonTrainingResult, TrainingResult
|
||||
from entropice.dashboard.utils.loaders import TrainingResult
|
||||
from entropice.utils.types import (
|
||||
GridConfig,
|
||||
)
|
||||
|
||||
|
||||
def render_training_results_summary(training_results: list[TrainingResult | AutogluonTrainingResult]):
|
||||
def render_training_results_summary(training_results: list[TrainingResult]):
|
||||
"""Render summary metrics for training results."""
|
||||
st.header("📊 Training Results Summary")
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
|
@ -24,7 +23,7 @@ def render_training_results_summary(training_results: list[TrainingResult | Auto
|
|||
st.metric("Total Runs", len(training_results))
|
||||
|
||||
with col3:
|
||||
models = {tr.settings.model for tr in training_results if hasattr(tr.settings, "model")}
|
||||
models = {tr.run.model_type for tr in training_results}
|
||||
st.metric("Model Types", len(models))
|
||||
|
||||
with col4:
|
||||
|
|
@ -34,15 +33,15 @@ def render_training_results_summary(training_results: list[TrainingResult | Auto
|
|||
|
||||
|
||||
@st.fragment
|
||||
def render_experiment_results(training_results: list[TrainingResult | AutogluonTrainingResult]): # noqa: C901
|
||||
def render_experiment_results(training_results: list[TrainingResult]): # noqa: C901
|
||||
"""Render detailed experiment results table and expandable details."""
|
||||
st.header("🎯 Experiment Results")
|
||||
|
||||
# Filters
|
||||
experiments = sorted({tr.experiment for tr in training_results if tr.experiment})
|
||||
tasks = sorted({tr.settings.task for tr in training_results})
|
||||
models = sorted({tr.settings.model if isinstance(tr, TrainingResult) else "autogluon" for tr in training_results})
|
||||
grids = sorted({f"{tr.settings.grid}-{tr.settings.level}" for tr in training_results})
|
||||
tasks = sorted({tr.run.task for tr in training_results})
|
||||
models = sorted({tr.run.model_type for tr in training_results})
|
||||
grids = sorted({f"{tr.run.dataset.grid}-{tr.run.dataset.level}" for tr in training_results})
|
||||
|
||||
# Create filter columns
|
||||
filter_cols = st.columns(4)
|
||||
|
|
@ -83,30 +82,21 @@ def render_experiment_results(training_results: list[TrainingResult | AutogluonT
|
|||
)
|
||||
|
||||
# Apply filters
|
||||
filtered_results = training_results
|
||||
filtered_results: list[TrainingResult] = training_results
|
||||
if selected_experiment != "All":
|
||||
filtered_results = [tr for tr in filtered_results if tr.experiment == selected_experiment]
|
||||
if selected_task != "All":
|
||||
filtered_results = [tr for tr in filtered_results if tr.settings.task == selected_task]
|
||||
if selected_model != "All" and selected_model != "autogluon":
|
||||
filtered_results = [
|
||||
tr for tr in filtered_results if isinstance(tr, TrainingResult) and tr.settings.model == selected_model
|
||||
]
|
||||
elif selected_model == "autogluon":
|
||||
filtered_results = [tr for tr in filtered_results if isinstance(tr, AutogluonTrainingResult)]
|
||||
filtered_results = [tr for tr in filtered_results if tr.run.task == selected_task]
|
||||
if selected_model != "All":
|
||||
filtered_results = [tr for tr in filtered_results if tr.run.model_type == selected_model]
|
||||
if selected_grid != "All":
|
||||
filtered_results = [tr for tr in filtered_results if f"{tr.settings.grid}-{tr.settings.level}" == selected_grid]
|
||||
filtered_results = [
|
||||
tr for tr in filtered_results if f"{tr.run.dataset.grid}-{tr.run.dataset.level}" == selected_grid
|
||||
]
|
||||
|
||||
st.subheader("Results Table")
|
||||
|
||||
summary_df = TrainingResult.to_dataframe([tr for tr in filtered_results if isinstance(tr, TrainingResult)])
|
||||
autogluon_df = AutogluonTrainingResult.to_dataframe(
|
||||
[tr for tr in filtered_results if isinstance(tr, AutogluonTrainingResult)]
|
||||
)
|
||||
if len(summary_df) == 0:
|
||||
summary_df = autogluon_df
|
||||
elif len(autogluon_df) > 0:
|
||||
summary_df = pd.concat([summary_df, autogluon_df], ignore_index=True)
|
||||
summary_df = TrainingResult.to_dataframe(filtered_results)
|
||||
|
||||
# Display with color coding for best scores
|
||||
st.dataframe(
|
||||
|
|
@ -120,25 +110,22 @@ def render_experiment_results(training_results: list[TrainingResult | AutogluonT
|
|||
for tr in filtered_results:
|
||||
tr_info = tr.display_info
|
||||
display_name = tr_info.get_display_name("model_first")
|
||||
model = "autogluon" if isinstance(tr, AutogluonTrainingResult) else tr.settings.model
|
||||
cv_splits = tr.settings.cv_splits if hasattr(tr.settings, "cv_splits") else "N/A"
|
||||
model = tr.run.model_type
|
||||
with st.expander(display_name):
|
||||
col1, col2 = st.columns([1, 2])
|
||||
|
||||
with col1:
|
||||
grid_config = GridConfig.from_grid_level((tr.settings.grid, tr.settings.level))
|
||||
grid_config = GridConfig.from_grid_level((tr.run.dataset.grid, tr.run.dataset.level))
|
||||
st.write(
|
||||
"**Configuration:**\n"
|
||||
f"- **Experiment:** {tr.experiment}\n"
|
||||
f"- **Task:** {tr.settings.task}\n"
|
||||
f"- **Target:** {tr.settings.target}\n"
|
||||
f"- **Task:** {tr.run.task}\n"
|
||||
f"- **Target:** {tr.run.target}\n"
|
||||
f"- **Model:** {model}\n"
|
||||
f"- **Grid:** {grid_config.display_name}\n"
|
||||
f"- **Created At:** {tr_info.timestamp.strftime('%Y-%m-%d %H:%M')}\n"
|
||||
f"- **Temporal Mode:** {tr.settings.temporal_mode}\n"
|
||||
f"- **Members:** {', '.join(tr.settings.members)}\n"
|
||||
f"- **CV Splits:** {cv_splits}\n"
|
||||
f"- **Classes:** {tr.settings.classes}\n"
|
||||
f"- **Temporal Mode:** {tr.run.dataset.temporal_mode}\n"
|
||||
f"- **Members:** {', '.join(tr.run.dataset.members)}\n"
|
||||
)
|
||||
|
||||
file_str = "\n**Files:**\n"
|
||||
|
|
@ -155,29 +142,32 @@ def render_experiment_results(training_results: list[TrainingResult | AutogluonT
|
|||
file_str += f"- 📄 `{file.name}`\n"
|
||||
st.write(file_str)
|
||||
with col2:
|
||||
if isinstance(tr, AutogluonTrainingResult):
|
||||
if tr.run.method_type == "AutoML":
|
||||
st.write("**Leaderboard:**")
|
||||
st.dataframe(tr.leaderboard, width="stretch", hide_index=True)
|
||||
st.dataframe(tr.run.leaderboard, width="stretch", hide_index=True)
|
||||
else:
|
||||
st.write("**CV Score Summary:**")
|
||||
# Extract all test scores
|
||||
metric_df = tr.get_metric_dataframe()
|
||||
metric_df = tr.get_cv_results_dataframe()
|
||||
if metric_df is not None:
|
||||
st.dataframe(metric_df, width="stretch", hide_index=True)
|
||||
else:
|
||||
st.write("No test scores found in results.")
|
||||
|
||||
cv_results = tr.run.cv_results
|
||||
assert cv_results is not None, "CV results should not be None for non-AutoML runs"
|
||||
|
||||
# Show parameter space explored
|
||||
if "initial_K" in tr.results.columns: # Common parameter
|
||||
if "initial_K" in cv_results.columns: # Common parameter
|
||||
st.write("\n**Parameter Ranges Explored:**")
|
||||
for param in ["initial_K", "eps_cl", "eps_e"]:
|
||||
if param in tr.results.columns:
|
||||
min_val = tr.results[param].min()
|
||||
max_val = tr.results[param].max()
|
||||
unique_vals = tr.results[param].nunique()
|
||||
if param in cv_results.columns:
|
||||
min_val = cv_results[param].min()
|
||||
max_val = cv_results[param].max()
|
||||
unique_vals = cv_results[param].nunique()
|
||||
st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
|
||||
|
||||
st.write("**CV Results DataFrame:**")
|
||||
st.dataframe(tr.results, width="stretch", hide_index=True)
|
||||
st.dataframe(cv_results, width="stretch", hide_index=True)
|
||||
|
||||
st.write(f"\n**Path:** `{tr.path}`")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Hyperparameter Space Visualization Section."""
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.plots.hyperparameter_space import (
|
||||
|
|
@ -11,9 +12,10 @@ from entropice.dashboard.plots.hyperparameter_space import (
|
|||
)
|
||||
from entropice.dashboard.utils.formatters import format_metric_name
|
||||
from entropice.dashboard.utils.loaders import TrainingResult
|
||||
from entropice.utils.training import HPOCV
|
||||
|
||||
|
||||
def _render_performance_summary(results, refit_metric: str):
|
||||
def _render_performance_summary(results: pd.DataFrame, refit_metric: str):
|
||||
"""Render performance summary subsection."""
|
||||
best_idx = results[f"mean_test_{refit_metric}"].idxmax()
|
||||
best_row = results.loc[best_idx]
|
||||
|
|
@ -47,7 +49,7 @@ def _render_performance_summary(results, refit_metric: str):
|
|||
st.metric(param_name, formatted_value)
|
||||
|
||||
|
||||
def _render_parameter_distributions(results, param_grid: dict | None):
|
||||
def _render_parameter_distributions(results: pd.DataFrame, param_grid: dict | None):
|
||||
"""Render parameter distributions subsection."""
|
||||
st.subheader("Parameter Distributions")
|
||||
st.caption("Distribution of hyperparameter values explored during random search")
|
||||
|
|
@ -73,7 +75,7 @@ def _render_parameter_distributions(results, param_grid: dict | None):
|
|||
st.plotly_chart(param_charts[param_name], width="stretch")
|
||||
|
||||
|
||||
def _render_score_evolution(results, selected_metric: str):
|
||||
def _render_score_evolution(results: pd.DataFrame, selected_metric: str):
|
||||
"""Render score evolution subsection."""
|
||||
st.subheader("Score Evolution Over Iterations")
|
||||
st.caption(f"How {format_metric_name(selected_metric)} evolved during the random search")
|
||||
|
|
@ -85,7 +87,7 @@ def _render_score_evolution(results, selected_metric: str):
|
|||
st.warning(f"Score evolution not available for metric: {selected_metric}")
|
||||
|
||||
|
||||
def _render_score_vs_parameters(results, selected_metric: str, param_grid: dict | None):
|
||||
def _render_score_vs_parameters(results: pd.DataFrame, selected_metric: str, param_grid: dict | None):
|
||||
"""Render score vs parameters subsection."""
|
||||
st.subheader("Score vs Individual Parameters")
|
||||
st.caption(f"Relationship between {format_metric_name(selected_metric)} and each hyperparameter")
|
||||
|
|
@ -110,7 +112,7 @@ def _render_score_vs_parameters(results, selected_metric: str, param_grid: dict
|
|||
st.plotly_chart(score_vs_param_charts[param_name], width="stretch")
|
||||
|
||||
|
||||
def _render_parameter_correlations(results, selected_metric: str):
|
||||
def _render_parameter_correlations(results: pd.DataFrame, selected_metric: str):
|
||||
"""Render parameter correlations subsection."""
|
||||
st.subheader("Parameter-Score Correlations")
|
||||
st.caption(f"Correlation between numeric parameters and {format_metric_name(selected_metric)}")
|
||||
|
|
@ -122,7 +124,7 @@ def _render_parameter_correlations(results, selected_metric: str):
|
|||
st.info("No numeric parameters found for correlation analysis.")
|
||||
|
||||
|
||||
def _render_parameter_interactions(results, selected_metric: str, param_grid: dict | None):
|
||||
def _render_parameter_interactions(results: pd.DataFrame, selected_metric: str, param_grid: dict | None):
|
||||
"""Render parameter interactions subsection."""
|
||||
st.subheader("Parameter Interactions")
|
||||
st.caption(f"Interaction between parameter pairs and their effect on {format_metric_name(selected_metric)}")
|
||||
|
|
@ -154,19 +156,24 @@ def render_hparam_space_section(selected_result: TrainingResult, selected_metric
|
|||
|
||||
"""
|
||||
st.header("🧩 Hyperparameter Space Exploration")
|
||||
if selected_result.run.method_type != "HPOCV":
|
||||
st.warning("Hyperparameter space visualization is only available for RandomSearchCV runs.")
|
||||
return
|
||||
assert isinstance(selected_result.run.method, HPOCV), "Expected method to be HPOCV for HPOCV runs"
|
||||
|
||||
results = selected_result.results
|
||||
cv_results = selected_result.run.cv_results
|
||||
assert cv_results is not None, "CV results should not be None for HPOCV runs"
|
||||
refit_metric = selected_result._get_best_metric_name()
|
||||
param_grid = selected_result.settings.param_grid
|
||||
param_grid = selected_result.run.method.hpconfig
|
||||
|
||||
_render_performance_summary(results, refit_metric)
|
||||
_render_performance_summary(cv_results, refit_metric)
|
||||
|
||||
_render_parameter_distributions(results, param_grid)
|
||||
_render_parameter_distributions(cv_results, param_grid)
|
||||
|
||||
_render_score_evolution(results, selected_metric)
|
||||
_render_score_evolution(cv_results, selected_metric)
|
||||
|
||||
_render_score_vs_parameters(results, selected_metric, param_grid)
|
||||
_render_score_vs_parameters(cv_results, selected_metric, param_grid)
|
||||
|
||||
_render_parameter_correlations(results, selected_metric)
|
||||
_render_parameter_correlations(cv_results, selected_metric)
|
||||
|
||||
_render_parameter_interactions(results, selected_metric, param_grid)
|
||||
_render_parameter_interactions(cv_results, selected_metric, param_grid)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ def render_regression_analysis(selected_result: TrainingResult):
|
|||
st.header("📊 Regression Analysis")
|
||||
|
||||
# Check if this is a regression task
|
||||
if selected_result.settings.task in ["binary", "count_regimes", "density_regimes"]:
|
||||
if selected_result.run.task in ["binary", "count_regimes", "density_regimes"]:
|
||||
st.info("📈 Regression analysis is only available for regression tasks (count, density).")
|
||||
return
|
||||
|
||||
|
|
@ -30,19 +30,18 @@ def render_regression_analysis(selected_result: TrainingResult):
|
|||
# Create DatasetEnsemble from settings
|
||||
with st.spinner("Loading training data to get true values..."):
|
||||
ensemble = DatasetEnsemble(
|
||||
grid=selected_result.settings.grid,
|
||||
level=selected_result.settings.level,
|
||||
members=selected_result.settings.members,
|
||||
temporal_mode=selected_result.settings.temporal_mode,
|
||||
dimension_filters=selected_result.settings.dimension_filters,
|
||||
variable_filters=selected_result.settings.variable_filters,
|
||||
add_lonlat=selected_result.settings.add_lonlat,
|
||||
grid=selected_result.run.dataset.grid,
|
||||
level=selected_result.run.dataset.level,
|
||||
members=selected_result.run.dataset.members,
|
||||
temporal_mode=selected_result.run.dataset.temporal_mode,
|
||||
dimension_filters=selected_result.run.dataset.dimension_filters,
|
||||
variable_filters=selected_result.run.dataset.variable_filters,
|
||||
)
|
||||
|
||||
# Create training set to get true values
|
||||
training_set = ensemble.create_training_set(
|
||||
task=selected_result.settings.task,
|
||||
target=selected_result.settings.target,
|
||||
task=selected_result.run.task,
|
||||
target=selected_result.run.target,
|
||||
device="cpu",
|
||||
cache_mode="read",
|
||||
)
|
||||
|
|
@ -59,7 +58,7 @@ def render_regression_analysis(selected_result: TrainingResult):
|
|||
merged = predictions_df.merge(true_values, on="cell_id", how="inner")
|
||||
merged["split"] = split_series.reindex(merged["cell_id"]).values
|
||||
|
||||
# Get train, test, and combined data
|
||||
# Get train, test, and complete data
|
||||
train_data = merged[merged["split"] == "train"]
|
||||
test_data = merged[merged["split"] == "test"]
|
||||
|
||||
|
|
@ -94,14 +93,14 @@ def render_regression_analysis(selected_result: TrainingResult):
|
|||
st.plotly_chart(fig_train, use_container_width=True)
|
||||
|
||||
with cols[2]:
|
||||
st.markdown("#### Combined")
|
||||
st.markdown("#### Combplete")
|
||||
st.caption("Train + Test sets")
|
||||
fig_combined = plot_regression_scatter(
|
||||
fig_complete = plot_regression_scatter(
|
||||
merged["y"],
|
||||
merged["predicted"],
|
||||
title="Combined",
|
||||
title="Complete",
|
||||
)
|
||||
st.plotly_chart(fig_combined, use_container_width=True)
|
||||
st.plotly_chart(fig_complete, use_container_width=True)
|
||||
|
||||
# Display residual plots
|
||||
st.subheader("Residual Analysis")
|
||||
|
|
@ -118,5 +117,5 @@ def render_regression_analysis(selected_result: TrainingResult):
|
|||
st.plotly_chart(fig_train_res, use_container_width=True)
|
||||
|
||||
with cols[2]:
|
||||
fig_combined_res = plot_residuals(merged["y"], merged["predicted"], title="Combined Residuals")
|
||||
st.plotly_chart(fig_combined_res, use_container_width=True)
|
||||
fig_complete_res = plot_residuals(merged["y"], merged["predicted"], title="Complete Residuals")
|
||||
st.plotly_chart(fig_complete_res, use_container_width=True)
|
||||
|
|
|
|||
|
|
@ -3,24 +3,24 @@
|
|||
import json
|
||||
import pickle
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import antimeridian
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
import toml
|
||||
import xarray as xr
|
||||
from shapely.geometry import shape
|
||||
|
||||
import entropice.spatial.grids
|
||||
import entropice.utils.paths
|
||||
from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo
|
||||
from entropice.ml.autogluon_training import AutoGluonTrainingSettings
|
||||
from entropice.ml.dataset import DatasetEnsemble, TrainingSet
|
||||
from entropice.ml.randomsearch import TrainingSettings
|
||||
from entropice.utils.training import Training
|
||||
from entropice.utils.types import GridConfig, TargetDataset, Task, all_target_datasets, all_tasks
|
||||
|
||||
|
||||
|
|
@ -39,12 +39,7 @@ class TrainingResult:
|
|||
|
||||
path: Path
|
||||
experiment: str
|
||||
settings: TrainingSettings
|
||||
results: pd.DataFrame
|
||||
train_metrics: dict[str, float]
|
||||
test_metrics: dict[str, float]
|
||||
combined_metrics: dict[str, float]
|
||||
confusion_matrix: xr.Dataset | None
|
||||
run: Training
|
||||
created_at: float
|
||||
available_metrics: list[str]
|
||||
files: list[Path]
|
||||
|
|
@ -52,52 +47,16 @@ class TrainingResult:
|
|||
@classmethod
|
||||
def from_path(cls, result_path: Path, experiment_name: str | None = None) -> "TrainingResult":
|
||||
"""Load a TrainingResult from a given result directory path."""
|
||||
result_file = result_path / "search_results.parquet"
|
||||
preds_file = result_path / "predicted_probabilities.parquet"
|
||||
settings_file = result_path / "search_settings.toml"
|
||||
metrics_file = result_path / "metrics.toml"
|
||||
confusion_matrix_file = result_path / "confusion_matrix.nc"
|
||||
all_files = list(result_path.iterdir())
|
||||
if not result_file.exists():
|
||||
raise FileNotFoundError(f"Missing results file in {result_path}")
|
||||
if not settings_file.exists():
|
||||
raise FileNotFoundError(f"Missing settings file in {result_path}")
|
||||
if not preds_file.exists():
|
||||
raise FileNotFoundError(f"Missing predictions file in {result_path}")
|
||||
if not metrics_file.exists():
|
||||
raise FileNotFoundError(f"Missing metrics file in {result_path}")
|
||||
run = Training.load(result_path)
|
||||
|
||||
created_at = result_path.stat().st_ctime
|
||||
settings_dict = toml.load(settings_file)["settings"]
|
||||
|
||||
# Handle backward compatibility: add missing fields with defaults
|
||||
if "classes" not in settings_dict:
|
||||
settings_dict["classes"] = None
|
||||
if "param_grid" not in settings_dict:
|
||||
settings_dict["param_grid"] = {}
|
||||
if "cv_splits" not in settings_dict:
|
||||
settings_dict["cv_splits"] = 5
|
||||
if "metrics" not in settings_dict:
|
||||
settings_dict["metrics"] = []
|
||||
|
||||
settings = TrainingSettings(**settings_dict)
|
||||
results = pd.read_parquet(result_file)
|
||||
metrics = toml.load(metrics_file)
|
||||
if not confusion_matrix_file.exists():
|
||||
confusion_matrix = None
|
||||
else:
|
||||
confusion_matrix = xr.open_dataset(confusion_matrix_file, engine="h5netcdf")
|
||||
available_metrics = [col.replace("mean_test_", "") for col in results.columns if col.startswith("mean_test_")]
|
||||
available_metrics = [str(v) for v in run.metrics["metric"].unique()]
|
||||
|
||||
return cls(
|
||||
path=result_path,
|
||||
experiment=experiment_name or "N/A",
|
||||
settings=settings,
|
||||
results=results,
|
||||
train_metrics=metrics["train_metrics"],
|
||||
test_metrics=metrics["test_metrics"],
|
||||
combined_metrics=metrics["combined_metrics"],
|
||||
confusion_matrix=confusion_matrix,
|
||||
run=run,
|
||||
created_at=created_at,
|
||||
available_metrics=available_metrics,
|
||||
files=all_files,
|
||||
|
|
@ -107,11 +66,11 @@ class TrainingResult:
|
|||
def display_info(self) -> TrainingResultDisplayInfo:
|
||||
"""Get display information for the training result."""
|
||||
return TrainingResultDisplayInfo(
|
||||
task=self.settings.task,
|
||||
target=self.settings.target,
|
||||
model=self.settings.model,
|
||||
grid=self.settings.grid,
|
||||
level=self.settings.level,
|
||||
task=self.run.task,
|
||||
target=self.run.target,
|
||||
model=self.run.model_type,
|
||||
grid=self.run.dataset.grid,
|
||||
level=self.run.dataset.level,
|
||||
timestamp=datetime.fromtimestamp(self.created_at),
|
||||
)
|
||||
|
||||
|
|
@ -155,18 +114,21 @@ class TrainingResult:
|
|||
st.error(f"Error loading predictions: {e}")
|
||||
return None
|
||||
|
||||
def get_metric_dataframe(self) -> pd.DataFrame | None:
|
||||
def get_cv_results_dataframe(self) -> pd.DataFrame | None:
|
||||
"""Get a DataFrame of available metrics for this training result."""
|
||||
metric_cols = [col for col in self.results.columns if col.startswith("mean_test_")]
|
||||
results = self.run.cv_results
|
||||
if results is None:
|
||||
return None
|
||||
metric_cols = [col for col in results.columns if col.startswith("mean_test_")]
|
||||
if not metric_cols:
|
||||
return None
|
||||
metric_data = []
|
||||
for col in metric_cols:
|
||||
metric_name = col.replace("mean_test_", "").replace("neg_", "").title()
|
||||
metrics = self.results[col]
|
||||
metrics = results[col]
|
||||
# Check if the metric is negative
|
||||
if col.startswith("mean_test_neg_"):
|
||||
task_multiplier = 1 if self.settings.task != "density" else 100
|
||||
task_multiplier = 1 if self.run.task != "density" else 100
|
||||
task_multiplier *= -1
|
||||
metrics = metrics * task_multiplier
|
||||
|
||||
|
|
@ -184,7 +146,7 @@ class TrainingResult:
|
|||
|
||||
def _get_best_metric_name(self) -> str:
|
||||
"""Get the primary metric name for a given task."""
|
||||
match self.settings.task:
|
||||
match self.run.task:
|
||||
case "binary":
|
||||
return "f1"
|
||||
case "count_regimes" | "density_regimes":
|
||||
|
|
@ -192,6 +154,16 @@ class TrainingResult:
|
|||
case _: # regression tasks
|
||||
return "r2"
|
||||
|
||||
def _get_best_score(self, split: Literal["train", "test", "complete"]) -> float:
|
||||
"""Get the best score for the primary metric on a given split."""
|
||||
metric = self._get_best_metric_name()
|
||||
scores = self.run.metrics[(self.run.metrics["metric"] == metric) & (self.run.metrics["split"] == split)]
|
||||
# Should leave a single row
|
||||
assert len(scores) == 1, (
|
||||
f"Expected exactly one score for metric {metric} and split {split}, but found {len(scores)}: {scores}"
|
||||
)
|
||||
return float(scores["score"].iloc[0])
|
||||
|
||||
@staticmethod
|
||||
def to_dataframe(training_results: list["TrainingResult"]) -> pd.DataFrame:
|
||||
"""Convert a list of TrainingResult objects to a DataFrame for display."""
|
||||
|
|
@ -199,6 +171,8 @@ class TrainingResult:
|
|||
for tr in training_results:
|
||||
info = tr.display_info
|
||||
best_metric_name = tr._get_best_metric_name()
|
||||
best_train_score = tr._get_best_score("train")
|
||||
best_test_score = tr._get_best_score("test")
|
||||
|
||||
record = {
|
||||
"Experiment": tr.experiment if tr.experiment else "N/A",
|
||||
|
|
@ -208,9 +182,9 @@ class TrainingResult:
|
|||
"Grid": GridConfig.from_grid_level((info.grid, info.level)).display_name,
|
||||
"Created At": info.timestamp.strftime("%Y-%m-%d %H:%M"),
|
||||
"Score-Metric": best_metric_name.title(),
|
||||
"Best Models Score (Train-Set)": tr.train_metrics.get(best_metric_name),
|
||||
"Best Models Score (Test-Set)": tr.test_metrics.get(best_metric_name),
|
||||
"Trials": len(tr.results),
|
||||
"Best Models Score (Train-Set)": best_train_score,
|
||||
"Best Models Score (Test-Set)": best_test_score,
|
||||
"Trials": tr.run.n_trials or "N/A",
|
||||
"Path": str(tr.path.name),
|
||||
}
|
||||
records.append(record)
|
||||
|
|
@ -219,11 +193,15 @@ class TrainingResult:
|
|||
@staticmethod
|
||||
def calculate_inference_maps(training_results: list["TrainingResult"]) -> gpd.GeoDataFrame:
|
||||
"""Calculate the mean and standard deviation of inference maps across multiple training results."""
|
||||
assert len({tr.settings.grid for tr in training_results}) == 1, "All training results must have the same grid"
|
||||
assert len({tr.settings.level for tr in training_results}) == 1, "All training results must have the same level"
|
||||
assert len({tr.run.dataset.grid for tr in training_results}) == 1, (
|
||||
"All training results must have the same grid"
|
||||
)
|
||||
assert len({tr.run.dataset.level for tr in training_results}) == 1, (
|
||||
"All training results must have the same level"
|
||||
)
|
||||
|
||||
grid = training_results[0].settings.grid
|
||||
level = training_results[0].settings.level
|
||||
grid = training_results[0].run.dataset.grid
|
||||
level = training_results[0].run.dataset.level
|
||||
gridfile = entropice.utils.paths.get_grid_file(grid, level)
|
||||
cells = gpd.read_parquet(gridfile, columns=["cell_id", "geometry"])
|
||||
if grid == "hex":
|
||||
|
|
@ -232,11 +210,9 @@ class TrainingResult:
|
|||
|
||||
vals = []
|
||||
for tr in training_results:
|
||||
preds_file = tr.path / "predicted_probabilities.parquet"
|
||||
if not preds_file.exists():
|
||||
continue
|
||||
preds = pd.read_parquet(preds_file, columns=["cell_id", "predicted"]).set_index("cell_id")
|
||||
preds = tr.run.predictions.set_index("cell_id")[["predicted"]]
|
||||
if preds["predicted"].dtype == "category":
|
||||
# We can do this because the categories are ordered
|
||||
preds["predicted"] = preds["predicted"].cat.codes
|
||||
vals.append(preds)
|
||||
all_preds = pd.concat(vals, axis=1)
|
||||
|
|
@ -257,9 +233,6 @@ def load_all_training_results() -> list[TrainingResult]:
|
|||
for result_path in results_dir.iterdir():
|
||||
if not result_path.is_dir():
|
||||
continue
|
||||
# Skip AutoGluon results directory
|
||||
if "autogluon" in result_path.name.lower():
|
||||
continue
|
||||
try:
|
||||
training_result = TrainingResult.from_path(result_path)
|
||||
training_results.append(training_result)
|
||||
|
|
@ -288,155 +261,10 @@ def load_all_training_results() -> list[TrainingResult]:
|
|||
return training_results
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutogluonTrainingResult:
|
||||
"""Wrapper for training result data and metadata."""
|
||||
|
||||
path: Path
|
||||
experiment: str
|
||||
settings: AutoGluonTrainingSettings
|
||||
test_metrics: dict[str, float | dict | pd.DataFrame]
|
||||
leaderboard: pd.DataFrame
|
||||
feature_importance: pd.DataFrame | None
|
||||
created_at: float
|
||||
files: list[Path]
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, result_path: Path, experiment_name: str | None = None) -> "AutogluonTrainingResult":
|
||||
"""Load an AutogluonTrainingResult from a given result directory path."""
|
||||
settings_file = result_path / "training_settings.toml"
|
||||
metrics_file = result_path / "test_metrics.pickle"
|
||||
leaderboard_file = result_path / "leaderboard.parquet"
|
||||
feature_importance_file = result_path / "feature_importance.parquet"
|
||||
all_files = list(result_path.iterdir())
|
||||
if not settings_file.exists():
|
||||
raise FileNotFoundError(f"Missing settings file in {result_path}")
|
||||
if not metrics_file.exists():
|
||||
raise FileNotFoundError(f"Missing metrics file in {result_path}")
|
||||
if not leaderboard_file.exists():
|
||||
raise FileNotFoundError(f"Missing leaderboard file in {result_path}")
|
||||
|
||||
created_at = result_path.stat().st_ctime
|
||||
settings_dict = toml.load(settings_file)["settings"]
|
||||
settings = AutoGluonTrainingSettings(**settings_dict)
|
||||
with open(metrics_file, "rb") as f:
|
||||
metrics = pickle.load(f)
|
||||
leaderboard = pd.read_parquet(leaderboard_file)
|
||||
|
||||
if feature_importance_file.exists():
|
||||
feature_importance = pd.read_parquet(feature_importance_file)
|
||||
else:
|
||||
feature_importance = None
|
||||
|
||||
return cls(
|
||||
path=result_path,
|
||||
experiment=experiment_name or "N/A",
|
||||
settings=settings,
|
||||
test_metrics=metrics,
|
||||
leaderboard=leaderboard,
|
||||
feature_importance=feature_importance,
|
||||
created_at=created_at,
|
||||
files=all_files,
|
||||
)
|
||||
|
||||
@property
|
||||
def test_confusion_matrix(self) -> pd.DataFrame | None:
|
||||
"""Get the test confusion matrix."""
|
||||
if "confusion_matrix" not in self.test_metrics:
|
||||
return None
|
||||
assert isinstance(self.test_metrics["confusion_matrix"], pd.DataFrame)
|
||||
return self.test_metrics["confusion_matrix"]
|
||||
|
||||
@property
|
||||
def display_info(self) -> TrainingResultDisplayInfo:
|
||||
"""Get display information for the training result."""
|
||||
return TrainingResultDisplayInfo(
|
||||
task=self.settings.task,
|
||||
target=self.settings.target,
|
||||
model="autogluon",
|
||||
grid=self.settings.grid,
|
||||
level=self.settings.level,
|
||||
timestamp=datetime.fromtimestamp(self.created_at),
|
||||
)
|
||||
|
||||
def _get_best_metric_name(self) -> str:
|
||||
"""Get the primary metric name for a given task."""
|
||||
match self.settings.task:
|
||||
case "binary":
|
||||
return "f1"
|
||||
case "count_regimes" | "density_regimes":
|
||||
return "f1_weighted"
|
||||
case _: # regression tasks
|
||||
return "root_mean_squared_error"
|
||||
|
||||
@staticmethod
|
||||
def to_dataframe(training_results: list["AutogluonTrainingResult"]) -> pd.DataFrame:
|
||||
"""Convert a list of AutogluonTrainingResult objects to a DataFrame for display."""
|
||||
records = []
|
||||
for tr in training_results:
|
||||
info = tr.display_info
|
||||
best_metric_name = tr._get_best_metric_name()
|
||||
|
||||
record = {
|
||||
"Experiment": tr.experiment if tr.experiment else "N/A",
|
||||
"Task": info.task,
|
||||
"Target": info.target,
|
||||
"Model": info.model,
|
||||
"Grid": GridConfig.from_grid_level((info.grid, info.level)).display_name,
|
||||
"Created At": info.timestamp.strftime("%Y-%m-%d %H:%M"),
|
||||
"Score-Metric": best_metric_name.title(),
|
||||
"Best Models Score (Test-Set)": tr.test_metrics.get(best_metric_name),
|
||||
"Path": str(tr.path.name),
|
||||
}
|
||||
records.append(record)
|
||||
return pd.DataFrame.from_records(records)
|
||||
|
||||
|
||||
@st.cache_data(ttl=300) # Cache for 5 minutes
|
||||
def load_all_autogluon_training_results() -> list[AutogluonTrainingResult]:
|
||||
"""Load all training results from the results directory."""
|
||||
results_dir = entropice.utils.paths.RESULTS_DIR
|
||||
training_results: list[AutogluonTrainingResult] = []
|
||||
incomplete_results: list[tuple[Path, Exception]] = []
|
||||
for result_path in results_dir.iterdir():
|
||||
if not result_path.is_dir():
|
||||
continue
|
||||
# Skip AutoGluon results directory
|
||||
if "autogluon" not in result_path.name.lower():
|
||||
continue
|
||||
try:
|
||||
training_result = AutogluonTrainingResult.from_path(result_path)
|
||||
training_results.append(training_result)
|
||||
except FileNotFoundError as e:
|
||||
is_experiment_dir = False
|
||||
for experiment_path in result_path.iterdir():
|
||||
if not experiment_path.is_dir():
|
||||
continue
|
||||
try:
|
||||
experiment_name = experiment_path.parent.name
|
||||
training_result = AutogluonTrainingResult.from_path(experiment_path, experiment_name)
|
||||
training_results.append(training_result)
|
||||
is_experiment_dir = True
|
||||
except FileNotFoundError as e2:
|
||||
incomplete_results.append((experiment_path, e2))
|
||||
if not is_experiment_dir:
|
||||
incomplete_results.append((result_path, e))
|
||||
|
||||
if len(incomplete_results) > 0:
|
||||
st.warning(
|
||||
f"Found {len(incomplete_results)} incomplete autogluon training results that were skipped:\n - "
|
||||
+ "\n - ".join(f"{p}: {e}" for p, e in incomplete_results)
|
||||
)
|
||||
# Sort by creation time (most recent first)
|
||||
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
|
||||
return training_results
|
||||
|
||||
|
||||
def load_training_sets(ensemble: DatasetEnsemble) -> dict[TargetDataset, dict[Task, TrainingSet]]:
|
||||
"""Load training sets for all target-task combinations in the ensemble."""
|
||||
train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]] = {}
|
||||
train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]] = defaultdict(dict)
|
||||
for target in all_target_datasets:
|
||||
train_data_dict[target] = {}
|
||||
for task in all_tasks:
|
||||
train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task)
|
||||
return train_data_dict
|
||||
|
|
@ -490,10 +318,6 @@ def load_experiment_training_results(experiment_name: str) -> list[TrainingResul
|
|||
for result_path in experiment_dir.iterdir():
|
||||
if not result_path.is_dir():
|
||||
continue
|
||||
# Skip AutoGluon results
|
||||
if "autogluon" in result_path.name.lower():
|
||||
continue
|
||||
|
||||
try:
|
||||
training_result = TrainingResult.from_path(result_path, experiment_name)
|
||||
training_results.append(training_result)
|
||||
|
|
@ -505,47 +329,11 @@ def load_experiment_training_results(experiment_name: str) -> list[TrainingResul
|
|||
return training_results
|
||||
|
||||
|
||||
def load_experiment_autogluon_results(experiment_name: str) -> list[AutogluonTrainingResult]:
|
||||
"""Load all AutoGluon training results for a specific experiment.
|
||||
|
||||
Args:
|
||||
experiment_name: Name of the experiment directory
|
||||
|
||||
Returns:
|
||||
List of AutogluonTrainingResult objects for the experiment
|
||||
|
||||
"""
|
||||
experiment_dir = entropice.utils.paths.RESULTS_DIR / experiment_name
|
||||
if not experiment_dir.exists():
|
||||
return []
|
||||
|
||||
training_results: list[AutogluonTrainingResult] = []
|
||||
for result_path in experiment_dir.iterdir():
|
||||
if not result_path.is_dir():
|
||||
continue
|
||||
# Only include AutoGluon results
|
||||
if "autogluon" not in result_path.name.lower():
|
||||
continue
|
||||
|
||||
try:
|
||||
training_result = AutogluonTrainingResult.from_path(result_path, experiment_name)
|
||||
training_results.append(training_result)
|
||||
except FileNotFoundError:
|
||||
pass # Skip incomplete results
|
||||
|
||||
# Sort by creation time (most recent first)
|
||||
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
|
||||
return training_results
|
||||
|
||||
|
||||
def create_experiment_summary_df(
|
||||
training_results: list[TrainingResult], autogluon_results: list[AutogluonTrainingResult]
|
||||
) -> pd.DataFrame:
|
||||
def create_experiment_summary_df(training_results: list[TrainingResult]) -> pd.DataFrame:
|
||||
"""Create a summary DataFrame for all results in an experiment.
|
||||
|
||||
Args:
|
||||
training_results: List of TrainingResult objects
|
||||
autogluon_results: List of AutogluonTrainingResult objects
|
||||
|
||||
Returns:
|
||||
DataFrame with summary statistics for the experiment
|
||||
|
|
@ -566,55 +354,26 @@ def create_experiment_summary_df(
|
|||
"grid": info.grid,
|
||||
"level": info.level,
|
||||
"grid_level": f"{info.grid}_{info.level}",
|
||||
"train_score": tr.train_metrics.get(best_metric_name, float("nan")),
|
||||
"test_score": tr.test_metrics.get(best_metric_name, float("nan")),
|
||||
"combined_score": tr.combined_metrics.get(best_metric_name, float("nan")),
|
||||
"train_score": tr._get_best_score("train"),
|
||||
"test_score": tr._get_best_score("test"),
|
||||
"complete_score": tr._get_best_score("complete"),
|
||||
"best_metric": best_metric_name,
|
||||
"n_trials": len(tr.results),
|
||||
"n_trials": tr.run.n_trials or "N/A",
|
||||
"created_at": tr.created_at,
|
||||
"path": tr.path,
|
||||
}
|
||||
|
||||
# Add all train metrics
|
||||
for metric, value in tr.train_metrics.items():
|
||||
for metric, value in tr.run.get_metrics_from_split("train").items():
|
||||
record[f"train_{metric}"] = value
|
||||
|
||||
# Add all test metrics
|
||||
for metric, value in tr.test_metrics.items():
|
||||
for metric, value in tr.run.get_metrics_from_split("test").items():
|
||||
record[f"test_{metric}"] = value
|
||||
|
||||
# Add all combined metrics
|
||||
for metric, value in tr.combined_metrics.items():
|
||||
record[f"combined_{metric}"] = value
|
||||
|
||||
records.append(record)
|
||||
|
||||
# Add AutoGluon results
|
||||
for ag in autogluon_results:
|
||||
info = ag.display_info
|
||||
best_metric_name = ag._get_best_metric_name()
|
||||
|
||||
record = {
|
||||
"method": "AutoGluon",
|
||||
"task": info.task,
|
||||
"target": info.target,
|
||||
"model": "ensemble", # AutoGluon is an ensemble
|
||||
"grid": info.grid,
|
||||
"level": info.level,
|
||||
"grid_level": f"{info.grid}_{info.level}",
|
||||
"train_score": float("nan"), # AutoGluon doesn't separate train scores
|
||||
"test_score": ag.test_metrics.get(best_metric_name, float("nan")),
|
||||
"combined_score": float("nan"),
|
||||
"best_metric": best_metric_name,
|
||||
"n_trials": len(ag.leaderboard),
|
||||
"created_at": ag.created_at,
|
||||
"path": ag.path,
|
||||
}
|
||||
|
||||
# Add test metrics
|
||||
for metric, value in ag.test_metrics.items():
|
||||
if isinstance(value, (int, float)):
|
||||
record[f"test_{metric}"] = value
|
||||
# Add all complete metrics
|
||||
for metric, value in tr.run.get_metrics_from_split("complete").items():
|
||||
record[f"complete_{metric}"] = value
|
||||
|
||||
records.append(record)
|
||||
|
||||
|
|
|
|||
|
|
@ -278,11 +278,11 @@ def load_all_default_dataset_statistics() -> dict[GridLevel, dict[TemporalMode,
|
|||
dataset_stats: dict[GridLevel, dict[TemporalMode, DatasetStatistics]] = {}
|
||||
for grid_config in grid_configs:
|
||||
dataset_stats[grid_config.id] = {}
|
||||
with stopwatch(f"Loading statistics for grid={grid_config.grid}, level={grid_config.level}"):
|
||||
grid_gdf = entropice.spatial.grids.open(grid_config.grid, grid_config.level) # Ensure grid is registered
|
||||
total_cells = len(grid_gdf)
|
||||
assert total_cells > 0, "Grid must contain at least one cell."
|
||||
for temporal_mode in all_temporal_modes:
|
||||
with stopwatch(f"Loading statistics for {grid_config.grid=}, {grid_config.level=}, {temporal_mode=}"):
|
||||
e = DatasetEnsemble(grid=grid_config.grid, level=grid_config.level, temporal_mode=temporal_mode)
|
||||
target_statistics = {}
|
||||
for target in all_target_datasets:
|
||||
|
|
@ -437,23 +437,27 @@ class CVMetricStatistics:
|
|||
mean_cv_std: float | None
|
||||
|
||||
@classmethod
|
||||
def compute(cls, result: TrainingResult, metric: str) -> "CVMetricStatistics":
|
||||
def compute(cls, result: TrainingResult, metric: str) -> "CVMetricStatistics | None":
|
||||
"""Get cross-validation statistics for a metric."""
|
||||
score_col = f"mean_test_{metric}"
|
||||
std_col = f"std_test_{metric}"
|
||||
|
||||
if score_col not in result.results.columns:
|
||||
cv_results = result.run.cv_results
|
||||
if cv_results is None:
|
||||
return None
|
||||
|
||||
if score_col not in cv_results.columns:
|
||||
raise ValueError(f"Metric {metric} not found in results.")
|
||||
|
||||
best_score = result.results[score_col].max()
|
||||
mean_score = result.results[score_col].mean()
|
||||
std_score = result.results[score_col].std()
|
||||
worst_score = result.results[score_col].min()
|
||||
median_score = result.results[score_col].median()
|
||||
best_score = cv_results[score_col].max()
|
||||
mean_score = cv_results[score_col].mean()
|
||||
std_score = cv_results[score_col].std()
|
||||
worst_score = cv_results[score_col].min()
|
||||
median_score = cv_results[score_col].median()
|
||||
|
||||
mean_cv_std = None
|
||||
if std_col in result.results.columns:
|
||||
mean_cv_std = result.results[std_col].mean()
|
||||
if std_col in cv_results.columns:
|
||||
mean_cv_std = cv_results[std_col].mean()
|
||||
|
||||
return CVMetricStatistics(
|
||||
best_score=best_score,
|
||||
|
|
@ -477,10 +481,12 @@ class ParameterSpaceSummary:
|
|||
unique_values: int
|
||||
|
||||
@classmethod
|
||||
def compute(cls, result: TrainingResult, param_col: str) -> "ParameterSpaceSummary":
|
||||
def compute(cls, result: TrainingResult, param_col: str) -> "ParameterSpaceSummary | None":
|
||||
"""Get cross-validation statistics for a metric."""
|
||||
if result.run.cv_results is None:
|
||||
return None
|
||||
param_name = param_col.replace("param_", "")
|
||||
param_values = result.results[param_col].dropna()
|
||||
param_values = result.run.cv_results[param_col].dropna()
|
||||
|
||||
if pd.api.types.is_numeric_dtype(param_values):
|
||||
return ParameterSpaceSummary(
|
||||
|
|
@ -511,14 +517,19 @@ class CVResultsStatistics:
|
|||
parameter_summary: list[ParameterSpaceSummary]
|
||||
|
||||
@classmethod
|
||||
def compute(cls, result: TrainingResult) -> "CVResultsStatistics":
|
||||
def compute(cls, result: TrainingResult) -> "CVResultsStatistics | None":
|
||||
"""Get cross-validation statistics for a metric."""
|
||||
if result.run.cv_results is None:
|
||||
return None
|
||||
metrics = result.available_metrics
|
||||
metric_stats: dict[str, CVMetricStatistics] = {}
|
||||
for metric in metrics:
|
||||
metric_stats[metric] = CVMetricStatistics.compute(result, metric)
|
||||
stats = CVMetricStatistics.compute(result, metric)
|
||||
if stats is None:
|
||||
continue
|
||||
metric_stats[metric] = stats
|
||||
|
||||
param_cols = [col for col in result.results.columns if col.startswith("param_") and col != "params"]
|
||||
param_cols = [col for col in result.run.cv_results.columns if col.startswith("param_") and col != "params"]
|
||||
summary_data = []
|
||||
for param_col in param_cols:
|
||||
summary_data.append(ParameterSpaceSummary.compute(result, param_col))
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def render_dataset_configuration_sidebar() -> DatasetEnsemble:
|
|||
# Grid selection
|
||||
grid_options = [gc.display_name for gc in grid_configs]
|
||||
|
||||
grid_level_combined = st.selectbox(
|
||||
grid_level_complete = st.selectbox(
|
||||
"Grid Configuration",
|
||||
options=grid_options,
|
||||
index=0,
|
||||
|
|
@ -43,7 +43,7 @@ def render_dataset_configuration_sidebar() -> DatasetEnsemble:
|
|||
)
|
||||
|
||||
# Find the selected grid config
|
||||
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
|
||||
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_complete)
|
||||
|
||||
# Temporal mode selection
|
||||
temporal_mode = st.selectbox(
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from entropice.dashboard.sections.experiment_overview import (
|
|||
)
|
||||
from entropice.dashboard.utils.loaders import (
|
||||
create_experiment_summary_df,
|
||||
load_experiment_autogluon_results,
|
||||
load_experiment_training_results,
|
||||
)
|
||||
|
||||
|
|
@ -43,14 +42,13 @@ def render_experiment_analysis_page():
|
|||
# Load experiment results
|
||||
with st.spinner(f"Loading results for experiment: {selected_experiment}..."):
|
||||
training_results = load_experiment_training_results(selected_experiment)
|
||||
autogluon_results = load_experiment_autogluon_results(selected_experiment)
|
||||
|
||||
if not training_results and not autogluon_results:
|
||||
if not training_results:
|
||||
st.warning(f"No training results found in experiment: {selected_experiment}")
|
||||
st.stop()
|
||||
|
||||
# Create summary DataFrame
|
||||
summary_df = create_experiment_summary_df(training_results, autogluon_results)
|
||||
summary_df = create_experiment_summary_df(training_results)
|
||||
|
||||
# Get available metrics
|
||||
metric_columns = [col for col in summary_df.columns if col.startswith("test_")]
|
||||
|
|
@ -61,7 +59,7 @@ def render_experiment_analysis_page():
|
|||
st.stop()
|
||||
|
||||
# Render analysis sections
|
||||
render_experiment_overview(selected_experiment, training_results, autogluon_results, summary_df)
|
||||
render_experiment_overview(selected_experiment, training_results, summary_df)
|
||||
|
||||
st.divider()
|
||||
|
||||
|
|
@ -73,7 +71,7 @@ def render_experiment_analysis_page():
|
|||
|
||||
st.divider()
|
||||
|
||||
render_feature_importance_analysis(training_results, autogluon_results)
|
||||
render_feature_importance_analysis(training_results)
|
||||
|
||||
st.divider()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,16 +1,8 @@
|
|||
"""Inference page: Visualization of model inference results across the study region."""
|
||||
|
||||
import geopandas as gpd
|
||||
import streamlit as st
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.dashboard.plots.inference import (
|
||||
render_class_comparison,
|
||||
render_class_distribution_histogram,
|
||||
render_inference_map,
|
||||
render_inference_statistics,
|
||||
render_spatial_distribution_stats,
|
||||
)
|
||||
from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results
|
||||
|
||||
|
||||
|
|
@ -27,7 +19,9 @@ def render_sidebar_selection(training_results: list[TrainingResult]) -> Training
|
|||
st.header("Select Training Run")
|
||||
|
||||
# Create selection options with task-first naming
|
||||
training_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results}
|
||||
training_options: dict[str, TrainingResult] = {
|
||||
tr.display_info.get_display_name("task_first"): tr for tr in training_results
|
||||
}
|
||||
|
||||
selected_name = st.selectbox(
|
||||
"Training Run",
|
||||
|
|
@ -44,113 +38,14 @@ def render_sidebar_selection(training_results: list[TrainingResult]) -> Training
|
|||
# Show run information in sidebar
|
||||
st.subheader("Run Information")
|
||||
|
||||
st.markdown(f"**Task:** {selected_result.settings.task.capitalize()}")
|
||||
st.markdown(f"**Model:** {selected_result.settings.model.upper()}")
|
||||
st.markdown(f"**Grid:** {selected_result.settings.grid.capitalize()}")
|
||||
st.markdown(f"**Level:** {selected_result.settings.level}")
|
||||
st.markdown(f"**Target:** {selected_result.settings.target.replace('darts_', '')}")
|
||||
|
||||
st.markdown(f"**Task:** {selected_result.run.task.capitalize()}")
|
||||
st.markdown(f"**Model:** {selected_result.run.model_type.upper()}")
|
||||
st.markdown(f"**Grid:** {selected_result.run.dataset.grid.capitalize()}")
|
||||
st.markdown(f"**Level:** {selected_result.run.dataset.level}")
|
||||
st.markdown(f"**Target:** {selected_result.run.target.replace('darts_', '')}")
|
||||
return selected_result
|
||||
|
||||
|
||||
def render_run_information(selected_result: TrainingResult):
|
||||
"""Render training run configuration overview.
|
||||
|
||||
Args:
|
||||
selected_result: The selected TrainingResult object.
|
||||
|
||||
"""
|
||||
st.header("📋 Run Configuration")
|
||||
|
||||
col1, col2, col3, col4, col5 = st.columns(5)
|
||||
|
||||
with col1:
|
||||
st.metric("Task", selected_result.settings.task.capitalize())
|
||||
|
||||
with col2:
|
||||
st.metric("Model", selected_result.settings.model.upper())
|
||||
|
||||
with col3:
|
||||
st.metric("Grid", selected_result.settings.grid.capitalize())
|
||||
|
||||
with col4:
|
||||
st.metric("Level", selected_result.settings.level)
|
||||
|
||||
with col5:
|
||||
st.metric("Target", selected_result.settings.target.replace("darts_", ""))
|
||||
|
||||
|
||||
def render_inference_statistics_section(predictions_gdf: gpd.GeoDataFrame, task: str):
|
||||
"""Render inference summary statistics section.
|
||||
|
||||
Args:
|
||||
predictions_gdf: GeoDataFrame with predictions.
|
||||
task: Task type ('binary', 'count_regimes', 'density_regimes', 'count', 'density').
|
||||
|
||||
"""
|
||||
st.header("📊 Inference Summary")
|
||||
render_inference_statistics(predictions_gdf, task)
|
||||
|
||||
|
||||
def render_spatial_coverage_section(predictions_gdf: gpd.GeoDataFrame):
|
||||
"""Render spatial coverage statistics section.
|
||||
|
||||
Args:
|
||||
predictions_gdf: GeoDataFrame with predictions.
|
||||
|
||||
"""
|
||||
st.header("🌍 Spatial Coverage")
|
||||
render_spatial_distribution_stats(predictions_gdf)
|
||||
|
||||
|
||||
def render_map_visualization_section(selected_result: TrainingResult):
|
||||
"""Render 3D map visualization section.
|
||||
|
||||
Args:
|
||||
selected_result: The selected TrainingResult object.
|
||||
|
||||
"""
|
||||
st.header("🗺️ Interactive Prediction Map")
|
||||
st.markdown(
|
||||
"""
|
||||
3D visualization of predictions across the study region. The map shows predicted
|
||||
classes with color coding and spatial distribution of model outputs.
|
||||
"""
|
||||
)
|
||||
render_inference_map(selected_result)
|
||||
|
||||
|
||||
def render_class_distribution_section(predictions_gdf: gpd.GeoDataFrame, task: str):
|
||||
"""Render class distribution histogram section.
|
||||
|
||||
Args:
|
||||
predictions_gdf: GeoDataFrame with predictions.
|
||||
task: Task type ('binary', 'count_regimes', 'density_regimes', 'count', 'density').
|
||||
|
||||
"""
|
||||
st.header("📈 Class Distribution")
|
||||
st.markdown("Distribution of predicted classes across all inference cells.")
|
||||
render_class_distribution_histogram(predictions_gdf, task)
|
||||
|
||||
|
||||
def render_class_comparison_section(predictions_gdf: gpd.GeoDataFrame, task: str):
|
||||
"""Render class comparison analysis section.
|
||||
|
||||
Args:
|
||||
predictions_gdf: GeoDataFrame with predictions.
|
||||
task: Task type ('binary', 'count_regimes', 'density_regimes', 'count', 'density').
|
||||
|
||||
"""
|
||||
st.header("🔍 Class Comparison Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Detailed comparison of predicted classes showing probability distributions
|
||||
and confidence metrics for different class predictions.
|
||||
"""
|
||||
)
|
||||
render_class_comparison(predictions_gdf, task)
|
||||
|
||||
|
||||
def render_inference_page():
|
||||
"""Render the Inference page of the dashboard."""
|
||||
st.title("🗺️ Inference Results")
|
||||
|
|
@ -179,45 +74,18 @@ def render_inference_page():
|
|||
with st.sidebar:
|
||||
selected_result = render_sidebar_selection(training_results)
|
||||
|
||||
# Main content area - Run Information
|
||||
render_run_information(selected_result)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Check if predictions file exists
|
||||
preds_file = selected_result.path / "predicted_probabilities.parquet"
|
||||
if not preds_file.exists():
|
||||
st.error("No inference results found for this training run.")
|
||||
st.info("Inference results are generated automatically during training.")
|
||||
return
|
||||
|
||||
# Main content area
|
||||
# Load predictions
|
||||
with st.spinner("Loading inference results..."):
|
||||
predictions_gdf = gpd.read_parquet(preds_file)
|
||||
task = selected_result.settings.task
|
||||
# Columns: Index(['cell_id', 'predicted', 'geometry'], dtype='object')
|
||||
predictions_gdf = selected_result.run.predictions
|
||||
task = selected_result.run.task
|
||||
|
||||
# Inference Statistics Section
|
||||
render_inference_statistics_section(predictions_gdf, task)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Spatial Coverage Section
|
||||
render_spatial_coverage_section(predictions_gdf)
|
||||
|
||||
st.divider()
|
||||
|
||||
# 3D Map Visualization Section
|
||||
render_map_visualization_section(selected_result)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Class Distribution Section
|
||||
render_class_distribution_section(predictions_gdf, task)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Class Comparison Section
|
||||
render_class_comparison_section(predictions_gdf, task)
|
||||
# TODO: Implement the sections
|
||||
# Map, optionally 3D
|
||||
# Some statistics about the predictions
|
||||
# Class Distribution for classification tasks
|
||||
# Distribution of predicted values for regression tasks
|
||||
|
||||
st.balloons()
|
||||
stopwatch.summary()
|
||||
|
|
|
|||
|
|
@ -1,919 +0,0 @@
|
|||
"""Model State page: Visualization of model internal state and feature importance."""
|
||||
|
||||
import streamlit as st
|
||||
import xarray as xr
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.dashboard.plots.model_state import (
|
||||
plot_arcticdem_heatmap,
|
||||
plot_arcticdem_summary,
|
||||
plot_box_assignment_bars,
|
||||
plot_box_assignments,
|
||||
plot_common_features,
|
||||
plot_embedding_aggregation_summary,
|
||||
plot_embedding_heatmap,
|
||||
plot_era5_heatmap,
|
||||
plot_era5_summary,
|
||||
plot_era5_time_heatmap,
|
||||
plot_top_features,
|
||||
)
|
||||
from entropice.dashboard.utils.colors import generate_unified_colormap
|
||||
from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results
|
||||
from entropice.dashboard.utils.unsembler import (
|
||||
extract_arcticdem_features,
|
||||
extract_common_features,
|
||||
extract_embedding_features,
|
||||
extract_era5_features,
|
||||
)
|
||||
from entropice.utils.types import L2SourceDataset
|
||||
|
||||
|
||||
def get_members_from_settings(settings) -> list[L2SourceDataset]:
|
||||
"""Extract dataset members from training settings.
|
||||
|
||||
Args:
|
||||
settings: TrainingSettings object containing dataset configuration.
|
||||
|
||||
Returns:
|
||||
List of L2SourceDataset members used in training.
|
||||
|
||||
"""
|
||||
return settings.members
|
||||
|
||||
|
||||
def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult:
|
||||
"""Render sidebar for training run selection.
|
||||
|
||||
Args:
|
||||
training_results: List of available TrainingResult objects.
|
||||
|
||||
Returns:
|
||||
Selected TrainingResult object.
|
||||
|
||||
"""
|
||||
st.header("Select Training Run")
|
||||
|
||||
# Result selection with task-first naming
|
||||
result_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results}
|
||||
selected_name = st.selectbox(
|
||||
"Training Run",
|
||||
options=list(result_options.keys()),
|
||||
index=0,
|
||||
help="Choose a training result to visualize model state",
|
||||
key="model_state_training_run_select",
|
||||
)
|
||||
selected_result = result_options[selected_name]
|
||||
|
||||
return selected_result
|
||||
|
||||
|
||||
def render_model_info(model_state: xr.Dataset, model_type: str):
|
||||
"""Render basic model state information.
|
||||
|
||||
Args:
|
||||
model_state: Xarray dataset containing model state.
|
||||
model_type: Type of model (espa, xgboost, rf, knn).
|
||||
|
||||
"""
|
||||
with st.expander("Model State Information", expanded=False):
|
||||
st.write(f"**Model Type:** {model_type.upper()}")
|
||||
st.write(f"**Variables:** {list(model_state.data_vars)}")
|
||||
st.write(f"**Dimensions:** {dict(model_state.sizes)}")
|
||||
st.write(f"**Coordinates:** {list(model_state.coords)}")
|
||||
st.write(f"**Attributes:** {dict(model_state.attrs)}")
|
||||
|
||||
|
||||
def render_training_data_summary(members: list[L2SourceDataset]):
|
||||
"""Render summary of training data sources.
|
||||
|
||||
Args:
|
||||
members: List of dataset members used in training.
|
||||
|
||||
"""
|
||||
st.header("📊 Training Data Summary")
|
||||
|
||||
st.markdown(
|
||||
f"""
|
||||
**Dataset Members Used in Training:** {len(members)}
|
||||
|
||||
The following data sources were used to train this model:
|
||||
"""
|
||||
)
|
||||
|
||||
# Create a nice display of members with emojis
|
||||
member_display = {
|
||||
"AlphaEarth": "🛰️ AlphaEarth (Satellite Embeddings)",
|
||||
"ArcticDEM": "🏔️ ArcticDEM (Topography)",
|
||||
"ERA5-yearly": "⛅ ERA5 Yearly (Climate)",
|
||||
"ERA5-seasonal": "⛅ ERA5 Seasonal (Summer/Winter)",
|
||||
"ERA5-shoulder": "⛅ ERA5 Shoulder Seasons (JFM/AMJ/JAS/OND)",
|
||||
}
|
||||
|
||||
cols = st.columns(min(len(members), 3))
|
||||
for idx, member in enumerate(members):
|
||||
with cols[idx % 3]:
|
||||
display_name = member_display.get(member, f"📁 {member}")
|
||||
st.info(display_name)
|
||||
|
||||
|
||||
def render_model_state_page():
|
||||
"""Render the Model State page of the dashboard."""
|
||||
st.title("🔬 Model State")
|
||||
st.markdown(
|
||||
"""
|
||||
Comprehensive visualization of the best model's internal state and feature importance.
|
||||
Select a training run from the sidebar to explore model parameters, feature weights,
|
||||
and data source contributions.
|
||||
"""
|
||||
)
|
||||
|
||||
# Load available training results
|
||||
training_results = load_all_training_results()
|
||||
|
||||
if not training_results:
|
||||
st.warning("No training results found. Please run some training experiments first.")
|
||||
st.info("Run training using: `pixi run python -m entropice.ml.training`")
|
||||
return
|
||||
|
||||
st.success(f"Found **{len(training_results)}** training result(s)")
|
||||
|
||||
st.divider()
|
||||
|
||||
# Sidebar: Training run selection
|
||||
with st.sidebar:
|
||||
selected_result = render_sidebar_selection(training_results)
|
||||
|
||||
# Get the model type from settings
|
||||
model_type = selected_result.settings.model
|
||||
|
||||
# Load model state
|
||||
with st.spinner("Loading model state..."):
|
||||
model_state = selected_result.load_model_state()
|
||||
if model_state is None:
|
||||
st.error("Could not load model state for this result.")
|
||||
st.info("The model state file (best_estimator_state.nc) may be missing from the training results.")
|
||||
return
|
||||
|
||||
# Display basic model state info
|
||||
render_model_info(model_state, model_type)
|
||||
|
||||
# Display dataset members summary
|
||||
members = get_members_from_settings(selected_result.settings)
|
||||
render_training_data_summary(members)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Render model-specific visualizations
|
||||
if model_type == "espa":
|
||||
render_espa_model_state(model_state, selected_result)
|
||||
elif model_type == "xgboost":
|
||||
render_xgboost_model_state(model_state, selected_result)
|
||||
elif model_type == "rf":
|
||||
render_rf_model_state(model_state, selected_result)
|
||||
elif model_type == "knn":
|
||||
render_knn_model_state(model_state, selected_result)
|
||||
else:
|
||||
st.warning(f"Visualization for model type '{model_type}' is not yet implemented.")
|
||||
|
||||
st.balloons()
|
||||
stopwatch.summary()
|
||||
|
||||
|
||||
def render_espa_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
|
||||
"""Render visualizations for ESPA model.
|
||||
|
||||
Args:
|
||||
model_state: Xarray dataset containing ESPA model state.
|
||||
selected_result: TrainingResult object containing training configuration.
|
||||
|
||||
"""
|
||||
# Scale feature weights by number of features
|
||||
n_features = model_state.sizes["feature"]
|
||||
model_state["feature_weights"] *= n_features
|
||||
|
||||
# Get members used in training
|
||||
members = get_members_from_settings(selected_result.settings)
|
||||
|
||||
# Extract different feature types based on what was used in training
|
||||
embedding_feature_array = None
|
||||
if "AlphaEarth" in members:
|
||||
embedding_feature_array = extract_embedding_features(model_state)
|
||||
|
||||
era5_yearly_array = None
|
||||
era5_seasonal_array = None
|
||||
era5_shoulder_array = None
|
||||
if "ERA5-yearly" in members:
|
||||
era5_yearly_array = extract_era5_features(model_state, temporal_group="yearly")
|
||||
if "ERA5-seasonal" in members:
|
||||
era5_seasonal_array = extract_era5_features(model_state, temporal_group="seasonal")
|
||||
if "ERA5-shoulder" in members:
|
||||
era5_shoulder_array = extract_era5_features(model_state, temporal_group="shoulder")
|
||||
|
||||
arcticdem_feature_array = None
|
||||
if "ArcticDEM" in members:
|
||||
arcticdem_feature_array = extract_arcticdem_features(model_state)
|
||||
|
||||
common_feature_array = extract_common_features(model_state)
|
||||
|
||||
# Generate unified colormaps (convert dataclass to dict)
|
||||
settings_dict = {"task": selected_result.settings.task, "classes": selected_result.settings.classes}
|
||||
_, _, altair_colors = generate_unified_colormap(settings_dict)
|
||||
|
||||
# Feature importance section
|
||||
st.header("Feature Importance")
|
||||
st.markdown("The most important features based on learned feature weights from the best estimator.")
|
||||
|
||||
@st.fragment
|
||||
def render_feature_importance():
|
||||
# Slider to control number of features to display
|
||||
top_n = st.slider(
|
||||
"Number of top features to display",
|
||||
min_value=5,
|
||||
max_value=50,
|
||||
value=10,
|
||||
step=5,
|
||||
help="Select how many of the most important features to visualize",
|
||||
)
|
||||
|
||||
with st.spinner("Generating feature importance plot..."):
|
||||
feature_chart = plot_top_features(model_state, top_n=top_n)
|
||||
st.altair_chart(feature_chart, width="stretch")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
**Interpretation:**
|
||||
- **Magnitude**: Larger absolute values indicate more important features
|
||||
- **Color**: Blue bars indicate positive weights, coral bars indicate negative weights
|
||||
"""
|
||||
)
|
||||
|
||||
render_feature_importance()
|
||||
|
||||
# Box-to-Label Assignment Visualization
|
||||
st.header("Box-to-Label Assignments")
|
||||
st.markdown(
|
||||
"""
|
||||
This visualization shows how the learned boxes (prototypes in feature space) are
|
||||
assigned to different class labels. The ESPA classifier learns K boxes and assigns
|
||||
them to classes through the Lambda matrix. Higher values indicate stronger assignment
|
||||
of a box to a particular class.
|
||||
"""
|
||||
)
|
||||
|
||||
with st.spinner("Generating box assignment visualizations..."):
|
||||
col1, col2 = st.columns([0.7, 0.3])
|
||||
|
||||
with col1:
|
||||
st.markdown("### Assignment Heatmap")
|
||||
box_assignment_heatmap = plot_box_assignments(model_state)
|
||||
st.altair_chart(box_assignment_heatmap, width="stretch")
|
||||
|
||||
with col2:
|
||||
st.markdown("### Box Count by Class")
|
||||
box_assignment_bars = plot_box_assignment_bars(model_state, altair_colors)
|
||||
st.altair_chart(box_assignment_bars, width="stretch")
|
||||
|
||||
# Show statistics
|
||||
with st.expander("Box Assignment Statistics"):
|
||||
box_assignments = model_state["box_assignments"].to_pandas()
|
||||
st.write("**Assignment Matrix Statistics:**")
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
st.metric("Total Boxes", len(box_assignments.columns))
|
||||
with col2:
|
||||
st.metric("Number of Classes", len(box_assignments.index))
|
||||
with col3:
|
||||
st.metric("Mean Assignment", f"{box_assignments.to_numpy().mean():.4f}")
|
||||
with col4:
|
||||
st.metric("Max Assignment", f"{box_assignments.to_numpy().max():.4f}")
|
||||
|
||||
# Show which boxes are most strongly assigned to each class
|
||||
st.write("**Top Box Assignments per Class:**")
|
||||
for class_label in box_assignments.index:
|
||||
top_boxes = box_assignments.loc[class_label].nlargest(5)
|
||||
st.write(
|
||||
f"**Class {class_label}:** Boxes {', '.join(map(str, top_boxes.index.tolist()))} "
|
||||
f"(strengths: {', '.join(f'{v:.3f}' for v in top_boxes.to_numpy())})"
|
||||
)
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
**Interpretation:**
|
||||
- Each box can be assigned to multiple classes with different strengths
|
||||
- Boxes with higher assignment values for a class contribute more to that class's predictions
|
||||
- The distribution shows how the model partitions the feature space for classification
|
||||
"""
|
||||
)
|
||||
|
||||
# Embedding features analysis (if present)
|
||||
if embedding_feature_array is not None:
|
||||
render_embedding_features(embedding_feature_array)
|
||||
|
||||
# ERA5 features analysis (if present) - split by temporal group
|
||||
if era5_yearly_array is not None:
|
||||
render_era5_features(era5_yearly_array, temporal_group="Yearly")
|
||||
|
||||
if era5_seasonal_array is not None:
|
||||
render_era5_features(era5_seasonal_array, temporal_group="Seasonal")
|
||||
|
||||
if era5_shoulder_array is not None:
|
||||
render_era5_features(era5_shoulder_array, temporal_group="Shoulder")
|
||||
|
||||
# ArcticDEM features analysis (if present)
|
||||
if arcticdem_feature_array is not None:
|
||||
render_arcticdem_features(arcticdem_feature_array)
|
||||
|
||||
# Common features analysis (if present)
|
||||
if common_feature_array is not None:
|
||||
render_common_features(common_feature_array)
|
||||
|
||||
|
||||
def render_xgboost_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
|
||||
"""Render visualizations for XGBoost model.
|
||||
|
||||
Args:
|
||||
model_state: Xarray dataset containing XGBoost model state.
|
||||
selected_result: TrainingResult object containing training configuration.
|
||||
|
||||
"""
|
||||
from entropice.dashboard.plots.model_state import (
|
||||
plot_xgboost_feature_importance,
|
||||
plot_xgboost_importance_comparison,
|
||||
)
|
||||
|
||||
st.header("🌲 XGBoost Model Analysis")
|
||||
st.markdown(
|
||||
f"""
|
||||
XGBoost gradient boosted tree model with **{model_state.attrs.get("n_trees", "N/A")} trees**.
|
||||
|
||||
**Objective:** {model_state.attrs.get("objective", "N/A")}
|
||||
"""
|
||||
)
|
||||
|
||||
# Feature importance with different types
|
||||
st.subheader("Feature Importance Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
XGBoost provides multiple ways to measure feature importance:
|
||||
- **Weight**: Number of times a feature is used to split the data
|
||||
- **Gain**: Average gain across all splits using the feature
|
||||
- **Cover**: Average coverage across all splits using the feature
|
||||
- **Total Gain**: Total gain across all splits
|
||||
- **Total Cover**: Total coverage across all splits
|
||||
"""
|
||||
)
|
||||
|
||||
# Importance type selector
|
||||
importance_type = st.selectbox(
|
||||
"Select Importance Type",
|
||||
options=["gain", "weight", "cover", "total_gain", "total_cover"],
|
||||
index=0,
|
||||
help="Choose which importance metric to visualize",
|
||||
key="model_state_importance_type",
|
||||
)
|
||||
|
||||
# Top N slider
|
||||
top_n = st.slider(
|
||||
"Number of top features to display",
|
||||
min_value=5,
|
||||
max_value=50,
|
||||
value=20,
|
||||
step=5,
|
||||
help="Select how many of the most important features to visualize",
|
||||
)
|
||||
|
||||
with st.spinner("Generating feature importance plot..."):
|
||||
importance_chart = plot_xgboost_feature_importance(model_state, importance_type=importance_type, top_n=top_n)
|
||||
st.altair_chart(importance_chart, width="stretch")
|
||||
|
||||
# Comparison of importance types
|
||||
st.subheader("Importance Type Comparison")
|
||||
st.markdown("Compare the top features across different importance metrics.")
|
||||
|
||||
with st.spinner("Generating importance comparison..."):
|
||||
comparison_chart = plot_xgboost_importance_comparison(model_state, top_n=15)
|
||||
st.altair_chart(comparison_chart, width="stretch")
|
||||
|
||||
# Statistics
|
||||
with st.expander("Model Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.metric("Number of Trees", model_state.attrs.get("n_trees", "N/A"))
|
||||
with col2:
|
||||
st.metric("Total Features", model_state.sizes.get("feature", "N/A"))
|
||||
|
||||
# Feature source analysis
|
||||
st.subheader("Feature Importance by Data Source")
|
||||
st.markdown(
|
||||
"""
|
||||
Breakdown of feature importance by data source (AlphaEarth embeddings, ERA5 climate,
|
||||
ArcticDEM topography, and common features).
|
||||
"""
|
||||
)
|
||||
|
||||
# Get members used in training
|
||||
members = get_members_from_settings(selected_result.settings)
|
||||
|
||||
# Extract features by source using the selected importance type
|
||||
importance_var = f"feature_importance_{importance_type}"
|
||||
|
||||
embedding_feature_array = None
|
||||
if "AlphaEarth" in members:
|
||||
embedding_feature_array = extract_embedding_features(model_state, importance_type=importance_var)
|
||||
|
||||
era5_yearly_array = None
|
||||
era5_seasonal_array = None
|
||||
era5_shoulder_array = None
|
||||
if "ERA5-yearly" in members:
|
||||
era5_yearly_array = extract_era5_features(model_state, importance_type=importance_var, temporal_group="yearly")
|
||||
if "ERA5-seasonal" in members:
|
||||
era5_seasonal_array = extract_era5_features(
|
||||
model_state, importance_type=importance_var, temporal_group="seasonal"
|
||||
)
|
||||
if "ERA5-shoulder" in members:
|
||||
era5_shoulder_array = extract_era5_features(
|
||||
model_state, importance_type=importance_var, temporal_group="shoulder"
|
||||
)
|
||||
|
||||
arcticdem_feature_array = None
|
||||
if "ArcticDEM" in members:
|
||||
arcticdem_feature_array = extract_arcticdem_features(model_state, importance_type=importance_var)
|
||||
|
||||
common_feature_array = extract_common_features(model_state, importance_type=importance_var)
|
||||
|
||||
# Render each source's features if present
|
||||
if embedding_feature_array is not None:
|
||||
render_embedding_features(embedding_feature_array)
|
||||
|
||||
if era5_yearly_array is not None:
|
||||
render_era5_features(era5_yearly_array, temporal_group="Yearly")
|
||||
|
||||
if era5_seasonal_array is not None:
|
||||
render_era5_features(era5_seasonal_array, temporal_group="Seasonal")
|
||||
|
||||
if era5_shoulder_array is not None:
|
||||
render_era5_features(era5_shoulder_array, temporal_group="Shoulder")
|
||||
|
||||
if arcticdem_feature_array is not None:
|
||||
render_arcticdem_features(arcticdem_feature_array)
|
||||
|
||||
if common_feature_array is not None:
|
||||
render_common_features(common_feature_array)
|
||||
|
||||
|
||||
def render_rf_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
|
||||
"""Render visualizations for Random Forest model.
|
||||
|
||||
Args:
|
||||
model_state: Xarray dataset containing Random Forest model state.
|
||||
selected_result: TrainingResult object containing training configuration.
|
||||
|
||||
"""
|
||||
from entropice.dashboard.plots.model_state import plot_rf_feature_importance
|
||||
|
||||
st.header("🌳 Random Forest Model Analysis")
|
||||
|
||||
# Check if using cuML (which doesn't provide tree statistics)
|
||||
is_cuml = "cuML" in model_state.attrs.get("description", "")
|
||||
|
||||
st.markdown(
|
||||
f"""
|
||||
Random Forest ensemble with **{model_state.attrs.get("n_estimators", "N/A")} trees**
|
||||
(max depth: {model_state.attrs.get("max_depth", "N/A")}).
|
||||
"""
|
||||
)
|
||||
|
||||
if is_cuml:
|
||||
st.info("ℹ️ Using cuML GPU-accelerated Random Forest. Individual tree statistics are not available.")
|
||||
|
||||
# Display OOB score if available
|
||||
oob_score = model_state.attrs.get("oob_score")
|
||||
if oob_score is not None:
|
||||
st.info(f"**Out-of-Bag Score:** {oob_score:.4f}")
|
||||
|
||||
# Feature importance
|
||||
st.subheader("Feature Importance (Gini Importance)")
|
||||
st.markdown(
|
||||
"""
|
||||
Random Forest uses Gini impurity to measure feature importance. Features with higher
|
||||
importance values contribute more to the model's predictions.
|
||||
"""
|
||||
)
|
||||
|
||||
# Top N slider
|
||||
top_n = st.slider(
|
||||
"Number of top features to display",
|
||||
min_value=5,
|
||||
max_value=50,
|
||||
value=20,
|
||||
step=5,
|
||||
help="Select how many of the most important features to visualize",
|
||||
)
|
||||
|
||||
with st.spinner("Generating feature importance plot..."):
|
||||
importance_chart = plot_rf_feature_importance(model_state, top_n=top_n)
|
||||
st.altair_chart(importance_chart, width="stretch")
|
||||
|
||||
# Tree statistics (only if available - sklearn RF has them, cuML RF doesn't)
|
||||
if not is_cuml and "tree_depths" in model_state:
|
||||
from entropice.dashboard.plots.model_state import plot_rf_tree_statistics
|
||||
|
||||
st.subheader("Tree Structure Statistics")
|
||||
st.markdown("Distribution of tree properties across the forest.")
|
||||
|
||||
with st.spinner("Generating tree statistics..."):
|
||||
chart_depths, chart_leaves, chart_nodes = plot_rf_tree_statistics(model_state)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.altair_chart(chart_depths, width="stretch")
|
||||
with col2:
|
||||
st.altair_chart(chart_leaves, width="stretch")
|
||||
with col3:
|
||||
st.altair_chart(chart_nodes, width="stretch")
|
||||
|
||||
# Statistics
|
||||
with st.expander("Forest Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
depths = model_state["tree_depths"].to_pandas()
|
||||
leaves = model_state["tree_n_leaves"].to_pandas()
|
||||
nodes = model_state["tree_n_nodes"].to_pandas()
|
||||
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.write("**Tree Depths:**")
|
||||
st.metric("Mean Depth", f"{depths.mean():.2f}")
|
||||
st.metric("Max Depth", f"{depths.max()}")
|
||||
st.metric("Min Depth", f"{depths.min()}")
|
||||
with col2:
|
||||
st.write("**Leaf Counts:**")
|
||||
st.metric("Mean Leaves", f"{leaves.mean():.2f}")
|
||||
st.metric("Max Leaves", f"{leaves.max()}")
|
||||
st.metric("Min Leaves", f"{leaves.min()}")
|
||||
with col3:
|
||||
st.write("**Node Counts:**")
|
||||
st.metric("Mean Nodes", f"{nodes.mean():.2f}")
|
||||
st.metric("Max Nodes", f"{nodes.max()}")
|
||||
st.metric("Min Nodes", f"{nodes.min()}")
|
||||
|
||||
# Feature source analysis
|
||||
st.subheader("Feature Importance by Data Source")
|
||||
st.markdown(
|
||||
"""
|
||||
Breakdown of feature importance by data source (AlphaEarth embeddings, ERA5 climate,
|
||||
ArcticDEM topography, and common features).
|
||||
"""
|
||||
)
|
||||
|
||||
# Get members used in training
|
||||
members = get_members_from_settings(selected_result.settings)
|
||||
|
||||
# Extract features by source
|
||||
embedding_feature_array = None
|
||||
if "AlphaEarth" in members:
|
||||
embedding_feature_array = extract_embedding_features(model_state, importance_type="feature_importance")
|
||||
|
||||
era5_yearly_array = None
|
||||
era5_seasonal_array = None
|
||||
era5_shoulder_array = None
|
||||
if "ERA5-yearly" in members:
|
||||
era5_yearly_array = extract_era5_features(
|
||||
model_state, importance_type="feature_importance", temporal_group="yearly"
|
||||
)
|
||||
if "ERA5-seasonal" in members:
|
||||
era5_seasonal_array = extract_era5_features(
|
||||
model_state, importance_type="feature_importance", temporal_group="seasonal"
|
||||
)
|
||||
if "ERA5-shoulder" in members:
|
||||
era5_shoulder_array = extract_era5_features(
|
||||
model_state, importance_type="feature_importance", temporal_group="shoulder"
|
||||
)
|
||||
|
||||
arcticdem_feature_array = None
|
||||
if "ArcticDEM" in members:
|
||||
arcticdem_feature_array = extract_arcticdem_features(model_state, importance_type="feature_importance")
|
||||
|
||||
common_feature_array = extract_common_features(model_state, importance_type="feature_importance")
|
||||
|
||||
# Render each source's features if present
|
||||
if embedding_feature_array is not None:
|
||||
render_embedding_features(embedding_feature_array)
|
||||
|
||||
if era5_yearly_array is not None:
|
||||
render_era5_features(era5_yearly_array, temporal_group="Yearly")
|
||||
|
||||
if era5_seasonal_array is not None:
|
||||
render_era5_features(era5_seasonal_array, temporal_group="Seasonal")
|
||||
|
||||
if era5_shoulder_array is not None:
|
||||
render_era5_features(era5_shoulder_array, temporal_group="Shoulder")
|
||||
|
||||
if arcticdem_feature_array is not None:
|
||||
render_arcticdem_features(arcticdem_feature_array)
|
||||
|
||||
if common_feature_array is not None:
|
||||
render_common_features(common_feature_array)
|
||||
|
||||
|
||||
def render_knn_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
|
||||
"""Render visualizations for KNN model.
|
||||
|
||||
Args:
|
||||
model_state: Xarray dataset containing KNN model state.
|
||||
selected_result: TrainingResult object containing training configuration.
|
||||
|
||||
"""
|
||||
st.header("🔍 K-Nearest Neighbors Model Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
K-Nearest Neighbors is a non-parametric, instance-based learning algorithm.
|
||||
Unlike tree-based or parametric models, KNN doesn't learn feature weights or build
|
||||
a model structure. Instead, it memorizes the training data and makes predictions
|
||||
based on the k nearest neighbors.
|
||||
"""
|
||||
)
|
||||
|
||||
# Display model metadata
|
||||
st.subheader("Model Configuration")
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Number of Neighbors (k)", model_state.attrs.get("n_neighbors", "N/A"))
|
||||
st.metric("Training Samples", model_state.attrs.get("n_samples_fit", "N/A"))
|
||||
with col2:
|
||||
st.metric("Weights", model_state.attrs.get("weights", "N/A"))
|
||||
st.metric("Algorithm", model_state.attrs.get("algorithm", "N/A"))
|
||||
with col3:
|
||||
st.metric("Metric", model_state.attrs.get("metric", "N/A"))
|
||||
|
||||
st.info(
|
||||
"""
|
||||
**Note:** KNN doesn't have traditional feature importance or model parameters to visualize.
|
||||
The model's behavior depends entirely on:
|
||||
- The number of neighbors (k)
|
||||
- The distance metric used
|
||||
- The weighting scheme for neighbors
|
||||
|
||||
To understand the model better, consider visualizing the decision boundaries on a
|
||||
reduced-dimensional representation of your data (e.g., using PCA or t-SNE).
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# Helper functions for embedding/era5/common features
|
||||
def render_embedding_features(embedding_feature_array: xr.DataArray):
|
||||
"""Render embedding feature visualizations.
|
||||
|
||||
Args:
|
||||
embedding_feature_array: DataArray containing AlphaEarth embedding feature weights.
|
||||
|
||||
"""
|
||||
with st.container(border=True):
|
||||
st.header("🛰️ Embedding Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of embedding features showing which aggregations, bands, and years
|
||||
are most important for the model predictions.
|
||||
"""
|
||||
)
|
||||
|
||||
# Summary bar charts
|
||||
st.markdown("### Importance by Dimension")
|
||||
with st.spinner("Generating dimension summaries..."):
|
||||
chart_agg, chart_band, chart_year = plot_embedding_aggregation_summary(embedding_feature_array)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.altair_chart(chart_agg, width="stretch")
|
||||
with col2:
|
||||
st.altair_chart(chart_band, width="stretch")
|
||||
with col3:
|
||||
st.altair_chart(chart_year, width="stretch")
|
||||
|
||||
# Detailed heatmap
|
||||
st.markdown("### Detailed Heatmap by Aggregation")
|
||||
st.markdown("Shows the weight of each band-year combination for each aggregation type.")
|
||||
with st.spinner("Generating heatmap..."):
|
||||
heatmap_chart = plot_embedding_heatmap(embedding_feature_array)
|
||||
st.altair_chart(heatmap_chart, width="stretch")
|
||||
|
||||
# Statistics
|
||||
with st.expander("Embedding Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_emb_features = embedding_feature_array.size
|
||||
mean_weight = float(embedding_feature_array.mean().values)
|
||||
max_weight = float(embedding_feature_array.max().values)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total Embedding Features", n_emb_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
|
||||
# Show top embedding features
|
||||
st.write("**Top 10 Embedding Features:**")
|
||||
emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index()
|
||||
top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]]
|
||||
st.dataframe(top_emb, width="stretch")
|
||||
|
||||
|
||||
def render_era5_features(era5_feature_array: xr.DataArray, temporal_group: str = ""):
|
||||
"""Render ERA5 feature visualizations.
|
||||
|
||||
Args:
|
||||
era5_feature_array: ERA5 feature importance array.
|
||||
temporal_group: Name of the temporal grouping (e.g., "Yearly", "Seasonal", "Shoulder").
|
||||
|
||||
"""
|
||||
group_suffix = f" ({temporal_group})" if temporal_group else ""
|
||||
|
||||
with st.container(border=True):
|
||||
st.header(f"⛅ ERA5 Feature Analysis{group_suffix}")
|
||||
temporal_suffix = f" for {temporal_group.lower()} aggregation" if temporal_group else ""
|
||||
st.markdown(
|
||||
f"""
|
||||
Analysis of ERA5 climate features{temporal_suffix} showing which variables and time periods
|
||||
are most important for the model predictions.
|
||||
"""
|
||||
)
|
||||
|
||||
# Summary bar charts
|
||||
st.markdown("### Importance by Dimension")
|
||||
with st.spinner("Generating ERA5 dimension summaries..."):
|
||||
charts = plot_era5_summary(era5_feature_array)
|
||||
|
||||
# Check if this is seasonal/shoulder data (returns 3 charts) or yearly (returns 2 charts)
|
||||
if len(charts) == 3:
|
||||
chart_variable, chart_season, chart_year = charts
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.altair_chart(chart_variable, width="stretch")
|
||||
with col2:
|
||||
st.altair_chart(chart_season, width="stretch")
|
||||
with col3:
|
||||
st.altair_chart(chart_year, width="stretch")
|
||||
else:
|
||||
chart_variable, chart_time = charts
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.altair_chart(chart_variable, width="stretch")
|
||||
with col2:
|
||||
st.altair_chart(chart_time, width="stretch")
|
||||
|
||||
# Detailed heatmap
|
||||
st.markdown("### Detailed Heatmap")
|
||||
|
||||
# Check if this is seasonal/shoulder data
|
||||
has_season = "season" in era5_feature_array.dims
|
||||
|
||||
if has_season:
|
||||
st.markdown("Shows the weight of each variable-season-year combination.")
|
||||
with st.spinner("Generating ERA5 season heatmap..."):
|
||||
era5_heatmap_chart = plot_era5_heatmap(era5_feature_array)
|
||||
st.altair_chart(era5_heatmap_chart, width="stretch")
|
||||
|
||||
# Add time-based heatmap for seasonal/shoulder
|
||||
st.markdown("### By Time Heatmap")
|
||||
st.markdown("Shows temporal trends by averaging over seasons.")
|
||||
with st.spinner("Generating ERA5 time heatmap..."):
|
||||
era5_time_heatmap_chart = plot_era5_time_heatmap(era5_feature_array)
|
||||
if era5_time_heatmap_chart is not None:
|
||||
st.altair_chart(era5_time_heatmap_chart, width="stretch")
|
||||
else:
|
||||
st.markdown("Shows the weight of each variable-time combination.")
|
||||
with st.spinner("Generating ERA5 heatmap..."):
|
||||
era5_heatmap_chart = plot_era5_heatmap(era5_feature_array)
|
||||
st.altair_chart(era5_heatmap_chart, width="stretch")
|
||||
|
||||
# Statistics
|
||||
with st.expander("ERA5 Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_era5_features = era5_feature_array.size
|
||||
mean_weight = float(era5_feature_array.mean().values)
|
||||
max_weight = float(era5_feature_array.max().values)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total ERA5 Features", n_era5_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
|
||||
# Show top ERA5 features
|
||||
st.write("**Top 10 ERA5 Features:**")
|
||||
era5_df = era5_feature_array.to_dataframe(name="weight").reset_index()
|
||||
# Get all columns except 'weight' for display
|
||||
display_cols = [col for col in era5_df.columns if col != "weight"] + ["weight"]
|
||||
top_era5 = era5_df.nlargest(10, "weight")[display_cols]
|
||||
st.dataframe(top_era5, width="stretch")
|
||||
|
||||
|
||||
def render_arcticdem_features(arcticdem_feature_array: xr.DataArray):
|
||||
"""Render ArcticDEM feature visualizations.
|
||||
|
||||
Args:
|
||||
arcticdem_feature_array: DataArray containing ArcticDEM feature weights.
|
||||
|
||||
"""
|
||||
with st.container(border=True):
|
||||
st.header("🏔️ ArcticDEM Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of ArcticDEM topographic features showing which terrain variables and
|
||||
aggregations are most important for the model predictions.
|
||||
"""
|
||||
)
|
||||
|
||||
# Summary bar charts
|
||||
st.markdown("### Importance by Dimension")
|
||||
with st.spinner("Generating ArcticDEM dimension summaries..."):
|
||||
chart_variable, chart_agg = plot_arcticdem_summary(arcticdem_feature_array)
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.altair_chart(chart_variable, width="stretch")
|
||||
with col2:
|
||||
st.altair_chart(chart_agg, width="stretch")
|
||||
|
||||
# Detailed heatmap
|
||||
st.markdown("### Detailed Heatmap")
|
||||
st.markdown("Shows the weight of each variable-aggregation combination.")
|
||||
with st.spinner("Generating ArcticDEM heatmap..."):
|
||||
arcticdem_heatmap_chart = plot_arcticdem_heatmap(arcticdem_feature_array)
|
||||
st.altair_chart(arcticdem_heatmap_chart, width="stretch")
|
||||
|
||||
# Statistics
|
||||
with st.expander("ArcticDEM Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_arcticdem_features = arcticdem_feature_array.size
|
||||
mean_weight = float(arcticdem_feature_array.mean().values)
|
||||
max_weight = float(arcticdem_feature_array.max().values)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total ArcticDEM Features", n_arcticdem_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
|
||||
# Show top ArcticDEM features
|
||||
st.write("**Top 10 ArcticDEM Features:**")
|
||||
arcticdem_df = arcticdem_feature_array.to_dataframe(name="weight").reset_index()
|
||||
top_arcticdem = arcticdem_df.nlargest(10, "weight")[["variable", "agg", "weight"]]
|
||||
st.dataframe(top_arcticdem, width="stretch")
|
||||
|
||||
|
||||
def render_common_features(common_feature_array: xr.DataArray):
|
||||
"""Render common feature visualizations.
|
||||
|
||||
Args:
|
||||
common_feature_array: DataArray containing common feature weights.
|
||||
|
||||
"""
|
||||
with st.container(border=True):
|
||||
st.header("🗺️ Common Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of common features including cell area, water area, land area, land ratio,
|
||||
longitude, and latitude. These features provide spatial and geographic context.
|
||||
"""
|
||||
)
|
||||
|
||||
# Bar chart showing all common feature weights
|
||||
with st.spinner("Generating common features chart..."):
|
||||
common_chart = plot_common_features(common_feature_array)
|
||||
st.altair_chart(common_chart, width="stretch")
|
||||
|
||||
# Statistics
|
||||
with st.expander("Common Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_common_features = common_feature_array.size
|
||||
mean_weight = float(common_feature_array.mean().values)
|
||||
max_weight = float(common_feature_array.max().values)
|
||||
min_weight = float(common_feature_array.min().values)
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
st.metric("Total Common Features", n_common_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
with col4:
|
||||
st.metric("Min Weight", f"{min_weight:.4f}")
|
||||
|
||||
# Show all common features sorted by importance
|
||||
st.write("**All Common Features (by absolute weight):**")
|
||||
common_df = common_feature_array.to_dataframe(name="weight").reset_index()
|
||||
common_df["abs_weight"] = common_df["weight"].abs()
|
||||
common_df = common_df.sort_values("abs_weight", ascending=False)
|
||||
st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
**Interpretation:**
|
||||
- **cell_area, water_area, land_area**: Spatial extent features that may indicate
|
||||
size-related patterns
|
||||
- **land_ratio**: Proportion of land vs water in each cell
|
||||
- **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
|
||||
- Positive weights indicate features that increase the probability of the positive class
|
||||
- Negative weights indicate features that decrease the probability of the positive class
|
||||
"""
|
||||
)
|
||||
|
|
@ -9,7 +9,7 @@ from entropice.dashboard.sections.experiment_results import (
|
|||
render_training_results_summary,
|
||||
)
|
||||
from entropice.dashboard.sections.storage_statistics import render_storage_statistics
|
||||
from entropice.dashboard.utils.loaders import load_all_autogluon_training_results, load_all_training_results
|
||||
from entropice.dashboard.utils.loaders import load_all_training_results
|
||||
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
|
||||
|
||||
|
||||
|
|
@ -27,9 +27,6 @@ def render_overview_page():
|
|||
)
|
||||
# Load training results
|
||||
training_results = load_all_training_results()
|
||||
autogluon_results = load_all_autogluon_training_results()
|
||||
if len(autogluon_results) > 0:
|
||||
training_results.extend(autogluon_results)
|
||||
|
||||
if not training_results:
|
||||
st.warning("No training results found. Please run some training experiments first.")
|
||||
|
|
|
|||
|
|
@ -51,9 +51,9 @@ def render_analysis_settings_sidebar(training_results: list[TrainingResult]) ->
|
|||
available_metrics = selected_result.available_metrics
|
||||
|
||||
# Try to get refit metric from settings
|
||||
if selected_result.settings.task == "binary":
|
||||
if selected_result.run.task == "binary":
|
||||
refit_metric = "f1"
|
||||
elif selected_result.settings.task in ["count_regimes", "density_regimes"]:
|
||||
elif selected_result.run.task in ["count_regimes", "density_regimes"]:
|
||||
refit_metric = "f1_weighted"
|
||||
else:
|
||||
refit_metric = "r2"
|
||||
|
|
@ -121,14 +121,16 @@ def render_training_analysis_page():
|
|||
st.divider()
|
||||
|
||||
# Render confusion matrices for classification, regression analysis for regression
|
||||
if selected_result.settings.task in ["binary", "count_regimes", "density_regimes"]:
|
||||
if selected_result.run.task in ["binary", "count_regimes", "density_regimes"]:
|
||||
render_confusion_matrices(selected_result)
|
||||
else:
|
||||
render_regression_analysis(selected_result)
|
||||
|
||||
st.divider()
|
||||
|
||||
render_cv_statistics_section(cv_statistics, selected_result.test_metrics.get(selected_metric, float("nan")))
|
||||
if cv_statistics is not None:
|
||||
test_score = selected_result._get_best_score("test")
|
||||
render_cv_statistics_section(cv_statistics, test_score)
|
||||
|
||||
st.divider()
|
||||
|
||||
|
|
@ -137,7 +139,11 @@ def render_training_analysis_page():
|
|||
st.divider()
|
||||
|
||||
# List all results at the end
|
||||
st.header("📄 All Training Results")
|
||||
st.dataframe(selected_result.results)
|
||||
if selected_result.run.method_type == "HPOCV":
|
||||
st.header("📄 All Cross-Validation Results")
|
||||
st.dataframe(selected_result.run.cv_results)
|
||||
elif selected_result.run.method_type == "AutoML":
|
||||
st.header("📄 Model Leaderboard")
|
||||
st.dataframe(selected_result.run.leaderboard)
|
||||
|
||||
st.balloons()
|
||||
|
|
|
|||
1
src/entropice/experiments/__init__.py
Normal file
1
src/entropice/experiments/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Experiments."""
|
||||
|
|
@ -1,4 +1,6 @@
|
|||
from typing import cast
|
||||
"""Feature Importance Experiment."""
|
||||
|
||||
from typing import Literal, cast
|
||||
|
||||
import cyclopts
|
||||
from stopuhr import stopwatch
|
||||
|
|
@ -13,16 +15,18 @@ from entropice.utils.types import Grid, Model, TargetDataset, Task
|
|||
|
||||
cli = cyclopts.App("entropice-feature-importance")
|
||||
|
||||
EXPERIMENT_NAME = "feature_importance_era5-shoulder_arcticdem"
|
||||
# EXPERIMENT_NAME = "tobis-final-tests"
|
||||
DEV = False
|
||||
|
||||
EXPERIMENT_NAME = "feature_importance_era5-shoulder_arcticdem-v2"
|
||||
if DEV:
|
||||
EXPERIMENT_NAME = "tobis-final-tests"
|
||||
|
||||
|
||||
@cli.default
|
||||
def main(
|
||||
grid: Grid,
|
||||
target: TargetDataset,
|
||||
):
|
||||
def main(grid: Grid, target: TargetDataset, selection: Literal["none", "cluster", "univariate"] = "none"):
|
||||
"""Feature Importance Experiment."""
|
||||
levels = [3, 4, 5, 6] if grid == "hex" else [6, 7, 8, 9, 10]
|
||||
if DEV:
|
||||
levels = [3, 6] if grid == "hex" else [6, 10]
|
||||
for level in levels:
|
||||
print(f"Running feature importance experiment for {grid} grid at level {level}...")
|
||||
|
|
@ -38,9 +42,11 @@ def main(
|
|||
|
||||
# AutoGluon
|
||||
time_limit = 30 * 60 # 30 minutes
|
||||
# time_limit = 60
|
||||
if DEV:
|
||||
time_limit = 2 * 60 # 2 minutes
|
||||
presets = "extreme"
|
||||
# presets = "medium"
|
||||
if DEV:
|
||||
presets = "medium"
|
||||
settings = AutoGluonRunSettings(
|
||||
time_limit=time_limit,
|
||||
presets=presets,
|
||||
|
|
@ -59,11 +65,12 @@ def main(
|
|||
print(f"\nRunning HPOCV for model {model}...")
|
||||
n_iter = {
|
||||
"espa": 300,
|
||||
"xgboost": 100,
|
||||
"rf": 40,
|
||||
"knn": 20,
|
||||
"xgboost": 300,
|
||||
"rf": 100, # RF is slow, so we reduce the number of iterations
|
||||
"knn": 40, # kNN hpspace is small, so we reduce the number of iterations
|
||||
}[model]
|
||||
# n_iter = 3
|
||||
if DEV:
|
||||
n_iter = 3
|
||||
scaler = "standard" if model in ["espa", "knn"] else "none"
|
||||
normalize = scaler != "none"
|
||||
settings = HPOCVRunSettings(
|
||||
|
|
|
|||
79
src/entropice/experiments/feature_importance_alphaearth.py
Normal file
79
src/entropice/experiments/feature_importance_alphaearth.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""Feature Importance Experiment."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import cyclopts
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.ml.autogluon import RunSettings as AutoGluonRunSettings
|
||||
from entropice.ml.autogluon import train as train_autogluon
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.ml.hpsearchcv import RunSettings as HPOCVRunSettings
|
||||
from entropice.ml.hpsearchcv import hpsearch_cv
|
||||
from entropice.utils.paths import RESULTS_DIR
|
||||
from entropice.utils.types import Grid, Model, TargetDataset, Task
|
||||
|
||||
cli = cyclopts.App("entropice-feature-importance")
|
||||
|
||||
EXPERIMENT_NAME = "feature_importance_era5-shoulder_arcticdem-v2"
|
||||
|
||||
|
||||
@cli.default
|
||||
def main(grid: Grid, target: TargetDataset):
|
||||
"""Feature Importance Experiment."""
|
||||
levels = [3, 4, 5, 6] if grid == "hex" else [6, 7, 8, 9, 10]
|
||||
for level in levels:
|
||||
print(f"Running feature importance experiment for {grid} grid at level {level}...")
|
||||
dimension_filters = {"AlphaEarth": {"agg": ["median"]}}
|
||||
dataset_ensemble = DatasetEnsemble(
|
||||
grid=grid, level=level, members=["AlphaEarth"], dimension_filters=dimension_filters
|
||||
)
|
||||
|
||||
for task in cast(list[Task], ["binary", "density"]):
|
||||
print(f"\nRunning for {task}...")
|
||||
|
||||
# AutoGluon
|
||||
time_limit = 30 * 60 # 30 minutes
|
||||
presets = "extreme"
|
||||
settings = AutoGluonRunSettings(
|
||||
time_limit=time_limit,
|
||||
presets=presets,
|
||||
verbosity=2,
|
||||
task=task,
|
||||
target=target,
|
||||
)
|
||||
train_autogluon(dataset_ensemble, settings, experiment=EXPERIMENT_NAME)
|
||||
|
||||
# HPOCV
|
||||
splitter = "stratified_shuffle" if task == "binary" else "kfold"
|
||||
models: list[Model] = ["xgboost", "rf", "knn"]
|
||||
if task == "binary":
|
||||
models.append("espa")
|
||||
for model in models:
|
||||
print(f"\nRunning HPOCV for model {model}...")
|
||||
n_iter = {
|
||||
"espa": 300,
|
||||
"xgboost": 300,
|
||||
"rf": 100, # RF is slow, so we reduce the number of iterations
|
||||
"knn": 40, # kNN hpspace is small, so we reduce the number of iterations
|
||||
}[model]
|
||||
settings = HPOCVRunSettings(
|
||||
n_iter=n_iter,
|
||||
task=task,
|
||||
target=target,
|
||||
splitter=splitter,
|
||||
model=model,
|
||||
# AlphaEarth Embeddings are already normalized unit vectors
|
||||
scaler="none",
|
||||
normalize=False,
|
||||
)
|
||||
hpsearch_cv(dataset_ensemble, settings, experiment=EXPERIMENT_NAME)
|
||||
|
||||
stopwatch.summary()
|
||||
times = stopwatch.export()
|
||||
times.to_parquet(RESULTS_DIR / EXPERIMENT_NAME / f"training_times_{target}_{grid}.parquet")
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
3348
src/entropice/experiments/feature_selection.ipynb
Normal file
3348
src/entropice/experiments/feature_selection.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
|
@ -18,6 +18,7 @@ from entropice.spatial import grids
|
|||
from entropice.utils.paths import (
|
||||
DARTS_MLLABELS_DIR,
|
||||
DARTS_V1_DIR,
|
||||
DARTS_V2_DIR,
|
||||
get_darts_file,
|
||||
)
|
||||
from entropice.utils.types import Grid
|
||||
|
|
@ -25,6 +26,8 @@ from entropice.utils.types import Grid
|
|||
traceback.install()
|
||||
pretty.install()
|
||||
|
||||
darts_v2_l2_file = DARTS_V2_DIR / "all_prediction_segments_ensemble5.parquet"
|
||||
darts_v2_l2_cov_file = DARTS_V2_DIR / "all_prediction_extent_ensemble5.parquet"
|
||||
darts_v1_l2_file = DARTS_V1_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
|
||||
darts_v1_l2_cov_file = DARTS_V1_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
|
||||
darts_v1_corrections = DARTS_V1_DIR / "negative_correction.geojson"
|
||||
|
|
@ -131,6 +134,83 @@ def _process_rts_yearly_grid(
|
|||
return darts
|
||||
|
||||
|
||||
@cli.command()
|
||||
def extract_darts_v2(grid: Grid, level: int):
|
||||
"""Extract RTS labels from DARTS-v2 Level-2 dataset.
|
||||
|
||||
Creates a Darts-v1 xarray Dataset on the specified grid and level.
|
||||
The Dataset contains the following variables:
|
||||
- count: Number of RTS in the cell
|
||||
- area_km2: Total area of RTS in the cell (in km^2)
|
||||
- covered_area_km2: Area of the cell covered by DARTS (in km^2)
|
||||
- coverage: Fraction of the cell covered by DARTS
|
||||
- density: Density of RTS area per covered area (area_km2 / covered_area_km2)
|
||||
Since the DARTS-v1 Level-2 dataset contains yearly data, all variables are indexed by year as well.
|
||||
Thus each variable has dimensions (cell_ids, year).
|
||||
|
||||
Args:
|
||||
grid (Grid): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
|
||||
"""
|
||||
with stopwatch("Load data"):
|
||||
grid_gdf, cell_areas = _load_grid(grid, level)
|
||||
darts_l2 = gpd.read_parquet(darts_v2_l2_file).to_crs(grid_gdf.crs)
|
||||
darts_cov_l2 = gpd.read_parquet(darts_v2_l2_cov_file).to_crs(grid_gdf.crs)
|
||||
# Need to filter small noise pixels (I do not know where they are coming from)
|
||||
darts_cov_l2 = darts_cov_l2[darts_cov_l2.geometry.area > 1e9]
|
||||
|
||||
with stopwatch("Assign RTS to grid"):
|
||||
grid_l2 = grid_gdf.overlay(darts_l2, how="intersection")
|
||||
grid_cov_l2 = grid_gdf.overlay(darts_cov_l2, how="intersection")
|
||||
|
||||
darts = _process_rts_yearly_grid(grid_l2, grid_cov_l2, cell_areas)
|
||||
darts = _convert_xdggs(darts, grid, level)
|
||||
output_path = get_darts_file(grid, level, version="v2")
|
||||
with stopwatch(f"Writing Darts v2 to {output_path}"):
|
||||
darts.to_zarr(output_path, consolidated=False, mode="w")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def extract_darts_v2_aggregated(grid: Grid, level: int):
|
||||
"""Extract RTS labels from DARTS-v2 Level-3 dataset.
|
||||
|
||||
Creates a Darts-v2 xarray Dataset on the specified grid and level.
|
||||
The Dataset contains the following variables:
|
||||
- count: Number of RTS in the cell
|
||||
- area_km2: Total area of RTS in the cell (in km^2)
|
||||
- covered_area_km2: Area of the cell covered by DARTS (in km^2)
|
||||
- coverage: Fraction of the cell covered by DARTS
|
||||
- density: Density of RTS area per covered area (area_km2 / covered_area_km2)
|
||||
Since the DARTS-v2 Level-2 dataset contains yearly data, the data is dissolved then exploded to obtain Level-3 data.
|
||||
Thus each variable has only the dimension (cell_ids).
|
||||
|
||||
Args:
|
||||
grid (Grid): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
|
||||
"""
|
||||
with stopwatch("Load data"):
|
||||
grid_gdf, cell_areas = _load_grid(grid, level)
|
||||
darts_l2 = gpd.read_parquet(darts_v2_l2_file).to_crs(grid_gdf.crs)
|
||||
darts_cov_l2 = gpd.read_parquet(darts_v2_l2_cov_file).to_crs(grid_gdf.crs)
|
||||
# Need to filter small noise pixels (I do not know where they are coming from)
|
||||
darts_cov_l2 = darts_cov_l2[darts_cov_l2.geometry.area > 1e9]
|
||||
# Remove overlapping labels by dissolving
|
||||
darts_l2 = darts_l2[["geometry"]].dissolve().explode()
|
||||
darts_cov_l2 = darts_cov_l2[["geometry"]].dissolve().explode()
|
||||
|
||||
with stopwatch("Extract RTS labels"):
|
||||
grid_l3 = grid_gdf.overlay(darts_l2, how="intersection")
|
||||
grid_cov_l3 = grid_gdf.overlay(darts_cov_l2, how="intersection")
|
||||
|
||||
darts = _process_rts_grid(grid_l3, grid_cov_l3, cell_areas)
|
||||
darts = _convert_xdggs(darts, grid, level)
|
||||
output_path = get_darts_file(grid, level, version="v2-l3")
|
||||
with stopwatch(f"Writing Darts v2 l3 to {output_path}"):
|
||||
darts.to_zarr(output_path, consolidated=False, mode="w")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def extract_darts_v1(grid: Grid, level: int):
|
||||
"""Extract RTS labels from DARTS-v1 Level-2 dataset.
|
||||
|
|
@ -176,7 +256,6 @@ def extract_darts_v1(grid: Grid, level: int):
|
|||
grid_cov_l2 = grid_gdf.overlay(darts_cov_l2.to_crs(grid_gdf.crs), how="intersection")
|
||||
|
||||
darts = _process_rts_yearly_grid(grid_l2, grid_cov_l2, cell_areas)
|
||||
|
||||
darts = _convert_xdggs(darts, grid, level)
|
||||
output_path = get_darts_file(grid, level, version="v1")
|
||||
with stopwatch(f"Writing Darts v1 to {output_path}"):
|
||||
|
|
@ -206,10 +285,18 @@ def extract_darts_v1_aggregated(grid: Grid, level: int):
|
|||
darts_l2 = gpd.read_parquet(darts_v1_l2_file)
|
||||
darts_cov_l2 = gpd.read_parquet(darts_v1_l2_cov_file)
|
||||
grid_gdf, cell_areas = _load_grid(grid, level)
|
||||
corrections = gpd.read_file(darts_v1_corrections).to_crs(darts_l2.crs)
|
||||
# Remove overlapping labels by dissolving
|
||||
darts_l2 = darts_l2[["geometry"]].dissolve().explode()
|
||||
darts_cov_l2 = darts_cov_l2[["geometry"]].dissolve().explode()
|
||||
|
||||
with stopwatch("Apply corrections"):
|
||||
# The correction file is just an area of sure negatives
|
||||
# Thus, we first need to remove all RTS labels that intersect with the correction area,
|
||||
darts_l2 = gpd.overlay(darts_l2, corrections, how="difference")
|
||||
# then we need to add the correction area as coverage to the coverage file.
|
||||
darts_cov_l2 = gpd.overlay(darts_cov_l2, corrections, how="union")
|
||||
|
||||
with stopwatch("Extract RTS labels"):
|
||||
grid_l3 = grid_gdf.overlay(darts_l2.to_crs(grid_gdf.crs), how="intersection")
|
||||
grid_cov_l3 = grid_gdf.overlay(darts_cov_l2.to_crs(grid_gdf.crs), how="intersection")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,6 @@ This package contains modules for machine learning workflows:
|
|||
- inference: Batch prediction pipeline for trained classifiers
|
||||
"""
|
||||
|
||||
from . import dataset, inference, randomsearch
|
||||
from . import autogluon, dataset, hpsearchcv, inference
|
||||
|
||||
__all__ = ["dataset", "inference", "randomsearch"]
|
||||
__all__ = ["autogluon", "dataset", "hpsearchcv", "inference", "inference"]
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def _compute_metrics_and_confusion_matrix( # noqa: C901
|
|||
complete_scores = predictor.evaluate(complete_data, display=True, detailed_report=True)
|
||||
m = []
|
||||
cm = {}
|
||||
for dataset, scores in zip(["train", "test", "complete"], [train_scores, test_scores, complete_scores]):
|
||||
for split, scores in zip(["train", "test", "complete"], [train_scores, test_scores, complete_scores]):
|
||||
for metric, score in scores.items():
|
||||
if metric == "confusion_matrix":
|
||||
score = cast(pd.DataFrame, score)
|
||||
|
|
@ -58,24 +58,24 @@ def _compute_metrics_and_confusion_matrix( # noqa: C901
|
|||
dims=("y_true", "y_pred"),
|
||||
coords={"y_true": score.index.tolist(), "y_pred": score.columns.tolist()},
|
||||
)
|
||||
cm[dataset] = confusion_matrix
|
||||
cm[split] = confusion_matrix
|
||||
elif metric == "classification_report":
|
||||
score = cast(dict[str, dict[str, float]], score)
|
||||
score.pop("accuracy") # Accuracy is already included as a separate metric
|
||||
macro_avg = score.pop("macro avg")
|
||||
for macro_avg_metric, macro_avg_score in macro_avg.items():
|
||||
metric_name = f"macro_avg_{macro_avg_metric}"
|
||||
m.append({"dataset": dataset, "metric": metric_name, "score": macro_avg_score})
|
||||
m.append({"split": split, "metric": metric_name, "score": macro_avg_score})
|
||||
weighted_avg = score.pop("weighted avg")
|
||||
for weighted_avg_metric, weighted_avg_score in weighted_avg.items():
|
||||
metric_name = f"weighted_avg_{weighted_avg_metric}"
|
||||
m.append({"dataset": dataset, "metric": metric_name, "score": weighted_avg_score})
|
||||
m.append({"split": split, "metric": metric_name, "score": weighted_avg_score})
|
||||
for class_name, class_scores in score.items():
|
||||
class_name = class_name.replace(" ", "-")
|
||||
for class_metric, class_score in class_scores.items():
|
||||
m.append({"dataset": dataset, "metric": f"{class_name}_{class_metric}", "score": class_score})
|
||||
m.append({"split": split, "metric": f"{class_name}_{class_metric}", "score": class_score})
|
||||
else: # Scalar metric
|
||||
m.append({"dataset": dataset, "metric": metric, "score": score})
|
||||
m.append({"split": split, "metric": metric, "score": score})
|
||||
if len(cm) == 0:
|
||||
return pd.DataFrame(m), None
|
||||
elif len(cm) == 3:
|
||||
|
|
@ -100,8 +100,8 @@ def _compute_shap_explanation(
|
|||
output_names=target_labels,
|
||||
)
|
||||
samples = test_data.drop(columns=["label"])
|
||||
if len(samples) > 200:
|
||||
samples = samples.sample(n=200, random_state=42)
|
||||
if len(samples) > 100:
|
||||
samples = samples.sample(n=100, random_state=42)
|
||||
explanation = explainer(samples)
|
||||
return explanation
|
||||
|
||||
|
|
@ -161,10 +161,12 @@ def train(
|
|||
feature_importance = predictor.feature_importance(test_data)
|
||||
metrics, confusion_matrix = _compute_metrics_and_confusion_matrix(predictor, train_data, test_data, complete_data)
|
||||
|
||||
with stopwatch("Explaining model predictions with SHAP..."):
|
||||
explanation = _compute_shap_explanation(
|
||||
predictor, train_data, test_data, training_data.feature_names, training_data.target_labels
|
||||
)
|
||||
# ?: GPU inference is not yet implemented in AutoGluon, hence SHAP computation takes ages for large model ensembles,
|
||||
# as they are present in the higher quality presets. Disabling SHAP for now...
|
||||
# with stopwatch("Explaining model predictions with SHAP..."):
|
||||
# explanation = _compute_shap_explanation(
|
||||
# predictor, train_data, test_data, training_data.feature_names, training_data.target_labels
|
||||
# )
|
||||
|
||||
print("Predicting probabilities for all cells...")
|
||||
preds = predict_proba(dataset_ensemble, model=predictor, task=settings.task)
|
||||
|
|
@ -176,12 +178,11 @@ def train(
|
|||
method=AutoML(time_budget=settings.time_limit, preset=settings.presets, hpo=False),
|
||||
task=settings.task,
|
||||
target=settings.target,
|
||||
training_set=training_data,
|
||||
model=predictor,
|
||||
model_type="autogluon",
|
||||
metrics=metrics,
|
||||
feature_importance=feature_importance,
|
||||
shap_explanation=explanation,
|
||||
shap_explanation=None,
|
||||
predictions=preds,
|
||||
confusion_matrix=confusion_matrix,
|
||||
cv_results=None,
|
||||
|
|
|
|||
|
|
@ -1,208 +0,0 @@
|
|||
"""DePRECATED!!! Training with AutoGluon TabularPredictor for automated ML."""
|
||||
|
||||
import pickle
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import cyclopts
|
||||
import pandas as pd
|
||||
import toml
|
||||
from autogluon.tabular import TabularDataset, TabularPredictor
|
||||
from rich import pretty, traceback
|
||||
from sklearn import set_config
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.utils.paths import get_training_results_dir
|
||||
from entropice.utils.types import TargetDataset, Task
|
||||
|
||||
traceback.install()
|
||||
pretty.install()
|
||||
|
||||
set_config(array_api_dispatch=False)
|
||||
|
||||
cli = cyclopts.App("entropice-autogluon")
|
||||
|
||||
|
||||
@cyclopts.Parameter("*")
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class AutoGluonSettings:
|
||||
"""AutoGluon training settings."""
|
||||
|
||||
task: Task = "binary"
|
||||
target: TargetDataset = "darts_v1"
|
||||
time_limit: int = 3600 # Time limit in seconds (1 hour default)
|
||||
presets: str = "best" # AutoGluon preset: 'best', 'high', 'good', 'medium'
|
||||
eval_metric: str | None = None # Evaluation metric, None for auto-detect
|
||||
num_bag_folds: int = 5 # Number of folds for bagging
|
||||
num_bag_sets: int = 1 # Number of bagging sets
|
||||
num_stack_levels: int = 1 # Number of stacking levels
|
||||
num_gpus: int = 1 # Number of GPUs to use
|
||||
verbosity: int = 2 # Verbosity level (0-4)
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class AutoGluonTrainingSettings(DatasetEnsemble, AutoGluonSettings):
|
||||
"""Combined settings for AutoGluon training."""
|
||||
|
||||
classes: list[str] | None = None
|
||||
problem_type: str = "binary"
|
||||
|
||||
|
||||
def _determine_problem_type_and_metric(task: Task) -> tuple[str, str]:
|
||||
"""Determine AutoGluon problem type and appropriate evaluation metric.
|
||||
|
||||
Args:
|
||||
task: The training task type
|
||||
|
||||
Returns:
|
||||
Tuple of (problem_type, eval_metric)
|
||||
|
||||
"""
|
||||
if task == "binary":
|
||||
return ("binary", "balanced_accuracy") # Good for imbalanced datasets
|
||||
elif task in ["count_regimes", "density_regimes"]:
|
||||
return ("multiclass", "f1_weighted") # Weighted F1 for multiclass
|
||||
elif task in ["count", "density"]:
|
||||
return ("regression", "mean_absolute_error")
|
||||
else:
|
||||
raise ValueError(f"Unknown task: {task}")
|
||||
|
||||
|
||||
@cli.default
|
||||
def autogluon_train(
|
||||
dataset_ensemble: DatasetEnsemble,
|
||||
settings: AutoGluonSettings = AutoGluonSettings(),
|
||||
experiment: str | None = None,
|
||||
):
|
||||
"""Train models using AutoGluon TabularPredictor.
|
||||
|
||||
Args:
|
||||
dataset_ensemble: Dataset ensemble configuration
|
||||
settings: AutoGluon training settings
|
||||
experiment: Optional experiment name for organizing results
|
||||
|
||||
"""
|
||||
training_data = dataset_ensemble.create_training_set(task=settings.task, target=settings.target)
|
||||
|
||||
# Convert to AutoGluon TabularDataset
|
||||
train_data: pd.DataFrame = TabularDataset(training_data.to_dataframe("train")) # ty:ignore[invalid-assignment]
|
||||
test_data: pd.DataFrame = TabularDataset(training_data.to_dataframe("test")) # ty:ignore[invalid-assignment]
|
||||
|
||||
print(f"\nTraining data: {len(train_data)} samples")
|
||||
print(f"Test data: {len(test_data)} samples")
|
||||
print(f"Features: {len(training_data.feature_names)}")
|
||||
print(f"Classes: {training_data.target_labels}")
|
||||
|
||||
# Determine problem type and metric
|
||||
problem_type, default_metric = _determine_problem_type_and_metric(settings.task)
|
||||
eval_metric = settings.eval_metric or default_metric
|
||||
|
||||
print(f"\n🎯 Problem type: {problem_type}")
|
||||
print(f"📈 Evaluation metric: {eval_metric}")
|
||||
|
||||
# Create results directory
|
||||
results_dir = get_training_results_dir(
|
||||
experiment=experiment,
|
||||
grid=dataset_ensemble.grid,
|
||||
level=dataset_ensemble.level,
|
||||
task=settings.task,
|
||||
target=settings.target,
|
||||
name="autogluon",
|
||||
)
|
||||
print(f"\n💾 Results directory: {results_dir}")
|
||||
|
||||
# Initialize TabularPredictor
|
||||
print(f"\n🚀 Initializing AutoGluon TabularPredictor (preset='{settings.presets}')...")
|
||||
predictor = TabularPredictor(
|
||||
label="label",
|
||||
problem_type=problem_type,
|
||||
eval_metric=eval_metric,
|
||||
path=str(results_dir / "models"),
|
||||
verbosity=settings.verbosity,
|
||||
)
|
||||
|
||||
# Train models
|
||||
print(f"\n⚡ Training models (time_limit={settings.time_limit}s, num_gpus={settings.num_gpus})...")
|
||||
with stopwatch("AutoGluon training"):
|
||||
predictor.fit(
|
||||
train_data=train_data,
|
||||
time_limit=settings.time_limit,
|
||||
presets=settings.presets,
|
||||
num_bag_folds=settings.num_bag_folds,
|
||||
num_bag_sets=settings.num_bag_sets,
|
||||
num_stack_levels=settings.num_stack_levels,
|
||||
num_gpus=settings.num_gpus,
|
||||
ag_args_fit={"num_gpus": settings.num_gpus} if settings.num_gpus > 0 else None,
|
||||
)
|
||||
|
||||
# Evaluate on test data
|
||||
print("\n📊 Evaluating on test data...")
|
||||
test_score = predictor.evaluate(test_data, silent=True, detailed_report=True)
|
||||
print(f"Test {eval_metric}: {test_score[eval_metric]:.4f}")
|
||||
|
||||
# Get leaderboard
|
||||
print("\n🏆 Model Leaderboard:")
|
||||
leaderboard = predictor.leaderboard(test_data, silent=True)
|
||||
print(leaderboard[["model", "score_test", "score_val", "pred_time_test", "fit_time"]].head(10))
|
||||
|
||||
# Save leaderboard
|
||||
leaderboard_file = results_dir / "leaderboard.parquet"
|
||||
print(f"\n💾 Saving leaderboard to {leaderboard_file}")
|
||||
leaderboard.to_parquet(leaderboard_file)
|
||||
|
||||
# Get feature importance
|
||||
print("\n🔍 Computing feature importance...")
|
||||
with stopwatch("Feature importance"):
|
||||
try:
|
||||
# Compute feature importance with reduced repeats
|
||||
feature_importance = predictor.feature_importance(
|
||||
test_data,
|
||||
num_shuffle_sets=3,
|
||||
subsample_size=min(500, len(test_data)), # Further subsample if needed
|
||||
)
|
||||
fi_file = results_dir / "feature_importance.parquet"
|
||||
print(f"💾 Saving feature importance to {fi_file}")
|
||||
feature_importance.to_parquet(fi_file)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Could not compute feature importance: {e}")
|
||||
|
||||
# Save training settings
|
||||
print("\n💾 Saving training settings...")
|
||||
combined_settings = AutoGluonTrainingSettings(
|
||||
**asdict(settings),
|
||||
**asdict(dataset_ensemble),
|
||||
classes=training_data.target_labels,
|
||||
problem_type=problem_type,
|
||||
)
|
||||
settings_file = results_dir / "training_settings.toml"
|
||||
with open(settings_file, "w") as f:
|
||||
toml.dump({"settings": asdict(combined_settings)}, f)
|
||||
|
||||
# Save test metrics
|
||||
# We need to use pickle here, because the confusion matrix is stored as a dataframe
|
||||
# This only matters for classification tasks
|
||||
test_metrics_file = results_dir / "test_metrics.pickle"
|
||||
print(f"💾 Saving test metrics to {test_metrics_file}")
|
||||
with open(test_metrics_file, "wb") as f:
|
||||
pickle.dump(test_score, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
# Save the predictor
|
||||
predictor_file = results_dir / "tabular_predictor.pkl"
|
||||
print(f"💾 Saving TabularPredictor to {predictor_file}")
|
||||
predictor.save()
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 80)
|
||||
print("✅ AutoGluon Training Complete!")
|
||||
print("=" * 80)
|
||||
print(f"\n📂 Results saved to: {results_dir}")
|
||||
print(f"🏆 Best model: {predictor.model_best}")
|
||||
print(f"📈 Test {eval_metric}: {test_score[eval_metric]:.4f}")
|
||||
print(f"⏱️ Total models trained: {len(leaderboard)}")
|
||||
|
||||
stopwatch.summary()
|
||||
print("\nDone! 🎉")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
|
@ -53,14 +53,40 @@ def _collapse_to_dataframe(ds: xr.Dataset | xr.DataArray) -> pd.DataFrame:
|
|||
use_dummy = collapsed.shape[0] == 0
|
||||
if use_dummy:
|
||||
collapsed.loc[tuple(range(len(collapsed.index.names)))] = np.nan
|
||||
pivcols = set(collapsed.index.names) - {"cell_ids"}
|
||||
cols = cast(list[str], list(collapsed.index.names))
|
||||
pivcols = sorted(set(cols) - {"cell_ids"})
|
||||
collapsed = collapsed.pivot_table(index="cell_ids", columns=pivcols)
|
||||
collapsed.columns = ["_".join(map(str, v)) for v in collapsed.columns]
|
||||
if use_dummy:
|
||||
collapsed = collapsed.dropna(how="all")
|
||||
expected_cols = _get_expected_collapsed_columns(ds)
|
||||
missing_cols = set(expected_cols) - set(collapsed.columns)
|
||||
if missing_cols:
|
||||
raise ValueError(
|
||||
f"Collapsed dataframe is missing expected columns: {missing_cols=} {collapsed.columns=} {expected_cols=}"
|
||||
)
|
||||
return collapsed
|
||||
|
||||
|
||||
def _get_expected_collapsed_columns(ds: xr.Dataset | xr.DataArray) -> list[str]:
|
||||
dims = sorted(set(ds.dims) - {"cell_ids"})
|
||||
dims_product = list(product(*[ds.coords[dim].to_numpy() for dim in dims]))
|
||||
expected_cols = []
|
||||
if isinstance(ds, xr.Dataset):
|
||||
variables = list(ds.data_vars)
|
||||
for var in variables:
|
||||
for dims_values in dims_product:
|
||||
agg = "_".join(dims_values)
|
||||
expected_cols.append(f"{var}_{agg}")
|
||||
else:
|
||||
assert ds.name is not None, "DataArray must have a name to determine expected columns"
|
||||
for dims_values in dims_product:
|
||||
agg = "_".join(dims_values)
|
||||
expected_cols.append(f"{ds.name}_{agg}")
|
||||
|
||||
return expected_cols
|
||||
|
||||
|
||||
def _cell_ids_hash(cell_ids: pd.Series) -> str:
|
||||
sorted_ids = np.sort(cell_ids.to_numpy())
|
||||
return hashlib.blake2b(sorted_ids.tobytes(), digest_size=8).hexdigest()
|
||||
|
|
@ -130,8 +156,8 @@ class SplittedArrays[ArrayType: (torch.Tensor, np.ndarray, cp.ndarray)]:
|
|||
test: ArrayType
|
||||
|
||||
@cached_property
|
||||
def combined(self) -> ArrayType:
|
||||
"""Combined train and test arrays."""
|
||||
def complete(self) -> ArrayType:
|
||||
"""Complete train and test arrays."""
|
||||
if isinstance(self.train, torch.Tensor) and isinstance(self.test, torch.Tensor):
|
||||
return torch.cat([self.train, self.test], dim=0) # ty:ignore[invalid-return-type]
|
||||
elif isinstance(self.train, cp.ndarray) and isinstance(self.test, cp.ndarray):
|
||||
|
|
@ -272,7 +298,6 @@ class DatasetEnsemble:
|
|||
# ?: We can't use L2SourceDataset as types here because cyclopts can't handle Literals as dict keys
|
||||
dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict)
|
||||
variable_filters: dict[str, list[str]] = field(default_factory=dict)
|
||||
add_lonlat: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
# Validate filters
|
||||
|
|
@ -295,6 +320,10 @@ class DatasetEnsemble:
|
|||
f"Invalid dimension filter for {dim=}: {values}"
|
||||
" Dimension filter values must be a list with one or more entries."
|
||||
)
|
||||
if "Grid" in self.variable_filters.keys():
|
||||
filtered = set(self.variable_filters["Grid"])
|
||||
valid = {"x", "y", "cell_area", "land_area", "water_area", "land_ratio"}
|
||||
assert len(filtered - valid) == 0
|
||||
|
||||
def __hash__(self):
|
||||
return int(self.id(), 16)
|
||||
|
|
@ -320,12 +349,10 @@ class DatasetEnsemble:
|
|||
@cache
|
||||
def read_grid(self) -> gpd.GeoDataFrame:
|
||||
"""Load the grid dataframe and enrich it with lat-lon information."""
|
||||
columns_to_load = ["cell_id", "geometry", "cell_area", "land_area", "water_area", "land_ratio"]
|
||||
# The name add_lonlat has legacy reasons and should be add_location
|
||||
# If add_location is true, keep the x and y
|
||||
# For future reworks: "lat" and "lon" are also available columns
|
||||
if self.add_lonlat:
|
||||
columns_to_load.extend(["x", "y"])
|
||||
if "Grid" in self.variable_filters:
|
||||
columns_to_load = self.variable_filters["Grid"] + ["cell_id", "geometry"]
|
||||
else:
|
||||
columns_to_load = ["cell_id", "geometry", "cell_area", "land_area", "water_area", "land_ratio", "x", "y"]
|
||||
|
||||
# Reading the data takes for the largest grids ~1.7s
|
||||
gridfile = entropice.utils.paths.get_grid_file(self.grid, self.level)
|
||||
|
|
@ -468,7 +495,7 @@ class DatasetEnsemble:
|
|||
ds = unstack_era5_time(ds, era5_agg)
|
||||
|
||||
# Apply the temporal mode
|
||||
if isinstance(self.temporal_mode, int):
|
||||
if isinstance(self.temporal_mode, int) and member != "ArcticDEM":
|
||||
ds = ds.sel(year=self.temporal_mode, drop=True)
|
||||
|
||||
# Actually read data into memory
|
||||
|
|
|
|||
|
|
@ -79,13 +79,16 @@ class RunSettings:
|
|||
"""
|
||||
return "torch" if self.model == "espa" else "cuda"
|
||||
|
||||
def build_pipeline(self, model_hpo_config: ModelHPOConfig) -> Pipeline: # noqa: C901
|
||||
@property
|
||||
def hpo_config(self) -> ModelHPOConfig:
|
||||
"""Get the hyperparameter optimization configuration for the selected model and task."""
|
||||
return get_model_hpo_config(self.model, self.task)
|
||||
|
||||
def build_pipeline(self, model_hpo_config: ModelHPOConfig) -> Pipeline:
|
||||
"""Build a scikit-learn Pipeline based on the settings."""
|
||||
# Add a feature scaler / normalization step if specified, but assert that it's only used for non-Tree models
|
||||
if self.model in ["rf", "xgboost"]:
|
||||
assert self.scaler == "none", f"Scaler {self.scaler} is not viable with model {self.model}"
|
||||
elif self.scaler == "none":
|
||||
assert self.scaler != "none", f"No scaler specified for model {self.model}, which is not viable."
|
||||
|
||||
match self.scaler:
|
||||
case "standard":
|
||||
|
|
@ -159,9 +162,9 @@ def _compute_metrics(y: SplittedArrays, y_pred: SplittedArrays, metrics: list[st
|
|||
m = []
|
||||
for metric in metrics:
|
||||
metric_fn = metric_functions[metric]
|
||||
for split in ["train", "test", "combined"]:
|
||||
for split in ["train", "test", "complete"]:
|
||||
value = metric_fn(getattr(y, split), getattr(y_pred, split))
|
||||
m.append({"metric": metric, "split": split, "value": value})
|
||||
m.append({"metric": metric, "split": split, "score": value})
|
||||
return pd.DataFrame(m)
|
||||
|
||||
|
||||
|
|
@ -174,9 +177,9 @@ def _compute_confusion_matrices(
|
|||
{
|
||||
"test": (("true_label", "predicted_label"), confusion_matrix(y.test, y_pred.test, labels=codes)),
|
||||
"train": (("true_label", "predicted_label"), confusion_matrix(y.train, y_pred.train, labels=codes)),
|
||||
"combined": (
|
||||
"complete": (
|
||||
("true_label", "predicted_label"),
|
||||
confusion_matrix(y.combined, y_pred.combined, labels=codes),
|
||||
confusion_matrix(y.complete, y_pred.complete, labels=codes),
|
||||
),
|
||||
},
|
||||
coords={"true_label": labels, "predicted_label": labels},
|
||||
|
|
@ -235,9 +238,9 @@ def _compute_feature_importance(model: Model, best_estimator: Pipeline, training
|
|||
return feature_importances
|
||||
|
||||
|
||||
def _compute_shap_explanation(model: Model, best_estimator: Pipeline, training_data: TrainingSet) -> Explanation:
|
||||
def _compute_shap_explanation(model: Model, best_estimator: Pipeline, training_data: TrainingSet) -> Explanation: # noqa: C901
|
||||
match model:
|
||||
case "espa" | "knn" | "rf": # CUML models do not yet work with TreeExplainer...
|
||||
case "espa" | "knn":
|
||||
train_transformed = training_data.X.as_numpy().train
|
||||
if "scaler" in best_estimator.named_steps:
|
||||
train_transformed = best_estimator.named_steps["scaler"].transform(train_transformed)
|
||||
|
|
@ -263,15 +266,23 @@ def _compute_shap_explanation(model: Model, best_estimator: Pipeline, training_d
|
|||
feature_names=training_data.feature_names,
|
||||
output_names=training_data.target_labels,
|
||||
)
|
||||
case "rf":
|
||||
masker = shap.maskers.Independent(data=training_data.X.as_numpy().train)
|
||||
explainer = TreeExplainer(
|
||||
best_estimator.named_steps["model"].as_sklearn(),
|
||||
data=masker,
|
||||
feature_names=training_data.feature_names,
|
||||
)
|
||||
case "xgboost":
|
||||
explainer = TreeExplainer(best_estimator.named_steps["model"], feature_names=training_data.feature_names)
|
||||
case _:
|
||||
raise ValueError(f"Unknown model: {model}")
|
||||
|
||||
samples = training_data.X.as_numpy().test
|
||||
if len(samples) > 200:
|
||||
nsamples = 2 * samples.shape[1] + 2048
|
||||
if len(samples) > nsamples:
|
||||
rng = np.random.default_rng(seed=42)
|
||||
sample_indices = rng.choice(len(samples), size=200, replace=False)
|
||||
sample_indices = rng.choice(len(samples), size=nsamples, replace=False)
|
||||
samples = samples[sample_indices]
|
||||
if "scaler" in best_estimator.named_steps:
|
||||
samples = best_estimator.named_steps["scaler"].transform(samples)
|
||||
|
|
@ -314,7 +325,7 @@ def hpsearch_cv(
|
|||
task=settings.task, target=settings.target, device=settings.device
|
||||
)
|
||||
|
||||
model_hpo_config = get_model_hpo_config(settings.model, settings.task)
|
||||
model_hpo_config = settings.hpo_config
|
||||
print(f"Using model: {settings.model} with parameters: {model_hpo_config.hp_config}")
|
||||
|
||||
metrics, refit = get_metrics(settings.task)
|
||||
|
|
@ -324,8 +335,8 @@ def hpsearch_cv(
|
|||
print(f"Pipeline steps: {pipeline.named_steps}")
|
||||
|
||||
hp_search = settings.build_search(pipeline, model_hpo_config, metrics, refit)
|
||||
print(f"Starting hyperparameter search with {settings.n_iter} iterations...")
|
||||
with stopwatch(f"RandomizedSearchCV fitting for {settings.n_iter} candidates"):
|
||||
print(f"Starting hyperparameter search for {settings.model} with {settings.n_iter} iterations...")
|
||||
with stopwatch(f"RandomizedSearchCV fitting of {settings.model} for {settings.n_iter} candidates"):
|
||||
fit_params = {f"model__{k}": v for k, v in model_hpo_config.fit_params.items()}
|
||||
hp_search.fit(
|
||||
training_data.X.train,
|
||||
|
|
@ -379,7 +390,6 @@ def hpsearch_cv(
|
|||
),
|
||||
task=settings.task,
|
||||
target=settings.target,
|
||||
training_set=training_data,
|
||||
model=best_estimator,
|
||||
model_type=settings.model,
|
||||
metrics=metrics,
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ def predict_proba(
|
|||
grid_gdf = e.read_grid()
|
||||
for batch in e.create_inference_df(batch_size=batch_size):
|
||||
# Filter rows containing NaN values
|
||||
batch = batch.dropna(axis=0, how="any")
|
||||
batch = batch.dropna(axis="index", how="any")
|
||||
|
||||
# Skip empty batches (all rows had NaN values)
|
||||
if len(batch) == 0:
|
||||
|
|
@ -96,7 +96,13 @@ def predict_proba(
|
|||
|
||||
if isinstance(model, TabularPredictor):
|
||||
print(f"Predicting batch of size {len(batch)} ({type(batch)}) with AutoGluon TabularPredictor...")
|
||||
try:
|
||||
batch_preds = model.predict(batch)
|
||||
except Exception as ex:
|
||||
print("Something went wrong")
|
||||
print(batch)
|
||||
print(batch.columns)
|
||||
raise ex
|
||||
print(f"Batch predictions type: {type(batch_preds)}, shape: {batch_preds.shape}")
|
||||
|
||||
assert isinstance(batch_preds, pd.DataFrame | pd.Series), (
|
||||
|
|
|
|||
|
|
@ -60,6 +60,12 @@ def get_search_space(hp_config: HPConfig) -> dict[str, list | rv_continuous_froz
|
|||
f"Unknown distribution type for {key}: {dist['distribution']}"
|
||||
)
|
||||
distfn = getattr(scipy.stats, dist["distribution"])
|
||||
# Add edge-case for uniform distribution, as low-high is different there
|
||||
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.uniform.html#scipy.stats.uniform
|
||||
# Using the parameters loc and scale, one obtains the uniform distribution on [loc, loc + scale].
|
||||
if dist["distribution"] == "uniform":
|
||||
search_space[key] = distfn(loc=dist["low"], scale=dist["high"] - dist["low"])
|
||||
else:
|
||||
search_space[key] = distfn(dist["low"], dist["high"])
|
||||
return search_space
|
||||
|
||||
|
|
@ -188,7 +194,7 @@ def get_model_hpo_config(model: str, task: Task, **model_kwargs) -> ModelHPOConf
|
|||
clf = RandomForestClassifier(split_criterion="entropy", **model_kwargs)
|
||||
return ModelHPOConfig(clf, rf_hpconfig)
|
||||
case ("rf", "regressor"):
|
||||
reg = RandomForestRegressor(split_criterion="variance", **model_kwargs)
|
||||
reg = RandomForestRegressor(split_criterion="poisson", **model_kwargs)
|
||||
return ModelHPOConfig(reg, rf_hpconfig)
|
||||
case ("knn", "classifier"):
|
||||
clf = KNeighborsClassifier(**model_kwargs)
|
||||
|
|
|
|||
|
|
@ -1,241 +0,0 @@
|
|||
"""DEPRECATED!!! Training of classification models training."""
|
||||
|
||||
import pickle
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import cyclopts
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import toml
|
||||
import xarray as xr
|
||||
from rich import pretty, traceback
|
||||
from sklearn import set_config
|
||||
from sklearn.metrics import (
|
||||
confusion_matrix,
|
||||
)
|
||||
from sklearn.model_selection import KFold, RandomizedSearchCV
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble, SplittedArrays
|
||||
from entropice.ml.inference import predict_proba
|
||||
from entropice.ml.models import (
|
||||
extract_espa_feature_importance,
|
||||
extract_espa_state,
|
||||
extract_rf_feature_importance,
|
||||
extract_xgboost_feature_importance,
|
||||
get_model_hpo_config,
|
||||
)
|
||||
from entropice.utils.metrics import get_metrics, metric_functions
|
||||
from entropice.utils.paths import get_training_results_dir
|
||||
from entropice.utils.types import Model, TargetDataset, Task
|
||||
|
||||
traceback.install()
|
||||
pretty.install()
|
||||
|
||||
|
||||
cli = cyclopts.App("entropice-training", config=cyclopts.config.Toml("training-config.toml")) # ty:ignore[invalid-argument-type]
|
||||
|
||||
|
||||
@cyclopts.Parameter("*")
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class RunSettings:
|
||||
"""Cross-validation settings for model training."""
|
||||
|
||||
n_iter: int = 2000
|
||||
task: Task = "binary"
|
||||
target: TargetDataset = "darts_v1"
|
||||
model: Model = "espa"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TrainingSettings(DatasetEnsemble, RunSettings):
|
||||
"""Helper Wrapper to store combined training and dataset ensemble settings."""
|
||||
|
||||
param_grid: dict
|
||||
cv_splits: int
|
||||
metrics: list[str]
|
||||
classes: list[str] | None
|
||||
|
||||
|
||||
@cli.default
|
||||
def random_cv(
|
||||
dataset_ensemble: DatasetEnsemble,
|
||||
settings: RunSettings = RunSettings(),
|
||||
experiment: str | None = None,
|
||||
) -> Path:
|
||||
"""Perform random cross-validation on the training dataset.
|
||||
|
||||
Args:
|
||||
dataset_ensemble (DatasetEnsemble): The dataset ensemble configuration.
|
||||
settings (RunSettings): The cross-validation settings.
|
||||
experiment (str | None): Optional experiment name for results directory.
|
||||
|
||||
"""
|
||||
# Since we use cuml and xgboost libraries, we can only enable array API for ESPA
|
||||
use_array_api = settings.model != "xgboost"
|
||||
device = "torch" if settings.model == "espa" else "cuda"
|
||||
set_config(array_api_dispatch=use_array_api)
|
||||
|
||||
print("Creating training data...")
|
||||
training_data = dataset_ensemble.create_training_set(task=settings.task, target=settings.target, device=device)
|
||||
model_hpo_config = get_model_hpo_config(settings.model, settings.task)
|
||||
print(f"Using model: {settings.model} with parameters: {model_hpo_config.hp_config}")
|
||||
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
||||
metrics, refit = get_metrics(settings.task)
|
||||
search = RandomizedSearchCV(
|
||||
model_hpo_config.model,
|
||||
model_hpo_config.search_space,
|
||||
n_iter=settings.n_iter,
|
||||
n_jobs=1,
|
||||
cv=cv,
|
||||
random_state=42,
|
||||
verbose=10,
|
||||
scoring=metrics,
|
||||
refit=refit,
|
||||
)
|
||||
|
||||
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
|
||||
with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"):
|
||||
search.fit(
|
||||
training_data.X.train,
|
||||
# XGBoost returns it's labels as numpy arrays instead of cupy arrays
|
||||
# Thus, for the scoring to work, we need to convert them back to numpy
|
||||
training_data.y.as_numpy().train if settings.model == "xgboost" else training_data.y.train,
|
||||
**model_hpo_config.fit_params,
|
||||
)
|
||||
|
||||
print("Best parameters combination found:")
|
||||
best_estimator = search.best_estimator_
|
||||
best_parameters = best_estimator.get_params()
|
||||
for param_name in sorted(model_hpo_config.hp_config.keys()):
|
||||
print(f"{param_name}: {best_parameters[param_name]}")
|
||||
|
||||
test_score = search.score(
|
||||
training_data.X.test,
|
||||
training_data.y.as_numpy().test if settings.model == "xgboost" else training_data.y.test,
|
||||
)
|
||||
print(
|
||||
f"{refit.replace('_', ' ').capitalize()} of the best parameters using the inner CV"
|
||||
f" of the random search: {search.best_score_:.3f}"
|
||||
)
|
||||
print(f"{refit.replace('_', ' ').capitalize()} on test set: {test_score:.3f}")
|
||||
|
||||
results_dir = get_training_results_dir(
|
||||
experiment=experiment,
|
||||
name="random_search",
|
||||
grid=dataset_ensemble.grid,
|
||||
level=dataset_ensemble.level,
|
||||
task=settings.task,
|
||||
target=settings.target,
|
||||
model_type=settings.model,
|
||||
)
|
||||
|
||||
# Store the search settings
|
||||
combined_settings = TrainingSettings(
|
||||
**asdict(settings),
|
||||
**asdict(dataset_ensemble),
|
||||
param_grid=model_hpo_config.hp_config,
|
||||
cv_splits=cv.get_n_splits(),
|
||||
metrics=metrics,
|
||||
classes=training_data.target_labels,
|
||||
)
|
||||
settings_file = results_dir / "search_settings.toml"
|
||||
print(f"Storing search settings to {settings_file}")
|
||||
with open(settings_file, "w") as f:
|
||||
toml.dump({"settings": asdict(combined_settings)}, f)
|
||||
|
||||
# Store the best estimator model
|
||||
best_model_file = results_dir / "best_estimator_model.pkl"
|
||||
print(f"Storing best estimator model to {best_model_file}")
|
||||
with open(best_model_file, "wb") as f:
|
||||
pickle.dump(best_estimator, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
# Store the search results
|
||||
results = pd.DataFrame(search.cv_results_)
|
||||
# Parse the params into individual columns
|
||||
params = pd.json_normalize(results["params"]) # ty:ignore[invalid-argument-type]
|
||||
# Concatenate the params columns with the original DataFrame
|
||||
results = pd.concat([results.drop(columns=["params"]), params], axis=1)
|
||||
results_file = results_dir / "search_results.parquet"
|
||||
print(f"Storing CV results to {results_file}")
|
||||
results.to_parquet(results_file)
|
||||
|
||||
# Compute predictions on the all sets and move them to numpy for metric computations
|
||||
y_pred = SplittedArrays(
|
||||
train=best_estimator.predict(training_data.X.train),
|
||||
test=best_estimator.predict(training_data.X.test),
|
||||
).as_numpy()
|
||||
|
||||
# Compute and StoreMetrics
|
||||
y = training_data.y.as_numpy()
|
||||
test_metrics = {metric: metric_functions[metric](y.test, y_pred.test) for metric in metrics}
|
||||
train_metrics = {metric: metric_functions[metric](y.train, y_pred.train) for metric in metrics}
|
||||
combined_metrics = {metric: metric_functions[metric](y.combined, y_pred.combined) for metric in metrics}
|
||||
all_metrics = {
|
||||
"test_metrics": test_metrics,
|
||||
"train_metrics": train_metrics,
|
||||
"combined_metrics": combined_metrics,
|
||||
}
|
||||
test_metrics_file = results_dir / "metrics.toml"
|
||||
print(f"Storing test metrics to {test_metrics_file}")
|
||||
with open(test_metrics_file, "w") as f:
|
||||
toml.dump(all_metrics, f)
|
||||
|
||||
# Make confusion matrices for classification taasks
|
||||
if settings.task in ["binary", "count_regimes", "density_regimes"]:
|
||||
codes = np.array(training_data.target_codes)
|
||||
cm = xr.Dataset(
|
||||
{
|
||||
"test": (("true_label", "predicted_label"), confusion_matrix(y.test, y_pred.test, labels=codes)),
|
||||
"train": (("true_label", "predicted_label"), confusion_matrix(y.train, y_pred.train, labels=codes)),
|
||||
"combined": (
|
||||
("true_label", "predicted_label"),
|
||||
confusion_matrix(y.combined, y_pred.combined, labels=codes),
|
||||
),
|
||||
},
|
||||
coords={"true_label": training_data.target_labels, "predicted_label": training_data.target_labels},
|
||||
)
|
||||
# Store the confusion matrices
|
||||
cm_file = results_dir / "confusion_matrix.nc"
|
||||
print(f"Storing confusion matrices to {cm_file}")
|
||||
cm.to_netcdf(cm_file, engine="h5netcdf")
|
||||
|
||||
# Get the inner state of the best estimator
|
||||
if settings.model == "espa":
|
||||
state = extract_espa_state(best_estimator, training_data)
|
||||
state_file = results_dir / "best_estimator_state.nc"
|
||||
print(f"Storing best estimator state to {state_file}")
|
||||
state.to_netcdf(state_file, engine="h5netcdf")
|
||||
fi = extract_espa_feature_importance(best_estimator, training_data)
|
||||
fi_file = results_dir / "best_estimator_feature_importance.parquet"
|
||||
print(f"Storing best estimator feature importance to {fi_file}")
|
||||
fi.to_parquet(fi_file)
|
||||
|
||||
elif settings.model == "xgboost":
|
||||
fi = extract_xgboost_feature_importance(best_estimator, training_data)
|
||||
fi_file = results_dir / "best_estimator_feature_importance.parquet"
|
||||
print(f"Storing best estimator feature importance to {fi_file}")
|
||||
fi.to_parquet(fi_file)
|
||||
|
||||
elif settings.model == "rf":
|
||||
fi = extract_rf_feature_importance(best_estimator, training_data)
|
||||
fi_file = results_dir / "best_estimator_feature_importance.parquet"
|
||||
print(f"Storing best estimator feature importance to {fi_file}")
|
||||
fi.to_parquet(fi_file)
|
||||
|
||||
# Predict probabilities for all cells
|
||||
print("Predicting probabilities for all cells...")
|
||||
preds = predict_proba(dataset_ensemble, model=best_estimator, task=settings.task, device=device)
|
||||
print(f"Predicted probabilities DataFrame with {len(preds)} entries.")
|
||||
preds_file = results_dir / "predicted_probabilities.parquet"
|
||||
print(f"Storing predicted probabilities to {preds_file}")
|
||||
preds.to_parquet(preds_file)
|
||||
|
||||
stopwatch.summary()
|
||||
print("Done.")
|
||||
return results_dir
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
"""Training utilities for Entropice."""
|
||||
|
||||
import pickle
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import cached_property
|
||||
|
|
@ -47,7 +49,7 @@ def move_data_to_device(data: ndarray, device: Literal["torch", "cuda", "cpu"])
|
|||
|
||||
|
||||
@dataclass
|
||||
class HPOCV:
|
||||
class HPOCV: # noqa: D101
|
||||
method: HPSearch
|
||||
splitter: Splitter
|
||||
scaler: Scaler
|
||||
|
|
@ -55,13 +57,13 @@ class HPOCV:
|
|||
n_iter: int
|
||||
hpconfig: HPConfig
|
||||
|
||||
@property
|
||||
def search_space(self):
|
||||
@cached_property
|
||||
def search_space(self): # noqa: D102
|
||||
return get_search_space(self.hpconfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoML:
|
||||
class AutoML: # noqa: D101
|
||||
time_budget: int
|
||||
preset: str
|
||||
hpo: bool
|
||||
|
|
@ -80,7 +82,7 @@ class Training:
|
|||
|
||||
f(dataset, method) -> (model, metrics)
|
||||
|
||||
Metrics refer to a simple dataframe in long format, with the columns: "metric", "split", "value".
|
||||
Metrics refer to a simple dataframe in long format, with the columns: "metric", "split", "score".
|
||||
Split is either "train", "test" or "complete".
|
||||
"""
|
||||
|
||||
|
|
@ -89,12 +91,11 @@ class Training:
|
|||
method: HPOCV | AutoML
|
||||
task: Task
|
||||
target: TargetDataset
|
||||
training_set: TrainingSet # TODO: Store Training Set to improve loading time (?)
|
||||
model: Any
|
||||
model_type: Model
|
||||
metrics: pd.DataFrame
|
||||
feature_importance: pd.DataFrame
|
||||
shap_explanation: Explanation
|
||||
shap_explanation: Explanation | None
|
||||
predictions: gpd.GeoDataFrame
|
||||
confusion_matrix: xr.Dataset | None # only for classification tasks
|
||||
cv_results: pd.DataFrame | None # only for HPOCV
|
||||
|
|
@ -115,6 +116,31 @@ class Training:
|
|||
"""Get the list of metric names from the metrics DataFrame."""
|
||||
return self.metrics["metric"].unique().tolist()
|
||||
|
||||
@cached_property
|
||||
def training_set(self) -> TrainingSet:
|
||||
"""Get the training set for this training run."""
|
||||
return self.dataset.create_training_set(self.task, self.target, device="cpu")
|
||||
|
||||
@property
|
||||
def method_type(self) -> Literal["HPOCV", "AutoML"]:
|
||||
"""Get the type of method used in this training run."""
|
||||
if isinstance(self.method, HPOCV):
|
||||
return "HPOCV"
|
||||
elif isinstance(self.method, AutoML):
|
||||
return "AutoML"
|
||||
else:
|
||||
raise ValueError(f"Unknown method type: {type(self.method)}")
|
||||
|
||||
@property
|
||||
def n_trials(self) -> int | None:
|
||||
"""Get the number of trials in the hyperparameter search, if applicable."""
|
||||
if self.method_type == "HPOCV" and self.cv_results is not None:
|
||||
return len(self.cv_results)
|
||||
elif self.method_type == "AutoML" and self.leaderboard is not None:
|
||||
return len(self.leaderboard)
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def get_state(self) -> xr.Dataset | pd.DataFrame | None:
|
||||
"""Get the inner state of the trained model, if available."""
|
||||
|
|
@ -123,6 +149,10 @@ class Training:
|
|||
else:
|
||||
return None
|
||||
|
||||
def get_metrics_from_split(self, split: Literal["train", "test", "complete"]) -> dict[str, float]:
|
||||
"""Get a dictionary of metric names and values for the specified split."""
|
||||
return self.metrics[self.metrics["split"] == split].set_index("metric")["score"].to_dict() # ty:ignore[invalid-return-type]
|
||||
|
||||
def save(self):
|
||||
"""Save the training results to the specified path."""
|
||||
self.path.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -130,7 +160,6 @@ class Training:
|
|||
model_file = self.path / "model.pkl"
|
||||
metrics_file = self.path / "metrics.parquet"
|
||||
feature_importance_file = self.path / "feature_importance.parquet"
|
||||
explanations_file = self.path / "shap_explanation.pkl"
|
||||
predictions_file = self.path / "predictions.parquet"
|
||||
# Save config
|
||||
with open(config_file, "w") as f:
|
||||
|
|
@ -150,9 +179,13 @@ class Training:
|
|||
model_file.write_bytes(pickle.dumps(self.model))
|
||||
self.metrics.to_parquet(metrics_file)
|
||||
self.feature_importance.to_parquet(feature_importance_file)
|
||||
explanations_file.write_bytes(pickle.dumps(self.shap_explanation))
|
||||
self.predictions.to_parquet(predictions_file)
|
||||
|
||||
# Save SHAP explanation if it exists
|
||||
if self.shap_explanation is not None:
|
||||
explanations_file = self.path / "shap_explanation.pkl"
|
||||
explanations_file.write_bytes(pickle.dumps(self.shap_explanation))
|
||||
|
||||
# Save the confusion matrix if it exists
|
||||
if self.confusion_matrix is not None:
|
||||
cm_file = self.path / "confusion_matrix.nc"
|
||||
|
|
@ -168,13 +201,14 @@ class Training:
|
|||
self.leaderboard.to_parquet(leaderboard_file)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path, device: Literal["cpu", "cuda"] = "cpu") -> "Training":
|
||||
def load(cls, path: Path) -> "Training":
|
||||
"""Load a training run from the specified path."""
|
||||
config_file = path / "training_config.toml"
|
||||
model_file = path / "model.pkl"
|
||||
metrics_file = path / "metrics.parquet"
|
||||
feature_importance_file = path / "feature_importance.parquet"
|
||||
predictions_file = path / "predictions.parquet"
|
||||
explanations_file = path / "shap_explanation.pkl"
|
||||
cm_file = path / "confusion_matrix.nc"
|
||||
cv_results_file = path / "search_results.parquet"
|
||||
leaderboard_file = path / "leaderboard.parquet"
|
||||
|
|
@ -188,7 +222,6 @@ class Training:
|
|||
model_type = config["model_type"]
|
||||
|
||||
dataset = DatasetEnsemble(**config["dataset"])
|
||||
training_set = dataset.create_training_set(task, target, device)
|
||||
|
||||
method_type = config["method_type"]
|
||||
if method_type == "HPOCV":
|
||||
|
|
@ -202,9 +235,13 @@ class Training:
|
|||
model = pickle.loads(model_file.read_bytes())
|
||||
metrics = pd.read_parquet(metrics_file)
|
||||
feature_importance = pd.read_parquet(feature_importance_file)
|
||||
shap_explanation = pickle.loads((path / "shap_explanation.pkl").read_bytes())
|
||||
predictions = gpd.read_parquet(predictions_file)
|
||||
|
||||
# Load SHAP explanation if it exists
|
||||
shap_explanation = None
|
||||
if explanations_file.exists():
|
||||
shap_explanation = pickle.loads(explanations_file.read_bytes())
|
||||
|
||||
# Load confusion matrix if it exists
|
||||
confusion_matrix = None
|
||||
if cm_file.exists():
|
||||
|
|
@ -225,7 +262,6 @@ class Training:
|
|||
method=method,
|
||||
task=task,
|
||||
target=target,
|
||||
training_set=training_set,
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
metrics=metrics,
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ def sample_ensemble() -> Generator[DatasetEnsemble]:
|
|||
grid="hex",
|
||||
level=3, # Use level 3 for much faster tests
|
||||
members=["AlphaEarth"], # Use only one member for faster tests
|
||||
add_lonlat=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -29,7 +28,6 @@ def sample_ensemble_v2() -> Generator[DatasetEnsemble]:
|
|||
grid="hex",
|
||||
level=3, # Use level 3 for much faster tests
|
||||
members=["AlphaEarth"], # Use only one member for faster tests
|
||||
add_lonlat=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -41,7 +39,6 @@ class TestDatasetEnsemble:
|
|||
assert sample_ensemble.grid == "hex"
|
||||
assert sample_ensemble.level == 3
|
||||
assert "AlphaEarth" in sample_ensemble.members
|
||||
assert sample_ensemble.add_lonlat is True
|
||||
|
||||
def test_get_targets_returns_geodataframe(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||
"""Test that get_targets() returns a GeoDataFrame."""
|
||||
|
|
@ -106,8 +103,7 @@ class TestDatasetEnsemble:
|
|||
# Should NOT have geometry column
|
||||
assert "geometry" not in features.columns
|
||||
|
||||
# Should have location columns if add_lonlat is True
|
||||
if sample_ensemble.add_lonlat:
|
||||
# Should have location columns
|
||||
assert "x" in features.columns
|
||||
assert "y" in features.columns
|
||||
|
||||
|
|
|
|||
|
|
@ -1,222 +0,0 @@
|
|||
"""Tests for training.py module, specifically random_cv function.
|
||||
|
||||
This test suite validates the random_cv training function across all model-task
|
||||
combinations using a minimal hex level 3 grid with synopsis temporal mode.
|
||||
|
||||
Test Coverage:
|
||||
- All 12 model-task combinations (4 models x 3 tasks): espa, xgboost, rf, knn
|
||||
- Device handling for each model type (torch/CUDA/cuML compatibility)
|
||||
- Multi-label target dataset support
|
||||
- Temporal mode configuration (synopsis)
|
||||
- Output file creation and validation
|
||||
|
||||
Running Tests:
|
||||
# Run all training tests (18 tests total, ~3 iterations each)
|
||||
pixi run pytest tests/test_training.py -v
|
||||
|
||||
# Run only device handling tests
|
||||
pixi run pytest tests/test_training.py::TestRandomCV::test_device_handling -v
|
||||
|
||||
# Run a specific model-task combination
|
||||
pixi run pytest tests/test_training.py::TestRandomCV::test_random_cv_all_combinations[binary-espa] -v
|
||||
|
||||
Note: Tests use minimal iterations (3) and level 3 grid for speed.
|
||||
Full production runs use higher iteration counts (100-2000).
|
||||
"""
|
||||
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.ml.randomsearch import RunSettings, random_cv
|
||||
from entropice.utils.types import Model, Task
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def test_ensemble():
|
||||
"""Create a minimal DatasetEnsemble for testing.
|
||||
|
||||
Uses hex level 3 grid with synopsis temporal mode for fast testing.
|
||||
"""
|
||||
return DatasetEnsemble(
|
||||
grid="hex",
|
||||
level=3,
|
||||
temporal_mode="synopsis",
|
||||
members=["AlphaEarth"], # Use only one member for faster tests
|
||||
add_lonlat=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_results():
|
||||
"""Clean up results directory after each test.
|
||||
|
||||
This fixture collects the actual result directories created during tests
|
||||
and removes them after the test completes.
|
||||
"""
|
||||
created_dirs = []
|
||||
|
||||
def register_dir(results_dir):
|
||||
"""Register a directory to be cleaned up."""
|
||||
created_dirs.append(results_dir)
|
||||
return results_dir
|
||||
|
||||
yield register_dir
|
||||
|
||||
# Clean up only the directories created during this test
|
||||
for results_dir in created_dirs:
|
||||
if results_dir.exists():
|
||||
shutil.rmtree(results_dir)
|
||||
|
||||
|
||||
# Model-task combinations to test
|
||||
# Note: Not all combinations make sense, but we test all to ensure robustness
|
||||
MODELS: list[Model] = ["espa", "xgboost", "rf", "knn"]
|
||||
TASKS: list[Task] = ["binary", "count", "density"]
|
||||
|
||||
|
||||
class TestRandomCV:
|
||||
"""Test suite for random_cv function."""
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("task", TASKS)
|
||||
def test_random_cv_all_combinations(self, test_ensemble, model: Model, task: Task, cleanup_results):
|
||||
"""Test random_cv with all model-task combinations.
|
||||
|
||||
This test runs 3 iterations for each combination to verify:
|
||||
- The function completes without errors
|
||||
- Device handling works correctly for each model type
|
||||
- All output files are created
|
||||
"""
|
||||
# Use darts_v1 as the primary target for all tests
|
||||
settings = RunSettings(
|
||||
n_iter=3,
|
||||
task=task,
|
||||
target="darts_v1",
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Run the cross-validation and get the results directory
|
||||
results_dir = random_cv(
|
||||
dataset_ensemble=test_ensemble,
|
||||
settings=settings,
|
||||
experiment="test_training",
|
||||
)
|
||||
cleanup_results(results_dir)
|
||||
|
||||
# Verify results directory was created
|
||||
assert results_dir.exists(), f"Results directory not created for {model=}, {task=}"
|
||||
|
||||
# Verify all expected output files exist
|
||||
expected_files = [
|
||||
"search_settings.toml",
|
||||
"best_estimator_model.pkl",
|
||||
"search_results.parquet",
|
||||
"metrics.toml",
|
||||
"predicted_probabilities.parquet",
|
||||
]
|
||||
|
||||
# Add task-specific files
|
||||
if task in ["binary", "count", "density"]:
|
||||
# All tasks that use classification (including count/density when binned)
|
||||
# Note: count and density without _regimes suffix might be regression
|
||||
if task == "binary" or "_regimes" in task:
|
||||
expected_files.append("confusion_matrix.nc")
|
||||
|
||||
# Add model-specific files
|
||||
if model in ["espa", "xgboost", "rf"]:
|
||||
expected_files.append("best_estimator_state.nc")
|
||||
|
||||
for filename in expected_files:
|
||||
filepath = results_dir / filename
|
||||
assert filepath.exists(), f"Expected file {filename} not found for {model=}, {task=}"
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
def test_device_handling(self, test_ensemble, model: Model, cleanup_results):
|
||||
"""Test that device handling works correctly for each model type.
|
||||
|
||||
Different models require different device configurations:
|
||||
- espa: Uses torch with array API dispatch
|
||||
- xgboost: Uses CUDA without array API dispatch
|
||||
- rf/knn: GPU-accelerated via cuML
|
||||
"""
|
||||
settings = RunSettings(
|
||||
n_iter=3,
|
||||
task="binary", # Simple binary task for device testing
|
||||
target="darts_v1",
|
||||
model=model,
|
||||
)
|
||||
|
||||
# This should complete without device-related errors
|
||||
try:
|
||||
results_dir = random_cv(
|
||||
dataset_ensemble=test_ensemble,
|
||||
settings=settings,
|
||||
experiment="test_training",
|
||||
)
|
||||
cleanup_results(results_dir)
|
||||
except RuntimeError as e:
|
||||
# Check if error is device-related
|
||||
error_msg = str(e).lower()
|
||||
device_keywords = ["cuda", "gpu", "device", "cpu", "torch", "cupy"]
|
||||
if any(keyword in error_msg for keyword in device_keywords):
|
||||
pytest.fail(f"Device handling error for {model=}: {e}")
|
||||
else:
|
||||
# Re-raise non-device errors
|
||||
raise
|
||||
|
||||
def test_random_cv_with_mllabels(self, test_ensemble, cleanup_results):
|
||||
"""Test random_cv with multi-label target dataset."""
|
||||
settings = RunSettings(
|
||||
n_iter=3,
|
||||
task="binary",
|
||||
target="darts_mllabels",
|
||||
model="espa",
|
||||
)
|
||||
|
||||
# Run the cross-validation and get the results directory
|
||||
results_dir = random_cv(
|
||||
dataset_ensemble=test_ensemble,
|
||||
settings=settings,
|
||||
experiment="test_training",
|
||||
)
|
||||
cleanup_results(results_dir)
|
||||
|
||||
# Verify results were created
|
||||
assert results_dir.exists(), "Results directory not created"
|
||||
assert (results_dir / "search_settings.toml").exists()
|
||||
|
||||
def test_temporal_mode_synopsis(self, cleanup_results):
|
||||
"""Test that temporal_mode='synopsis' is correctly used."""
|
||||
import toml
|
||||
|
||||
ensemble = DatasetEnsemble(
|
||||
grid="hex",
|
||||
level=3,
|
||||
temporal_mode="synopsis",
|
||||
members=["AlphaEarth"],
|
||||
add_lonlat=True,
|
||||
)
|
||||
|
||||
settings = RunSettings(
|
||||
n_iter=3,
|
||||
task="binary",
|
||||
target="darts_v1",
|
||||
model="espa",
|
||||
)
|
||||
|
||||
# This should use synopsis mode (all years aggregated)
|
||||
results_dir = random_cv(
|
||||
dataset_ensemble=ensemble,
|
||||
settings=settings,
|
||||
experiment="test_training",
|
||||
)
|
||||
cleanup_results(results_dir)
|
||||
|
||||
# Verify the settings were stored correctly
|
||||
assert results_dir.exists(), "Results directory not created"
|
||||
with open(results_dir / "search_settings.toml") as f:
|
||||
stored_settings = toml.load(f)
|
||||
|
||||
assert stored_settings["settings"]["temporal_mode"] == "synopsis"
|
||||
Loading…
Add table
Add a link
Reference in a new issue