Skip to content

Commit

Permalink
Merge remote-tracking branch 'MAIN/main' into feature/update-to-main-…
Browse files Browse the repository at this point in the history
…20240401
  • Loading branch information
jiandewang committed Apr 16, 2024
2 parents ab7bd14 + 87913b5 commit 6d0150d
Show file tree
Hide file tree
Showing 11 changed files with 808 additions and 304 deletions.
140 changes: 119 additions & 21 deletions config_src/drivers/nuopc_cap/mom_cap.F90
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module MOM_cap_mod
use MOM_domains, only: MOM_infra_init, MOM_infra_end
use MOM_file_parser, only: get_param, log_version, param_file_type, close_param_file
use MOM_get_input, only: get_MOM_input, directories
use MOM_domains, only: pass_var
use MOM_domains, only: pass_var, pe_here
use MOM_error_handler, only: MOM_error, FATAL, is_root_pe
use MOM_grid, only: ocean_grid_type, get_global_grid_size
use MOM_ocean_model_nuopc, only: ice_ocean_boundary_type
Expand All @@ -29,6 +29,7 @@ module MOM_cap_mod
use MOM_cap_methods, only: med2mod_areacor, state_diagnose
use MOM_cap_methods, only: ChkErr
use MOM_ensemble_manager, only: ensemble_manager_init
use MOM_coms, only: sum_across_PEs

#ifdef CESMCOUPLED
use shr_log_mod, only: shr_log_setLogUnit
Expand Down Expand Up @@ -842,6 +843,7 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
type(ocean_grid_type) , pointer :: ocean_grid
type(ocean_internalstate_wrapper) :: ocean_internalstate
integer :: npet, ntiles
integer :: npes ! number of PEs (from FMS).
integer :: nxg, nyg, cnt
integer :: isc,iec,jsc,jec
integer, allocatable :: xb(:),xe(:),yb(:),ye(:),pe(:)
Expand All @@ -868,6 +870,8 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
integer :: lsize
integer :: ig,jg, ni,nj,k
integer, allocatable :: gindex(:) ! global index space
integer, allocatable :: gindex_ocn(:) ! global index space for ocean cells (excl. masked cells)
integer, allocatable :: gindex_elim(:) ! global index space for eliminated cells
character(len=128) :: fldname
character(len=256) :: cvalue
character(len=256) :: frmt ! format specifier for several error msgs
Expand All @@ -891,6 +895,11 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
real(ESMF_KIND_R8) :: min_areacor_glob(2)
real(ESMF_KIND_R8) :: max_areacor_glob(2)
character(len=*), parameter :: subname='(MOM_cap:InitializeRealize)'
integer :: niproc, njproc
integer :: ip, jp, pe_ix
integer :: num_elim_blocks ! number of blocks to be eliminated
integer :: num_elim_cells_global, num_elim_cells_local, num_elim_cells_remaining
integer, allocatable :: cell_mask(:,:)
real(8) :: MPI_Wtime, timeirls
!--------------------------------

Expand Down Expand Up @@ -937,19 +946,19 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
rc = ESMF_FAILURE
call ESMF_LogWrite(subname//' ntiles must be 1', ESMF_LOGMSG_ERROR)
endif
ntiles = mpp_get_domain_npes(ocean_public%domain)
write(tmpstr,'(a,1i6)') subname//' ntiles = ',ntiles
npes = mpp_get_domain_npes(ocean_public%domain)
write(tmpstr,'(a,1i6)') subname//' npes = ',npes
call ESMF_LogWrite(trim(tmpstr), ESMF_LOGMSG_INFO)

!---------------------------------
! get start and end indices of each tile and their PET
!---------------------------------

allocate(xb(ntiles),xe(ntiles),yb(ntiles),ye(ntiles),pe(ntiles))
allocate(xb(npes),xe(npes),yb(npes),ye(npes),pe(npes))
call mpp_get_compute_domains(ocean_public%domain, xbegin=xb, xend=xe, ybegin=yb, yend=ye)
call mpp_get_pelist(ocean_public%domain, pe)
if (dbug > 1) then
do n = 1,ntiles
do n = 1,npes
write(tmpstr,'(a,6i6)') subname//' tiles ',n,pe(n),xb(n),xe(n),yb(n),ye(n)
call ESMF_LogWrite(trim(tmpstr), ESMF_LOGMSG_INFO)
enddo
Expand All @@ -971,17 +980,102 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
call get_global_grid_size(ocean_grid, ni, nj)
lsize = ( ocean_grid%iec - ocean_grid%isc + 1 ) * ( ocean_grid%jec - ocean_grid%jsc + 1 )

! Create the global index space for the computational domain
allocate(gindex(lsize))
k = 0
do j = ocean_grid%jsc, ocean_grid%jec
jg = j + ocean_grid%jdg_offset
do i = ocean_grid%isc, ocean_grid%iec
ig = i + ocean_grid%idg_offset
k = k + 1 ! Increment position within gindex
gindex(k) = ni * (jg - 1) + ig
num_elim_blocks = 0
num_elim_cells_global = 0
num_elim_cells_local = 0
num_elim_cells_remaining = 0

! Compute the number of eliminated blocks (specified in MOM_mask_table)
if (associated(ocean_grid%Domain%maskmap)) then
njproc = size(ocean_grid%Domain%maskmap, 1)
niproc = size(ocean_grid%Domain%maskmap, 2)

do ip = 1, niproc
do jp = 1, njproc
if (.not. ocean_grid%Domain%maskmap(jp,ip)) then
num_elim_blocks = num_elim_blocks+1
endif
enddo
enddo
enddo
endif

! Apply land block elimination to ESMF gindex
! (Here we assume that each processor gets assigned a single tile. If multi-tile implementation is to be added
! in MOM6 NUOPC cap in the future, below code must be updated accordingly.)
if (num_elim_blocks>0) then

allocate(cell_mask(ni, nj), source=0)
allocate(gindex_ocn(lsize))
k = 0
do j = ocean_grid%jsc, ocean_grid%jec
jg = j + ocean_grid%jdg_offset
do i = ocean_grid%isc, ocean_grid%iec
ig = i + ocean_grid%idg_offset
k = k + 1 ! Increment position within gindex
gindex_ocn(k) = ni * (jg - 1) + ig
cell_mask(ig, jg) = 1
enddo
enddo
call sum_across_PEs(cell_mask, ni*nj)

if (maxval(cell_mask) /= 1 ) then
call MOM_error(FATAL, "Encountered cells shared by multiple PEs while attempting to determine masked cells.")
endif

num_elim_cells_global = ni * nj - sum(cell_mask)
num_elim_cells_local = num_elim_cells_global / npes

if (pe_here() == pe(npes)) then
! assign all remaining cells to the last PE.
num_elim_cells_remaining = num_elim_cells_global - num_elim_cells_local * npes
allocate(gindex_elim(num_elim_cells_local+num_elim_cells_remaining))
else
allocate(gindex_elim(num_elim_cells_local))
endif

! Zero-based PE index.
pe_ix = pe_here() - pe(1)

k = 0
do jg = 1, nj
do ig = 1, ni
if (cell_mask(ig, jg) == 0) then
k = k + 1
if (k > pe_ix * num_elim_cells_local .and. &
k <= ((pe_ix+1) * num_elim_cells_local + num_elim_cells_remaining)) then
gindex_elim(k - pe_ix * num_elim_cells_local) = ni * (jg -1) + ig
endif
endif
enddo
enddo

allocate(gindex(lsize + num_elim_cells_local + num_elim_cells_remaining))
do k = 1, lsize
gindex(k) = gindex_ocn(k)
enddo
do k = 1, num_elim_cells_local + num_elim_cells_remaining
gindex(k+lsize) = gindex_elim(k)
enddo

deallocate(cell_mask)
deallocate(gindex_ocn)
deallocate(gindex_elim)

else ! no eliminated land blocks

! Create the global index space for the computational domain
allocate(gindex(lsize))
k = 0
do j = ocean_grid%jsc, ocean_grid%jec
jg = j + ocean_grid%jdg_offset
do i = ocean_grid%isc, ocean_grid%iec
ig = i + ocean_grid%idg_offset
k = k + 1 ! Increment position within gindex
gindex(k) = ni * (jg - 1) + ig
enddo
enddo

endif

DistGrid = ESMF_DistGridCreate(arbSeqIndexList=gindex, rc=rc)
if (ChkErr(rc,__LINE__,u_FILE_u)) return
Expand All @@ -1005,6 +1099,10 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
call ESMF_MeshGet(Emesh, spatialDim=spatialDim, numOwnedElements=numOwnedElements, rc=rc)
if (ChkErr(rc,__LINE__,u_FILE_u)) return

if (lsize /= numOwnedElements - num_elim_cells_local - num_elim_cells_remaining) then
call MOM_error(FATAL, "Discrepancy detected between ESMF mesh and internal MOM6 domain sizes. Check mask table.")
endif

allocate(ownedElemCoords(spatialDim*numOwnedElements))
allocate(lonMesh(numOwnedElements), lon(numOwnedElements))
allocate(latMesh(numOwnedElements), lat(numOwnedElements))
Expand Down Expand Up @@ -1036,7 +1134,7 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
end do

eps_omesh = get_eps_omesh(ocean_state)
do n = 1,numOwnedElements
do n = 1,lsize
diff_lon = abs(mod(lonMesh(n) - lon(n),360.0))
if (diff_lon > eps_omesh) then
frmt = "('ERROR: Difference between ESMF Mesh and MOM6 domain coords is "//&
Expand Down Expand Up @@ -1140,11 +1238,11 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)

! generate delayout and dist_grid

allocate(deBlockList(2,2,ntiles))
allocate(petMap(ntiles))
allocate(deLabelList(ntiles))
allocate(deBlockList(2,2,npes))
allocate(petMap(npes))
allocate(deLabelList(npes))

do n = 1, ntiles
do n = 1, npes
deLabelList(n) = n
deBlockList(1,1,n) = xb(n)
deBlockList(1,2,n) = xe(n)
Expand Down Expand Up @@ -1727,7 +1825,7 @@ subroutine ModelAdvance(gcomp, rc)
rpointer_filename = 'rpointer.ocn'//trim(inst_suffix)

write(restartname,'(A,".mom6.r.",I4.4,"-",I2.2,"-",I2.2,"-",I5.5)') &
trim(casename), year, month, day, seconds
trim(casename), year, month, day, hour * 3600 + minute * 60 + seconds
call ESMF_LogWrite("MOM_cap: Writing restart : "//trim(restartname), ESMF_LOGMSG_INFO)
! write restart file(s)
call ocean_model_restart(ocean_state, restartname=restartname, num_rest_files=num_rest_files)
Expand Down
9 changes: 8 additions & 1 deletion config_src/drivers/nuopc_cap/mom_cap_methods.F90
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ subroutine State_SetExport(state, fldname, isc, iec, jsc, jec, input, ocean_grid

! local variables
type(ESMF_StateItem_Flag) :: itemFlag
integer :: n, i, j, i1, j1, ig,jg
integer :: n, i, j, k, i1, j1, ig,jg
integer :: lbnd1,lbnd2
real(ESMF_KIND_R8), pointer :: dataPtr1d(:)
real(ESMF_KIND_R8), pointer :: dataPtr2d(:,:)
Expand Down Expand Up @@ -889,6 +889,13 @@ subroutine State_SetExport(state, fldname, isc, iec, jsc, jec, input, ocean_grid
enddo
end if

! if a maskmap is provided, set exports of all eliminated cells to zero.
if (associated(ocean_grid%Domain%maskmap)) then
do k = n+1, size(dataPtr1d)
dataPtr1d(k) = 0.0
enddo
endif

else if (geomtype == ESMF_GEOMTYPE_GRID) then

call state_getfldptr(state, trim(fldname), dataptr2d, rc)
Expand Down
15 changes: 13 additions & 2 deletions config_src/infra/FMS1/MOM_domain_infra.F90
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module MOM_domain_infra
use mpp_domains_mod, only : mpp_create_group_update, mpp_do_group_update
use mpp_domains_mod, only : mpp_reset_group_update_field, mpp_group_update_initialized
use mpp_domains_mod, only : mpp_start_group_update, mpp_complete_group_update
use mpp_domains_mod, only : mpp_compute_block_extent
use mpp_domains_mod, only : mpp_compute_block_extent, mpp_compute_extent
use mpp_domains_mod, only : mpp_broadcast_domain, mpp_redistribute, mpp_global_field
use mpp_domains_mod, only : AGRID, BGRID_NE, CGRID_NE, SCALAR_PAIR, BITWISE_EXACT_SUM
use mpp_domains_mod, only : CYCLIC_GLOBAL_DOMAIN, FOLD_NORTH_EDGE
Expand All @@ -40,7 +40,7 @@ module MOM_domain_infra
public :: domain2D, domain1D, group_pass_type
! These interfaces are actually implemented or have explicit interfaces in this file.
public :: create_MOM_domain, clone_MOM_domain, get_domain_components, get_domain_extent
public :: deallocate_MOM_domain, get_global_shape, compute_block_extent
public :: deallocate_MOM_domain, get_global_shape, compute_block_extent, compute_extent
public :: pass_var, pass_vector, fill_symmetric_edges, rescale_comp_data
public :: pass_var_start, pass_var_complete, pass_vector_start, pass_vector_complete
public :: create_group_pass, do_group_pass, start_group_pass, complete_group_pass
Expand Down Expand Up @@ -1945,6 +1945,17 @@ subroutine compute_block_extent(isg, ieg, ndivs, ibegin, iend)
call mpp_compute_block_extent(isg, ieg, ndivs, ibegin, iend)
end subroutine compute_block_extent

!> Get the array ranges in one dimension for the divisions of a global index space
subroutine compute_extent(isg, ieg, ndivs, ibegin, iend)
integer, intent(in) :: isg !< The starting index of the global index space
integer, intent(in) :: ieg !< The ending index of the global index space
integer, intent(in) :: ndivs !< The number of divisions
integer, dimension(:), intent(out) :: ibegin !< The starting index of each division
integer, dimension(:), intent(out) :: iend !< The ending index of each division

call mpp_compute_extent(isg, ieg, ndivs, ibegin, iend)
end subroutine compute_extent

!> Broadcast a 2-d domain from the root PE to the other PEs
subroutine broadcast_domain(domain)
type(domain2d), intent(inout) :: domain !< The domain2d type that will be shared across PEs.
Expand Down
10 changes: 10 additions & 0 deletions config_src/infra/FMS2/MOM_coms_infra.F90
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ module MOM_coms_infra
interface sum_across_PEs
module procedure sum_across_PEs_int4_0d
module procedure sum_across_PEs_int4_1d
module procedure sum_across_PEs_int4_2d
module procedure sum_across_PEs_int8_0d
module procedure sum_across_PEs_int8_1d
module procedure sum_across_PEs_int8_2d
Expand Down Expand Up @@ -357,6 +358,15 @@ subroutine sum_across_PEs_int4_1d(field, length, pelist)
call mpp_sum(field, length, pelist)
end subroutine sum_across_PEs_int4_1d

!> Find the sum of the values in corresponding positions of field across PEs, and return these sums in field.
subroutine sum_across_PEs_int4_2d(field, length, pelist)
integer(kind=int32), dimension(:,:), intent(inout) :: field !< The values to add, the sums upon return
integer, intent(in) :: length !< Number of elements in field to add
integer, optional, intent(in) :: pelist(:) !< List of PEs to work with

call mpp_sum(field, length, pelist)
end subroutine sum_across_PEs_int4_2d

!> Find the sum of field across PEs, and return this sum in field.
subroutine sum_across_PEs_int8_0d(field, pelist)
integer(kind=int64), intent(inout) :: field !< Value on this PE, and the sum across PEs upon return
Expand Down
17 changes: 14 additions & 3 deletions config_src/infra/FMS2/MOM_domain_infra.F90
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module MOM_domain_infra
use mpp_domains_mod, only : mpp_create_group_update, mpp_do_group_update
use mpp_domains_mod, only : mpp_reset_group_update_field, mpp_group_update_initialized
use mpp_domains_mod, only : mpp_start_group_update, mpp_complete_group_update
use mpp_domains_mod, only : mpp_compute_block_extent
use mpp_domains_mod, only : mpp_compute_block_extent, mpp_compute_extent
use mpp_domains_mod, only : mpp_broadcast_domain, mpp_redistribute, mpp_global_field
use mpp_domains_mod, only : AGRID, BGRID_NE, CGRID_NE, SCALAR_PAIR, BITWISE_EXACT_SUM
use mpp_domains_mod, only : CYCLIC_GLOBAL_DOMAIN, FOLD_NORTH_EDGE
Expand All @@ -38,7 +38,7 @@ module MOM_domain_infra
public :: domain2D, domain1D, group_pass_type
! These interfaces are actually implemented or have explicit interfaces in this file.
public :: create_MOM_domain, clone_MOM_domain, get_domain_components, get_domain_extent
public :: deallocate_MOM_domain, get_global_shape, compute_block_extent
public :: deallocate_MOM_domain, get_global_shape, compute_block_extent, compute_extent
public :: pass_var, pass_vector, fill_symmetric_edges, rescale_comp_data
public :: pass_var_start, pass_var_complete, pass_vector_start, pass_vector_complete
public :: create_group_pass, do_group_pass, start_group_pass, complete_group_pass
Expand Down Expand Up @@ -1936,7 +1936,7 @@ subroutine get_global_shape(domain, niglobal, njglobal)
njglobal = domain%njglobal
end subroutine get_global_shape

!> Get the array ranges in one dimension for the divisions of a global index space
!> Get the array ranges in one dimension for the divisions of a global index space (alternative to compute_extent)
subroutine compute_block_extent(isg, ieg, ndivs, ibegin, iend)
integer, intent(in) :: isg !< The starting index of the global index space
integer, intent(in) :: ieg !< The ending index of the global index space
Expand All @@ -1947,6 +1947,17 @@ subroutine compute_block_extent(isg, ieg, ndivs, ibegin, iend)
call mpp_compute_block_extent(isg, ieg, ndivs, ibegin, iend)
end subroutine compute_block_extent

!> Get the array ranges in one dimension for the divisions of a global index space
subroutine compute_extent(isg, ieg, ndivs, ibegin, iend)
integer, intent(in) :: isg !< The starting index of the global index space
integer, intent(in) :: ieg !< The ending index of the global index space
integer, intent(in) :: ndivs !< The number of divisions
integer, dimension(:), intent(out) :: ibegin !< The starting index of each division
integer, dimension(:), intent(out) :: iend !< The ending index of each division

call mpp_compute_extent(isg, ieg, ndivs, ibegin, iend)
end subroutine compute_extent

!> Broadcast a 2-d domain from the root PE to the other PEs
subroutine broadcast_domain(domain)
type(domain2d), intent(inout) :: domain !< The domain2d type that will be shared across PEs.
Expand Down
8 changes: 4 additions & 4 deletions src/core/MOM.F90
Original file line number Diff line number Diff line change
Expand Up @@ -2539,12 +2539,12 @@ subroutine initialize_MOM(Time, Time_init, param_file, dirs, CS, &
G_in => CS%G_in
#ifdef STATIC_MEMORY_
call MOM_domains_init(G_in%domain, param_file, symmetric=symmetric, &
static_memory=.true., NIHALO=NIHALO_, NJHALO=NJHALO_, &
NIGLOBAL=NIGLOBAL_, NJGLOBAL=NJGLOBAL_, NIPROC=NIPROC_, &
NJPROC=NJPROC_)
static_memory=.true., NIHALO=NIHALO_, NJHALO=NJHALO_, &
NIGLOBAL=NIGLOBAL_, NJGLOBAL=NJGLOBAL_, NIPROC=NIPROC_, &
NJPROC=NJPROC_, US=US)
#else
call MOM_domains_init(G_in%domain, param_file, symmetric=symmetric, &
domain_name="MOM_in")
domain_name="MOM_in", US=US)
#endif

! Copy input grid (G_in) domain to active grid G
Expand Down
Loading

0 comments on commit 6d0150d

Please sign in to comment.