Skip to content

Commit

Permalink
Hybrid flavor 3, parallel implementation
Browse files Browse the repository at this point in the history
The parallel version of the hybrid scheme is finalized.
Both serial and parallel runs have been matched.
- A bug in flavor 3 algorithm has been fixed. The bug
didn't make use of the right weights when updating the
state variables
- The hybrid weights at the obs locations is now implemeneted
for identity obs
- Functionality to allow the weight sd to change in time has been
added to the hybrid module
  • Loading branch information
mgharamti committed Mar 6, 2023
1 parent 82e510c commit 60fae50
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 57 deletions.
84 changes: 49 additions & 35 deletions assimilation_code/modules/assimilation/assim_tools_mod.f90
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ module assim_tools_mod
get_close_obs, get_close_state, &
convert_vertical_obs, convert_vertical_state

use distributed_state_mod, only : create_mean_window, free_mean_window
use distributed_state_mod, only : create_mean_window, free_mean_window, get_state

use quality_control_mod, only : good_dart_qc, DARTQC_FAILED_VERT_CONVERT

Expand Down Expand Up @@ -414,11 +414,10 @@ subroutine filter_assim(ens_handle, obs_ens_handle, obs_seq, keys, ens_size,
real(r8) :: stat_obs_prior_mean, stat_obs_prior_var
real(r8) :: hyb_obs_var, corr_rho
real(r8) :: vary_ss_hybrid_mean, vary_ss_hybrid_sd
real(r8), allocatable :: orig_hybrid_weight(:)
real(r8), allocatable :: stat_obs_prior(:)
real(r8), allocatable :: dtrd(:), tr_R(:), tr_B(:)
real(r8) :: sum_d, sum_R, sum_B, sum_d_all, sum_R_all, sum_B_all
real(r8) :: fs, ens_mean(1), orig_hyb_mean
real(r8) :: fs, ens_mean(1), orig_hyb_mean(1)
integer :: i_qc

! Are we hybridizing the increments?
Expand Down Expand Up @@ -568,9 +567,7 @@ subroutine filter_assim(ens_handle, obs_ens_handle, obs_seq, keys, ens_size,

! Use the original weight for hybridizing
! Need to change for spatially-varying form
allocate(orig_hybrid_weight(ens_handle%num_vars))
orig_hybrid_weight = ens_handle%copies(HYB_MEAN_COPY, :)
!ens_handle%copies(ENS_MEAN_COPY, :) = ens_handle%copies(HYB_MEAN_COPY, :)
ens_handle%copies(ENS_MEAN_COPY, :) = ens_handle%copies(HYB_MEAN_COPY, :)
endif

! Get info on my number and indices for obs
Expand Down Expand Up @@ -682,7 +679,7 @@ subroutine filter_assim(ens_handle, obs_ens_handle, obs_seq, keys, ens_size,
! rescaled static covariance more appropriately estimated
! the total forecast-error variance
if (do_hybrid) then

if(hybrid_scaling < 0.0_r8) then

i_qc = 0
Expand Down Expand Up @@ -747,6 +744,10 @@ subroutine filter_assim(ens_handle, obs_ens_handle, obs_seq, keys, ens_size,
stat_obs_ens_handle%copies(stat_obs_var_copy, :)
endif

!print *, ''
!write(*, '(A)') 'Before Obs Loop:'
!write(*, '(A, 20F10.6)') 'x(879): ', ens_handle%copies(1:ens_size, 879)

! use MLOOP for the overall outer loop times; LG_GRN is for
! sections inside the overall loop, including the total time
! for the state_update and obs_update loops. use SM_GRN for
Expand Down Expand Up @@ -788,17 +789,6 @@ subroutine filter_assim(ens_handle, obs_ens_handle, obs_seq, keys, ens_size,
call get_state_meta_data(-1 * int(base_obs_type,i8), dummyloc, base_obs_kind) ! identity obs
endif

! Weight coefficient at this obs location
if (do_hybrid) then
! I don't know the exact obs location, take average in space
if (base_obs_type > 0 ) then
orig_hyb_mean = sum(orig_hybrid_weight)/size(orig_hybrid_weight)
else
! Identity case:
orig_hyb_mean = orig_hybrid_weight(-1 * int(base_obs_type,i8))
endif
endif

! Get the value of the observation
call get_obs_values(observation, obs, obs_val_index)

Expand Down Expand Up @@ -845,13 +835,20 @@ subroutine filter_assim(ens_handle, obs_ens_handle, obs_seq, keys, ens_size,
! The assumption here is that if the ensemble QC is OK, then the
! static one is also OK. May need a revisit!
stat_obs_prior = stat_obs_ens_handle%copies(1:hybrid_ens_size, owners_index)


! I don't know the exact obs location, take average in space
if (base_obs_type > 0 ) then
orig_hyb_mean = sum(ens_handle%copies(ENS_MEAN_COPY, :)) / size(ens_handle%copies(ENS_MEAN_COPY, :))
else ! Identity case:
orig_hyb_mean = stat_obs_ens_handle%copies(hybrid_ens_size+2, owners_index)
endif
print *, 'orig_hyb_mean: ', orig_hyb_mean

! Find the static variance for this observation
hyb_obs_var = stat_obs_ens_handle%copies(stat_obs_var_copy, owners_index)

call obs_increment(obs_prior, ens_size, obs(1), &
obs_err_var, obs_inc, inflate, my_inflate, &
my_inflate_sd, net_a(1), orig_hyb_mean, stat_obs_prior)
my_inflate_sd, net_a(1), orig_hyb_mean(1), stat_obs_prior)
endif
end do

Expand Down Expand Up @@ -940,7 +937,7 @@ subroutine filter_assim(ens_handle, obs_ens_handle, obs_seq, keys, ens_size,
whichvert_real = real(whichvert_obs_in_localization_coord, r8)
if(local_varying_ss_inflate) then
call broadcast_send(map_pe_to_task(ens_handle, owner), obs_prior, obs_inc, &
orig_obs_prior_mean, orig_obs_prior_var, net_a, scalar1=obs_qc, &
orig_obs_prior_mean, orig_obs_prior_var, net_a, stat_obs_prior, orig_hyb_mean, scalar1=obs_qc, &
scalar2=vertvalue_obs_in_localization_coord, scalar3=whichvert_real, &
scalar4=my_hybrid_weight, scalar5=my_hybrid_weight_sd, scalar6=hyb_obs_var)

Expand All @@ -965,7 +962,7 @@ subroutine filter_assim(ens_handle, obs_ens_handle, obs_seq, keys, ens_size,
!>the cost of sending unneeded values
if(local_varying_ss_inflate) then
call broadcast_recv(map_pe_to_task(ens_handle, owner), obs_prior, obs_inc, &
orig_obs_prior_mean, orig_obs_prior_var, net_a, scalar1=obs_qc, &
orig_obs_prior_mean, orig_obs_prior_var, net_a, stat_obs_prior, orig_hyb_mean, scalar1=obs_qc, &
scalar2=vertvalue_obs_in_localization_coord, scalar3=whichvert_real, &
scalar4=my_hybrid_weight, scalar5=my_hybrid_weight_sd, scalar6=hyb_obs_var)

Expand Down Expand Up @@ -1197,6 +1194,13 @@ subroutine filter_assim(ens_handle, obs_ens_handle, obs_seq, keys, ens_size,
! Loop through to update each of my state variables that is potentially close
if (timing(LG_GRN)) call start_timer(t_base(LG_GRN))

!print *, ''
!write(*, '(A, I2)') 'obsnum: ', i
!write(*, '(A, 20F10.6)') 'x(879): ', ens_handle%copies(1:ens_size, 879)
!write(*, '(A, 20F10.6)') 'x(880): ', ens_handle%copies(1:ens_size, 880)
!write(*, '(A, 20F10.6)') 'x(881): ', ens_handle%copies(1:ens_size, 881)
!print *, ''

STATE_UPDATE: do j = 1, num_close_states
state_index = close_state_ind(j)

Expand Down Expand Up @@ -1246,16 +1250,16 @@ subroutine filter_assim(ens_handle, obs_ens_handle, obs_seq, keys, ens_size,
obs_inc, ens_handle%copies(1:ens_size, state_index), ens_size, &
stat_obs_prior, stat_obs_prior_mean, stat_obs_prior_var, &
stat_ens_handle%copies(1:hybrid_ens_size, state_index), hybrid_ens_size, &
!ens_handle%copies(ENS_MEAN_COPY, state_index),
orig_hybrid_weight(state_index), increment, reg_coef(1), net_a(1), correl(1))
ens_handle%copies(ENS_MEAN_COPY, state_index), orig_hyb_mean(1), &
increment, reg_coef(1), net_a(1), correl(1))
else
call update_from_hybobs_inc(obs_prior, obs_prior_mean(1), obs_prior_var(1), &
obs_inc, ens_handle%copies(1:ens_size, state_index), ens_size, &
stat_obs_prior, stat_obs_prior_mean, stat_obs_prior_var, &
stat_ens_handle%copies(1:hybrid_ens_size, state_index), hybrid_ens_size, &
!ens_handle%copies(ENS_MEAN_COPY, state_index),
orig_hybrid_weight(state_index), increment, reg_coef(1), net_a(1))
endif
ens_handle%copies(ENS_MEAN_COPY, state_index), orig_hyb_mean(1), &
increment, reg_coef(1), net_a(1))
endif
else
! Loop through groups to update the state variable ensemble members
do group = 1, num_groups
Expand Down Expand Up @@ -1441,7 +1445,7 @@ subroutine filter_assim(ens_handle, obs_ens_handle, obs_seq, keys, ens_size,
obs_ens_handle%copies(1:ens_size, obs_index), ens_size, &
stat_obs_prior, stat_obs_prior_mean, stat_obs_prior_var, &
stat_obs_ens_handle%copies(1:hybrid_ens_size, obs_index), hybrid_ens_size, &
orig_hyb_mean, increment, reg_coef(1), net_a(1))
orig_hyb_mean(1), orig_hyb_mean(1), increment, reg_coef(1), net_a(1))
else
do group = 1, num_groups
grp_bot = grp_beg(group)
Expand Down Expand Up @@ -1599,7 +1603,7 @@ subroutine filter_assim(ens_handle, obs_ens_handle, obs_seq, keys, ens_size,
deallocate(n_close_state_items, &
n_close_obs_items)

if (do_hybrid) deallocate(stat_obs_prior, orig_hybrid_weight)
if (do_hybrid) deallocate(stat_obs_prior)
! end dealloc

end subroutine filter_assim
Expand Down Expand Up @@ -2276,7 +2280,8 @@ end subroutine update_from_obs_inc

subroutine update_from_hybobs_inc(obs, obs_prior_mean, obs_prior_var, obs_inc, &
state, ens_size, clim_obs, clim_obs_prior_mean, clim_obs_prior_var, &
clim_state, ens_size2, weight_factor, state_inc, reg_coef, net_a, correl_out)
clim_state, ens_size2, alpha_x, alpha_y, state_inc, reg_coef, &
net_a, correl_out)
!========================================================================

! Does linear regression of a state variable onto an observation and
Expand All @@ -2291,7 +2296,7 @@ subroutine update_from_hybobs_inc(obs, obs_prior_mean, obs_prior_var, obs_inc,
real(r8), intent(in) :: obs_prior_mean, obs_prior_var
real(r8), intent(in) :: clim_obs_prior_mean, clim_obs_prior_var
real(r8), intent(in) :: state(ens_size), clim_state(ens_size2)
real(r8), intent(in) :: weight_factor
real(r8), intent(in) :: alpha_x, alpha_y
real(r8), intent(out) :: state_inc(ens_size), reg_coef
real(r8), intent(inout) :: net_a
real(r8), optional, intent(inout) :: correl_out
Expand All @@ -2300,6 +2305,7 @@ subroutine update_from_hybobs_inc(obs, obs_prior_mean, obs_prior_var, obs_inc,
real(r8) :: restoration_inc(ens_size), state_mean, state_var, correl
real(r8) :: factor, exp_true_correl, mean_factor
real(r8) :: clim_state_mean, clim_obs_state_cov
real(r8) :: hyb_state_var, clim_state_var

! For efficiency, just compute regression coefficient here unless correl is needed
state_mean = sum(state) / ens_size
Expand All @@ -2308,8 +2314,14 @@ subroutine update_from_hybobs_inc(obs, obs_prior_mean, obs_prior_var, obs_inc,
clim_state_mean = sum(clim_state) / ens_size2
clim_obs_state_cov = sum( (clim_state - clim_state_mean) * (clim_obs - clim_obs_prior_mean) ) / (ens_size2 - 1)

hyb_obs_prior_var = weight_factor * obs_prior_var + (1.0_r8 - weight_factor) * clim_obs_prior_var
hyb_obs_state_cov = weight_factor * obs_state_cov + (1.0_r8 - weight_factor) * clim_obs_state_cov
!if (clim_state_mean /= clim_state_mean .or. clim_obs_prior_mean /= clim_obs_prior_mean) then
! hyb_obs_prior_var = obs_prior_var
! hyb_obs_state_cov = obs_state_cov
!else
hyb_obs_prior_var = alpha_y * obs_prior_var + (1.0_r8 - alpha_y) * clim_obs_prior_var
hyb_obs_state_cov = sqrt(alpha_x) * sqrt(alpha_y) * obs_state_cov + &
sqrt(1.0_r8 - alpha_x) * sqrt(1.0_r8 - alpha_y) * clim_obs_state_cov
!endif

if (hyb_obs_prior_var > 0.0_r8) then
reg_coef = hyb_obs_state_cov / hyb_obs_prior_var
Expand Down Expand Up @@ -2338,7 +2350,9 @@ subroutine update_from_hybobs_inc(obs, obs_prior_mean, obs_prior_var, obs_inc,
if (obs_state_cov == 0.0_r8 .or. obs_prior_var <= 0.0_r8) then
correl = 0.0_r8
else
state_var = sum((state - state_mean)**2) / (ens_size - 1)
state_var = sum((state - state_mean)**2) / (ens_size - 1)
!clim_state_var = sum((clim_state - clim_state_mean)**2) / (ens_size2 - 1)
!hyb_state_var = weight_factor * state_var + (1.0_r8 - weight_factor) * clim_state_var
if (state_var <= 0.0_r8) then
correl = 0.0_r8
else
Expand Down
53 changes: 39 additions & 14 deletions assimilation_code/modules/assimilation/filter_mod.f90
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ module filter_mod
logical :: do_hybrid = .false.
logical :: output_hybrid = .false.
integer :: STAT_ENS_MEAN_COPY = COPY_NOT_PRESENT
integer :: STAT_ENS_SD_COPY = COPY_NOT_PRESENT
integer :: STAT_ENS_SD_COPY = COPY_NOT_PRESENT
integer :: STAT_HYB_COPY = COPY_NOT_PRESENT
type(hybrid_type) :: hybridization

logical :: has_cycling = .false. ! filter will advance the model
Expand Down Expand Up @@ -537,9 +538,6 @@ subroutine filter_main()
num_state_ens_copies = count_state_ens_copies(ens_size, prior_inflate, post_inflate, hybridization)
num_extras = num_state_ens_copies - ens_size

!print *, 'num_state_ens_copies: ', num_state_ens_copies
!print *, 'num_extras: ', num_extras

! Observation
OBS_ERR_VAR_COPY = ens_size + 1
OBS_VAL_COPY = ens_size + 2
Expand All @@ -559,9 +557,10 @@ subroutine filter_main()
endif

! state copies: # of static member files + mean + sd
num_static_ens_copies = hyb_ens_size + 2
STAT_ENS_MEAN_COPY = hyb_ens_size + 1
STAT_ENS_SD_COPY = hyb_ens_size + 2
num_static_ens_copies = hyb_ens_size + 3
STAT_HYB_COPY = hyb_ens_size + 1
STAT_ENS_MEAN_COPY = hyb_ens_size + 2
STAT_ENS_SD_COPY = hyb_ens_size + 3

! Indices (dummy) needed to compute forward operators only
STATIC_OBS_ERR_VAR_COPY = hyb_ens_size + 1
Expand Down Expand Up @@ -632,7 +631,9 @@ subroutine filter_main()

call set_num_extra_copies(state_ens_handle, num_extras)

! For the static members, we only have 2 extra copies: meand and sd
! For the static members, we only have 3 extra copies: meand, sd, hyb weight
! We are going to cheat and tell it that hybrid weight is part of the ensemble
! Like this we can get the obs_space weights for identity obs
if (do_hybrid) call set_num_extra_copies(static_state_ens_handle, 2)

call trace_message('After setting up space for ensembles')
Expand Down Expand Up @@ -690,12 +691,17 @@ subroutine filter_main()
call read_state(static_state_ens_handle, file_info_hybrid, read_time_from_file, time2)
endif

!print *, ''
!write(*, '(A)') 'After read state:'
!write(*, '(A, 20F10.6)') 'x(878): ', state_ens_handle%copies(1:ens_size, 878)
!write(*, '(A, 20F10.6)') 'x(879): ', state_ens_handle%copies(1:ens_size, 879)

!print *, '130 ens var: ', state_ens_handle%copies(1:ens_size, 130)
!print *, ''
!print *, '130 hyb var: ', static_state_ens_handle%copies(1:hyb_ens_size, 130)

!print *, 'extras ens: ', state_ens_handle%copies(ens_size+1:num_state_ens_copies, 130)
!print *, 'extras hyb: ', static_state_ens_handle%copies(hyb_ens_size+1:num_static_ens_copies, 130)
!print *, 'extras hyb 13644: ', static_state_ens_handle%copies(hyb_ens_size+1:num_static_ens_copies, 13644)
!print *, 'extras hyb 15914: ', static_state_ens_handle%copies(hyb_ens_size+1:num_static_ens_copies, 15914)

! This must be after read_state
call get_minmax_task_zero(prior_inflate, state_ens_handle, PRIOR_INF_COPY, PRIOR_INF_SD_COPY)
Expand Down Expand Up @@ -779,7 +785,13 @@ subroutine filter_main()
call compute_copy_mean_sd(state_ens_handle, 1, ens_size, ENS_MEAN_COPY, ENS_SD_COPY)

! Compute mean and spread for static ensemble
if(do_hybrid) call compute_copy_mean_sd(static_state_ens_handle, 1, hyb_ens_size, STAT_ENS_MEAN_COPY, STAT_ENS_SD_COPY)
if(do_hybrid) then
call compute_copy_mean_sd(static_state_ens_handle, 1, hyb_ens_size, STAT_ENS_MEAN_COPY, STAT_ENS_SD_COPY)
static_state_ens_handle%copies(STAT_HYB_COPY, :) = state_ens_handle%copies(HYBRID_WEIGHT_MEAN_COPY, :)
endif

!print *, 'extras hyb 13644: ', static_state_ens_handle%copies(hyb_ens_size+1:num_static_ens_copies, 13644)
!print *, 'extras hyb 15914: ', static_state_ens_handle%copies(hyb_ens_size+1:num_static_ens_copies, 15914)

! Write out the mean and sd for the input files if requested
if (get_stage_to_write('input')) then
Expand Down Expand Up @@ -1014,15 +1026,19 @@ subroutine filter_main()
static_qc_ens_handle, seq, keys, obs_val_index, input_qc_index, &
STATIC_OBS_ERR_VAR_COPY, STATIC_OBS_VAL_COPY, STATIC_OBS_KEY_COPY, STATIC_OBS_GLOBAL_QC_COPY, &
STATIC_OBS_EXTRA_QC_COPY, STATIC_OBS_MEAN_START, STATIC_OBS_VAR_START, &
isprior=.true., prior_qc_copy=static_prior_qc_copy)
isprior=.true., prior_qc_copy=static_prior_qc_copy, do_hybrid=do_hybrid)
endif

!print *, 'obs_ens: ', obs_fwd_op_ens_handle%copies(1:ens_size, 1)
!print *, 'rest: ' , obs_fwd_op_ens_handle%copies(ens_size+1:TOTAL_OBS_COPIES, 1)

!print *, 'obs_hyb: ', static_obs_ens_handle%copies(1:hyb_ens_size, 1)
!print *, 'rest: ' , static_obs_ens_handle%copies(hyb_ens_size+1:TOTAL_STATIC_OBS_COPIES, 1)


!print *, 'hyb_ens: ', static_obs_ens_handle%copies(1:hyb_ens_size, 1)

!print *, 'rest 1: ', static_obs_ens_handle%copies(hyb_ens_size+1:TOTAL_STATIC_OBS_COPIES, 1)
!print *, 'rest 2: ', static_obs_ens_handle%copies(hyb_ens_size+1:TOTAL_STATIC_OBS_COPIES, 2)

call timestamp_message('After computing prior observation values')
call trace_message('After computing prior observation values')

Expand Down Expand Up @@ -1250,6 +1266,15 @@ subroutine filter_main()
endif ! sd >= 0 or sd from restart file
endif ! if doing state space posterior inflate

!print *, ''
!write(*, '(A)') 'Before writing out the state:'
!write(*, '(A, 20F10.6)') 'x(878): ', state_ens_handle%copies(1:ens_size, 878)
!write(*, '(A, 20F10.6)') 'x(879): ', state_ens_handle%copies(1:ens_size, 879)

!print *, ''
!write(*, '(A, F10.6)') 'x878 mean: ', state_ens_handle%copies(ENS_MEAN_COPY, 878)
!write(*, '(A, F10.6)') 'x879 mean: ', state_ens_handle%copies(ENS_MEAN_COPY, 879)

! Write out analysis diagnostic files if requested. This contains the
! posterior inflated ensemble and updated {prior,posterior} inflation values
if (get_stage_to_write('analysis')) then
Expand Down
Loading

0 comments on commit 60fae50

Please sign in to comment.