Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uw #12

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Uw #12

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
**/__pycache__
outputs
test_data
logs
multirun
lightning_logs
Expand Down
152 changes: 152 additions & 0 deletions .hydra/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
xpd: jobs/new_baseline
xp_name: ${hydra:runtime.choices.xp}
xp_overrides:
model:
save_dir: ${hydra:runtime.output_dir}
rec_weight: ${rec_weight}
test_metrics: ${ose_metrics.test_metrics}
pre_metric_fn: ${ose_metrics.pre_metric_fn}
datamodule:
input_da:
_target_: ${oc.select:data.${inp_da},contrib.ose2osse.data.load_ose_data}
domains:
test:
time:
_target_: builtins.slice
_args_:
- '2016-12-01'
- '2018-01-31'
train:
time:
_target_: builtins.slice
_args_:
- '2016-12-01'
- '2018-01-31'
xrds_kw:
strides:
lat: 100
lon: 100
domain_limits: ${domain.train}
dl_kw:
batch_size: 16
num_workers: 10
persistent_workers: false
patcher_cls:
_partial_: true
dense_vars: null
sparse_vars:
- input
- tgt
res: 5000.0
weight: ${rec_weight}
cache: false
nproc_rec: 10
overrides_targets:
model.rec_weight: contrib.multi_domain_diag.get_smooth_spat_rec_weight
model.rec_weight.orig_rec_weight: contrib.multi_domain_diag.load_cfg_from_xp
datamodule.patcher_cls: contrib.ortho.OrthoPatcher
datamodule.patcher_cls.weight: contrib.multi_domain_diag.get_smooth_spat_rec_weight
datamodule.patcher_cls.weight.orig_rec_weight: contrib.multi_domain_diag.load_cfg_from_xp
trainer:
_target_: pytorch_lightning.Trainer
inference_mode: false
accelerator: gpu
devices: 2
logger:
_target_: pytorch_lightning.loggers.CSVLogger
save_dir: ${xpd}
name: ${xp_name}
lit_mod:
_target_: contrib.multi_domain_diag.load_cfg_from_xp
key: model
xpd: ${xpd}
overrides: ${xp_overrides}
overrides_targets: ${overrides_targets}
dm:
_target_: contrib.multi_domain_diag.load_cfg_from_xp
key: datamodule
xpd: ${xpd}
overrides: ${xp_overrides}
overrides_targets: ${overrides_targets}
ckpt:
_target_: src.utils.best_ckpt
xp_dir: ${xpd}
rec_weight:
orig_rec_weight:
key: model.rec_weight
xpd: ${xpd}
entrypoints:
- _target_: pytorch_lightning.seed_everything
seed: 333
- _target_: builtins.print
_args_:
- ${hydra:runtime.output_dir}
- _target_: pytorch_lightning.Trainer.test
self: ${trainer}
model: ${lit_mod}
datamodule: ${dm}
ckpt_path: ${ckpt}
domain:
train:
lat:
_target_: builtins.slice
_args_:
- 32
- 54
lon:
_target_: builtins.slice
_args_:
- -51
- -9
test:
lat:
_target_: builtins.slice
_args_:
- 33
- 53
lon:
_target_: builtins.slice
_args_:
- -50
- -10
inp_da: default
data:
sst: contrib.multimodal.ose_utils.load_ose_data_with_mursst
ose_metrics:
test_metrics:
mu:
_target_: contrib.ose2osse.diagnostics.rmse_score
_partial_: true
rmse:
_target_: contrib.ose2osse.diagnostics.rmse
_partial_: true
lx:
_target_: contrib.ose2osse.diagnostics.dc_spat_res_from_diag_data
_partial_: true
v: rec
pre_metric_fn:
_target_: src.utils.pipe
_partial_: true
fns:
- _target_: operator.attrgetter
_args_:
- out
- _target_: contrib.ose2osse.diagnostics.compute_segment_data
_partial_: true
oi:
_target_: src.utils.pipe
inp:
_target_: xarray.open_dataset
_args_:
- ../sla-data-registry/data_OSE/NATL/training/ssh_alg_h2g_j2g_j2n_j3_s3a_duacs.nc
fns:
- _target_: operator.attrgetter
_args_:
- ssh
test_track:
_target_: xarray.open_dataset
_args_:
- ../sla-data-registry/data_OSE/along_track/c2_2017_world.nc
- _target_: operator.itemgetter
_args_:
- 1
171 changes: 171 additions & 0 deletions .hydra/hydra.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
hydra:
run:
dir: /raid/localscratch/qfebvre/4dvarnet-starter
sweep:
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
subdir: ${hydra.job.num}
launcher:
_target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
sweeper:
_target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
max_batch_size: null
params: null
help:
app_name: ${hydra.job.name}
header: '${hydra.help.app_name} is powered by Hydra.

'
footer: 'Powered by Hydra (https://hydra.cc)

Use --hydra-help to view Hydra specific help

'
template: '${hydra.help.header}

== Configuration groups ==

Compose your configuration from those groups (group=option)


$APP_CONFIG_GROUPS


== Config ==

Override anything in the config (foo.bar=value)


$CONFIG


${hydra.help.footer}

'
hydra_help:
template: 'Hydra (${hydra.runtime.version})

See https://hydra.cc for more info.


== Flags ==

$FLAGS_HELP


== Configuration groups ==

Compose your configuration from those groups (For example, append hydra/job_logging=disabled
to command line)


$HYDRA_CONFIG_GROUPS


Use ''--cfg hydra'' to Show the Hydra config.

'
hydra_help: ???
hydra_logging:
version: 1
formatters:
simple:
format: '[%(asctime)s][HYDRA] %(message)s'
handlers:
console:
class: logging.StreamHandler
formatter: simple
stream: ext://sys.stdout
root:
level: INFO
handlers:
- console
loggers:
logging_example:
level: DEBUG
disable_existing_loggers: false
job_logging:
version: 1
formatters:
simple:
format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
handlers:
console:
class: logging.StreamHandler
formatter: simple
stream: ext://sys.stdout
file:
class: logging.FileHandler
formatter: simple
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
root:
level: INFO
handlers:
- console
- file
disable_existing_loggers: false
env: {}
mode: RUN
searchpath:
- pkg://config
- pkg://contrib
callbacks: {}
output_subdir: .hydra
overrides:
hydra:
- hydra.run.dir="/raid/localscratch/qfebvre/4dvarnet-starter"
- hydra.job.name=train_ddp_process_1
- hydra.mode=RUN
task:
- xp=test_base
- xpd=jobs/new_baseline
- +params=[ose_metrics,ortho_test_ose]
- trainer.devices=2
- domain=cNATL
job:
name: train_ddp_process_1
chdir: null
override_dirname: +params=[ose_metrics,ortho_test_ose],domain=cNATL,trainer.devices=2,xp=test_base,xpd=jobs/new_baseline
id: ???
num: ???
config_name: main
env_set: {}
env_copy: []
config:
override_dirname:
kv_sep: '='
item_sep: ','
exclude_keys: []
runtime:
version: 1.3.2
version_base: '1.3'
cwd: /raid/localscratch/qfebvre/4dvarnet-starter
config_sources:
- path: hydra.conf
schema: pkg
provider: hydra
- path: /raid/localscratch/qfebvre/4dvarnet-starter/config
schema: file
provider: main
- path: config
schema: pkg
provider: hydra.searchpath in main
- path: contrib
schema: pkg
provider: hydra.searchpath in main
- path: ''
schema: structured
provider: schema
output_dir: /raid/localscratch/qfebvre/4dvarnet-starter
choices:
domain: cNATL
xp: test_base
hydra/env: default
hydra/callbacks: null
hydra/job_logging: default
hydra/hydra_logging: default
hydra/hydra_help: default
hydra/help: default
hydra/sweeper: basic
hydra/launcher: basic
hydra/output: default
verbose: false
5 changes: 5 additions & 0 deletions .hydra/overrides.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- xp=test_base
- xpd=jobs/new_baseline
- +params=[ose_metrics,ortho_test_ose]
- trainer.devices=2
- domain=cNATL
11 changes: 10 additions & 1 deletion config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
OmegaConf.register_new_resolver(
"_singleton",
lambda k: dict(
_target_="main.SingletonStore.get",
_target_="config.SingletonStore.get",
key=k,
obj_cfg="${" + k + "}",
),
Expand All @@ -15,6 +15,15 @@
"singleton", lambda k: "${oc.create:${_singleton:" + k + "}}", replace=True
)

def drop_target(cfg):
cfg = OmegaConf.resolve(cfg)
if "_target_" in cfg:
del cfg["_target_"]
return cfg

OmegaConf.register_new_resolver(
"drop_tgt", drop_target, replace=True
)

class SingletonStore:
STORE = dict()
Expand Down
Loading